为了账号安全,请及时绑定邮箱和手机立即绑定

如何索引具有形状 (batch_size, 200, 256) 的张量以获得

如何索引具有形状 (batch_size, 200, 256) 的张量以获得

HUX布斯 2022-06-22 18:27:00
我有形状为 (batch_size, 200, 256) 的 LSTM 层的输出,其中 200 是标记序列的长度,256 是 LSTM 输出维度。我还有另一个形状为 (batch_size) 的张量,它是我想从批次中的每个样本序列中切出的标记的索引列表。如果令牌索引不是 -1,我将切出一个令牌向量表示(长度 = 256)。如果令牌索引为 -1,我将给出零向量(长度 = 256)。预期的输出结果具有形状 (batch_size, 1, 256)。我该怎么做?谢谢这是我到目前为止尝试过的bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) dropout = Dropout(params['dropout_rate'])(bidir)def slice_by_tensor(x):    matrix_to_slice = x[0]    index_tensor = x[1]    out_tensor = tf.where(index_tensor == -1,                           tf.zeros(tf.shape(tf.gather(matrix_to_slice,                                                       index_tensor, axis=1))),                           tf.gather(matrix_to_slice, index_tensor, axis=1))    return out_tensorrepresentation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) # stack_idx0 shape is (batch_size) # I got output with shape (batch_size, batch_size, 256) with this code
查看完整描述

1 回答

?
慕娘9325324

TA贡献1783条经验 获得超4个赞

a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))

#     [[[ 0,  1,  2,  3],

#        [ 4,  5,  6,  7],

#        [ 8,  9, 10, 11]],


#      [[12, 13, 14, 15],

#      [16, 17, 18, 19],

#       [20, 21, 22, 23]]]


b=tf.constant([-1,2]) 


aa=tf.pad(a,[[0,0],[1,0],[0,0]]) 


bb=b+1 


index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) 

res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)

#[[[ 0,  0,  0,  0]],

#[[20, 21, 22, 23]]]

当 index 为 -1 时,我们需要像张量这样的零。所以我们可以先沿第二个轴填充原始张量。然后将索引增加 1。在此之后,使用tf.gather_nd将返回答案。


查看完整回答
反对 回复 2022-06-22
  • 1 回答
  • 0 关注
  • 98 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信