1 回答

TA贡献1863条经验 获得超2个赞
事实证明,tf.data.TFRecordDataset有一个其他函数被调用padded_batch,它基本上是在做这件事tf.train.batch(dynamic_pad=True)。这很容易解决问题......
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
dataset = dataset.shuffle(1000+3*batch_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(batch_size,
drop_remainder=False,
padded_shapes=([None, None, None],
[None, 4],
[None, 1])
)
添加回答
举报