1 回答
TA贡献1872条经验 获得超3个赞
基于这个答案,我们可以用Dataset.take()和分割数据集Dataset.skip():
train_size = int(0.7 * DATASET_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
如何修复你的代码?
不要在循环中多次创建迭代器,而是使用一个迭代器:
# inspect elements
for feature, label in train_dataset:
print(feature)
在您的代码中发生了什么导致这种行为?
1) 内置pythoniter函数从对象获取迭代器或对象本身必须提供自己的迭代器。所以当你调用的时候iter(train_dataset),就相当于调用了Dataset.make_one_shot_iterator()。
2) 默认情况下,tf.contrib.data.make_csv_dataset()shuffle 中的参数为 True ( shuffle=True)。因此,每次调用iter(train_dataset)它时都会创建包含不同数据的新迭代器。
3)最后,当循环通过for step in range(10)它时,类似于创建10个不同的迭代器,大小为1,每个迭代器都有自己的数据,因为它们被打乱了。
建议:如果你想避免这样的事情在循环外初始化(创建)迭代器:
train_dataset = train_dataset.take(1)
iterator = train_dataset.make_one_shot_iterator()
# inspect elements
for step in range(10):
features, labels = next(iterator)
print(list(features))
# throws exception because size of iterator is 1
添加回答
举报