为了账号安全,请及时绑定邮箱和手机立即绑定

在 Tensorflow 数据集管道中返回不同长度的数组

在 Tensorflow 数据集管道中返回不同长度的数组

慕虎7371278 2023-08-22 10:42:35
我正在 python 中使用 Tensorflow 进行对象检测。我想使用张量流输入管道来加载批量输入数据。问题是图像中的对象数量是可变的。想象一下我想做以下事情。注释是图像文件名及其包含的边界框的数组。标签被排除在外。每个边界框由四个数字表示。import tensorflow as tf@tf.function()def prepare_sample(annotation):    annotation_parts = tf.strings.split(annotation, sep=' ')    image_file_name = annotation_parts[0]    image_file_path = tf.strings.join(["/images/", image_file_name])    depth_image = tf.io.read_file(image_file_path)    bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])    return depth_image, bboxesannotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']dataset = tf.data.Dataset.from_tensor_slices(annotations)dataset = dataset.shuffle(len(annotations))dataset = dataset.map(prepare_sample)dataset = dataset.batch(16)for image, bboxes in dataset:  pass在上面的示例中,image1 包含单个对象,而 image2 包含两个对象。我收到以下错误:InvalidArgumentError:无法将张量添加到批次:元素数量不匹配。形状为:[张量]:[1,4],[批次]:[2,4]这就说得通了。我正在寻找从映射函数返回不同长度数组的方法。我能做些什么?谢谢你!编辑:我想我找到了解决方案;我不再收到错误。我dataset.batch(16)改为dataset.padded_batch(16).
查看完整描述

1 回答

?
绝地无双

TA贡献1946条经验 获得超4个赞


dataset.batch(16)更改为 后该错误将得到解决dataset.padded_batch(16)。


下面是相同的修改后的代码。


import tensorflow as tf


@tf.function()

def prepare_sample(annotation):

    annotation_parts = tf.strings.split(annotation, sep=' ')

    image_file_name = annotation_parts[0]

    image_file_path = tf.strings.join(["/images/", image_file_name])

    depth_image = tf.io.read_file(image_file_path)

    bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])

    return depth_image, bboxes


annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']

dataset = tf.data.Dataset.from_tensor_slices(annotations)

dataset = dataset.shuffle(len(annotations))

dataset = dataset.map(prepare_sample)

dataset = dataset.padded_batch(16)


for image, bboxes in dataset:

  pass



查看完整回答
反对 回复 2023-08-22
  • 1 回答
  • 0 关注
  • 4418 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信