为了账号安全,请及时绑定邮箱和手机立即绑定

如何指定 tf.Data.iterator 的起点(或跳过前 X 个批次)?

如何指定 tf.Data.iterator 的起点(或跳过前 X 个批次)?

一只名叫tom的猫 2022-05-24 10:40:34
对于张量流数据集迭代器(tf.data.Iterator),跳过前 X 个批次的最佳方法是什么,但仅在第一次迭代中,而不是在指定 repeat() 时的后续迭代)?我尝试了以下但没有奏效:import tensorflow as tfimport pandas as pdfrom pyspark.sql import SparkSessionspark = SparkSession.builder.master('local[*]').config("spark.jars",'some/path/spark-tensorflow-connector_2.11-1.10.0.jar').getOrCreate()df = pd.DataFrame({'x': range(10), 'y': [i*2 for i in range(10)]})df = spark.createDataFrame(df)df.write.format('tfrecords').option('recordType', 'Example').mode("overwrite").save('testdata')def parse_function(proto):    feature_description = {    'x': tf.FixedLenFeature([], tf.int64),    'y': tf.FixedLenFeature([], tf.int64)    }    parsed_features = tf.parse_single_example(proto, feature_description)    x = parsed_features['x']    y = parsed_features['y']    return {'x': x, 'y': y}def load_data(filename_pattern, parse_function, batch_size=200, skip_batches=0):    files = tf.data.Dataset.list_files(file_pattern=filename_pattern, shuffle=False)    dataset = tf.data.TFRecordDataset(files)    dataset = dataset.repeat()    dataset = dataset.map(parse_function)    dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))    dataset = dataset.prefetch(2)    # Create an iterator    iterator = dataset.make_one_shot_iterator()    data = iterator.get_next()    with tf.Session() as sess:        for i in range(skip_batches):            sess.run(data)    return data# skip first three batchesdata = load_data('testdata/part-*', parse_function, batch_size=2, skip_batches=3)sess = tf.Session()for i in range(3):    print(sess.run(data))预期/期望:    {'y': array([12, 14]), 'x': array([6, 7])}    {'y': array([16, 18]), 'x': array([8, 9])}    {'y': array([0, 2]), 'x': array([0, 1])}实际的:    {'y': array([0, 2]), 'x': array([0, 1])}    {'y': array([4, 6]), 'x': array([2, 3])}    {'y': array([8, 10]), 'x': array([4, 5])}
查看完整描述

1 回答

?
摇曳的蔷薇

TA贡献1793条经验 获得超6个赞

tf.Dataset.iterator()你为什么不跳过前 X 批,而不是通过?

假设您想要 10 个批次,每个批次有 32 个元素,这意味着总共 320 个元素。因此,您可以使用tf.Dataset.skip(320)skip ) 跳过这些,它会为您提供跳过前 10 个批次的数据集。


查看完整回答
反对 回复 2022-05-24
  • 1 回答
  • 0 关注
  • 121 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信