我确实知道在将数据加载到我的网络之前,我可以通过它们的标签将数据分开。假设有 3 个类,标签为 0、1、2。我可以通过以下方式做到:dataset1 = tf.data.TextLineDataset(train_csv_file1).map(_parse_csv_train)dataset2 = tf.data.TextLineDataset(train_csv_file2).map(_parse_csv_train)dataset3 = tf.data.TextLineDataset(train_csv_file3).map(_parse_csv_train)我只是对以下内容感到好奇:假设我们有数据集:dataset = tf.data.TextLineDataset(train_csv_file).map(_parse_csv_train)其中包含来自 3 个类的所有数据,有没有办法调用像 dataset.selectDataByLabel(label=="2")[这是一个虚构的函数]这样的函数,以便我可以根据它们的标签将数据集分成3部分?
1 回答

墨色风雨
TA贡献1853条经验 获得超6个赞
所以最后我选择了用csvs分隔文件,即生成每个只包含一个类的数据的csvs。当类太多时,这可能不是一个完美的解决方案,但在我的情况下只有 5 个类,所以没关系。
添加回答
举报
0/150
提交
取消