2 回答
TA贡献1872条经验 获得超3个赞
这样的事情应该工作(未经测试):
# Serialize the data into two tfrecord files
tf.enable_eager_execution()
feature_extractor = ...
features_file = tf.python_io.TFRecordWriter('features.tfrec')
label_file = tf.python_io.TFRecordWriter('labels.tfrec')
for images, labels in dataset:
features = feature_extractor(images)
features_file.write(tf.serialize_tensor(features))
label_file.write(tf.serialize_tensor(labels))
# Parse the files and zip them together
def parse(type, shape):
_def parse(x):
result = tf.parse_tensor(x, out_type=shape)
result = tf.reshape(result, FEATURE_SHAPE)
return result
return parse
features_ds = tf.data.TFRecordDataset('features.tfrec')
features_ds = features_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)
labels_ds = tf.data.TFRecordDataset('labels.tfrec')
labels_ds = labels_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)
ds = tf.data.Dataset.zip(features_ds, labels_ds)
ds = ds.unbatch().shuffle().repeat().batch().prefetch()...
您也可以使用 来完成它Dataset.cache,但我不是 100% 确定细节。
添加回答
举报