对于张量流数据集迭代器(tf.data.Iterator),跳过前 X 个批次的最佳方法是什么,但仅在第一次迭代中,而不是在指定 repeat() 时的后续迭代)?我尝试了以下但没有奏效:import tensorflow as tfimport pandas as pdfrom pyspark.sql import SparkSessionspark = SparkSession.builder.master('local[*]').config("spark.jars",'some/path/spark-tensorflow-connector_2.11-1.10.0.jar').getOrCreate()df = pd.DataFrame({'x': range(10), 'y': [i*2 for i in range(10)]})df = spark.createDataFrame(df)df.write.format('tfrecords').option('recordType', 'Example').mode("overwrite").save('testdata')def parse_function(proto): feature_description = { 'x': tf.FixedLenFeature([], tf.int64), 'y': tf.FixedLenFeature([], tf.int64) } parsed_features = tf.parse_single_example(proto, feature_description) x = parsed_features['x'] y = parsed_features['y'] return {'x': x, 'y': y}def load_data(filename_pattern, parse_function, batch_size=200, skip_batches=0): files = tf.data.Dataset.list_files(file_pattern=filename_pattern, shuffle=False) dataset = tf.data.TFRecordDataset(files) dataset = dataset.repeat() dataset = dataset.map(parse_function) dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) dataset = dataset.prefetch(2) # Create an iterator iterator = dataset.make_one_shot_iterator() data = iterator.get_next() with tf.Session() as sess: for i in range(skip_batches): sess.run(data) return data# skip first three batchesdata = load_data('testdata/part-*', parse_function, batch_size=2, skip_batches=3)sess = tf.Session()for i in range(3): print(sess.run(data))预期/期望: {'y': array([12, 14]), 'x': array([6, 7])} {'y': array([16, 18]), 'x': array([8, 9])} {'y': array([0, 2]), 'x': array([0, 1])}实际的: {'y': array([0, 2]), 'x': array([0, 1])} {'y': array([4, 6]), 'x': array([2, 3])} {'y': array([8, 10]), 'x': array([4, 5])}
添加回答
举报
0/150
提交
取消