为了账号安全,请及时绑定邮箱和手机立即绑定

在 Tensorflow 数据集 API 中拆分数据集问题

在 Tensorflow 数据集 API 中拆分数据集问题

慕侠2389804 2021-09-28 21:00:31
我正在读取一个tf.contrib.data.make_csv_dataset用于形成数据集的 csv 文件,然后我使用该命令take()来形成另一个只有一个元素的数据集,但它仍然返回所有元素。这里有什么问题?我带来了下面的代码:import tensorflow as tfimport ostf.enable_eager_execution()# Constantscolumn_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']batch_size   = 1feature_names = column_names[:-1]label_name = column_names[-1]# to reorient data strucutedef pack_features_vector(features, labels):  """Pack the features into a single array."""  features = tf.stack(list(features.values()), axis=1)  return features, labels# Download the filetrain_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),                                       origin=train_dataset_url)# form the datasettrain_dataset = tf.contrib.data.make_csv_dataset(train_dataset_fp,batch_size, column_names=column_names,label_name=label_name,num_epochs=1)# perform the mappingtrain_dataset = train_dataset.map(pack_features_vector)# construct a databse with one element train_dataset= train_dataset.take(1)# inspect elementsfor step in range(10):    features, labels = next(iter(train_dataset))    print(list(features))
查看完整描述

1 回答

?
守着一只汪

TA贡献1872条经验 获得超3个赞

基于这个答案,我们可以用Dataset.take()和分割数据集Dataset.skip():


train_size = int(0.7 * DATASET_SIZE)


train_dataset = full_dataset.take(train_size)

test_dataset = full_dataset.skip(train_size)

如何修复你的代码?


不要在循环中多次创建迭代器,而是使用一个迭代器:


# inspect elements

for feature, label in train_dataset:

    print(feature)

在您的代码中发生了什么导致这种行为?


1) 内置pythoniter函数从对象获取迭代器或对象本身必须提供自己的迭代器。所以当你调用的时候iter(train_dataset),就相当于调用了Dataset.make_one_shot_iterator()。


2) 默认情况下,tf.contrib.data.make_csv_dataset()shuffle 中的参数为 True ( shuffle=True)。因此,每次调用iter(train_dataset)它时都会创建包含不同数据的新迭代器。


3)最后,当循环通过for step in range(10)它时,类似于创建10个不同的迭代器,大小为1,每个迭代器都有自己的数据,因为它们被打乱了。


建议:如果你想避免这样的事情在循环外初始化(创建)迭代器:


train_dataset = train_dataset.take(1)

iterator = train_dataset.make_one_shot_iterator()

# inspect elements

for step in range(10):

    features, labels = next(iterator)

    print(list(features))

    # throws exception because size of iterator is 1


查看完整回答
反对 回复 2021-09-28
  • 1 回答
  • 0 关注
  • 281 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信