我正在使用 tf.data 从大型文本语料库中迭代批处理。我只想将函数应用于数据子集(或批处理子集),而不是一个一个元素。具体来说,我的数据迭代器产生 query, reply批处理。它们都是正对,所以我只想洗牌下一批的子集(在这种情况下,只有“回复”批次)以生成随机负数。例如,输入:query1 reply1query2 reply2query3 reply3...输出:正对:(query1 reply1与输入相同)负对:(query1 replyN回复随机洗牌)当然,可以使用 python 对文本进行混洗,但我想使用 tf.data 使其高效,因为数据大小太大。
1 回答

千万里不及你
TA贡献1784条经验 获得超9个赞
假设你有queries和replies作为两个张量。您需要的是我认为类似于下面的内容,您可以将其与原始批次连接起来。
batch_size = 10
def reply_shuffle(queries, replies):
shuffled_indices = tf.random_uniform(minval=0, maxval=batch_size+1, shape=[batch_size], dtype=tf.int32)
shuffled_replies = tf.gather_nd(replies, shuffled_indices)
return queries, shuffled_replies
添加回答
举报
0/150
提交
取消