1 回答
TA贡献1824条经验 获得超5个赞
似乎您无法从映射函数(1、2)内部调用 .numpy() 函数,尽管我能够使用来自(doc)的 py_function 进行管理。
在下面的示例中,我已将我解析的数据集映射到一个函数,该函数将我的图像转换为np.uint8
以便使用 matplotlib绘制它们。
records_path = data_directory+'TFRecords'+'/data_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
# Map our dataset to the parsing function
parsed_dataset = dataset.map(parsing_fn)
converted_dataset = parsed_dataset.map(lambda image,label:
tf.py_function(func=converting_function,
inp=[image,label],
Tout=[np.uint8,tf.int64]))
# Gets the iterator
iterator = tf.compat.v1.data.make_one_shot_iterator(converted_dataset)
for i in range(5):
image,label = iterator.get_next()
plt.imshow(image)
plt.show()
print('label: ', label)
输出:
解析函数:
def parsing_fn(serialized):
# Define a dict with the data-names and types we expect to
# find in the TFRecords file.
features = \
{
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.io.parse_single_example(serialized=serialized,
features=features)
# Get the image as raw bytes.
image_raw = parsed_example['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.io.decode_jpeg(image_raw)
# Get the label associated with the image.
label = parsed_example['label']
# The image and label are now correct TensorFlow types.
return image, label
更新:实际上并没有签出,但 tf.shape() 似乎也是一个有前途的选择。
添加回答
举报