为了账号安全,请及时绑定邮箱和手机立即绑定

在 TensorFlow 2.x 中看似不连续地打乱后的批处理元素

在 TensorFlow 2.x 中看似不连续地打乱后的批处理元素

四季花海 2024-01-04 16:28:11
我有以下简单的例子:import tensorflow as tftensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])tensor2 = tf.constant(value = [20, 21, 22, 23])print(tensor1.shape)print(tensor2.shape)dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))print('Original dataset')for i in dataset:      print(i)dataset = dataset.repeat(3)print('Repeated dataset')for i in dataset:      print(i)如果我然后将其批处理dataset为:dataset = dataset.batch(3)print('Batched dataset')for i in dataset:   print(i)正如预期的那样,我收到:Batched dataset(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[1, 2, 3],       [4, 5, 6],       [7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[10, 11, 12],       [ 1,  2,  3],       [ 4,  5,  6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 7,  8,  9],       [10, 11, 12],       [ 1,  2,  3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 4,  5,  6],       [ 7,  8,  9],       [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)批处理数据集采用连续的元素。但是,当我先进行混音,然后进行批处理时:dataset = dataset.shuffle(3)print('Shuffled dataset')for i in dataset:  print(i)dataset = dataset.batch(3)print('Batched dataset')for i in dataset:   print(i)我正在使用 Google Colab 和TensorFlow 2.x.我的问题是:为什么在批处理之前进行洗牌会导致batch返回非连续元素?感谢您的任何答复。
查看完整描述

1 回答

?
12345678_0001

TA贡献1802条经验 获得超5个赞

这就是洗牌的作用。你是这样开始的:

[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

您已指定,buffer_size=3因此它会创建前 3 个元素的缓冲区:

[[1, 2, 3], [4, 5, 6], [7, 8, 9]]

您指定了batch_size=3,因此它将从此样本中随机选择一个元素,并将其替换为初始缓冲区之外的第一个元素。假设[1, 2, 3]已被选中,您的批次现在是:

[[1, 2, 3]]

现在你的缓冲区是:

[[10, 11, 12], [4, 5, 6], [7, 8, 9]]

对于 的第二个元素batch=3,它将从此缓冲区中随机选择。假设[7, 8, 9]已挑选,您的批次现在是:

[[1, 2, 3], [7, 8, 9]]

现在你的缓冲区是:

[[10, 11, 12], [4, 5, 6]]

没有什么新内容可以填充缓冲区,因此它将随机选择这些元素之一,例如[10, 11, 12]。您的批次现在是:

[[1, 2, 3], [7, 8, 9], [10, 11, 12]]

下一批将只是[4, 5, 6]因为默认情况下, batch(drop_remainder=False).


查看完整回答
反对 回复 2024-01-04
  • 1 回答
  • 0 关注
  • 116 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信