我有以下简单的例子:import tensorflow as tftensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])tensor2 = tf.constant(value = [20, 21, 22, 23])print(tensor1.shape)print(tensor2.shape)dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))print('Original dataset')for i in dataset: print(i)dataset = dataset.repeat(3)print('Repeated dataset')for i in dataset: print(i)如果我然后将其批处理dataset为:dataset = dataset.batch(3)print('Batched dataset')for i in dataset: print(i)正如预期的那样,我收到:Batched dataset(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[10, 11, 12], [ 1, 2, 3], [ 4, 5, 6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 7, 8, 9], [10, 11, 12], [ 1, 2, 3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)批处理数据集采用连续的元素。但是,当我先进行混音,然后进行批处理时:dataset = dataset.shuffle(3)print('Shuffled dataset')for i in dataset: print(i)dataset = dataset.batch(3)print('Batched dataset')for i in dataset: print(i)我正在使用 Google Colab 和TensorFlow 2.x.我的问题是:为什么在批处理之前进行洗牌会导致batch返回非连续元素?感谢您的任何答复。
1 回答
12345678_0001
TA贡献1802条经验 获得超5个赞
这就是洗牌的作用。你是这样开始的:
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
您已指定,buffer_size=3
因此它会创建前 3 个元素的缓冲区:
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
您指定了batch_size=3
,因此它将从此样本中随机选择一个元素,并将其替换为初始缓冲区之外的第一个元素。假设[1, 2, 3]
已被选中,您的批次现在是:
[[1, 2, 3]]
现在你的缓冲区是:
[[10, 11, 12], [4, 5, 6], [7, 8, 9]]
对于 的第二个元素batch=3
,它将从此缓冲区中随机选择。假设[7, 8, 9]
已挑选,您的批次现在是:
[[1, 2, 3], [7, 8, 9]]
现在你的缓冲区是:
[[10, 11, 12], [4, 5, 6]]
没有什么新内容可以填充缓冲区,因此它将随机选择这些元素之一,例如[10, 11, 12]
。您的批次现在是:
[[1, 2, 3], [7, 8, 9], [10, 11, 12]]
下一批将只是[4, 5, 6]
因为默认情况下, batch(drop_remainder=False)
.
添加回答
举报
0/150
提交
取消