我有两个尺寸分别为 asA:[B,3000,3]和 的张量C:[B,4000]。我想使用tf.gather()张量 C 中的每一行作为索引,并使用张量 A 中的每一行作为参数,以获得 size 的结果[B,4000,3]。这是一个更容易理解的例子:假设我有张量A = [[1,2,3],[4,5,6],[7,8,9]],C = [0,2,1,2,1],结果 = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],通过使用tf.gather(A,C). 当应用于维度小于 3 的张量时,这一切都很好。但是当以描述为开头的情况下,通过应用tf.gather(A,C,axis=1),结果张量的形状为[B,B,4000,3]似乎tf.gather()只是将张量 C 中的每个元素都作为索引来收集张量 A 中的元素。我正在考虑的唯一解决方案是使用for循环,但这会极大地降低计算能力,因为使用tf.gather(A[i,...],C[i,...])以获得正确的张量的大小[B,4000,3]因此,是否有任何功能能够类似地完成此任务?
1 回答
慕码人8056858
TA贡献1803条经验 获得超6个赞
你需要使用tf.gather_nd:
import tensorflow as tf
A = ... # B x 3000 x 3
C = ... # B x 4000
s = tf.shape(C)
B, cols = s[0], s[1]
# Make indices for first dimension
idx = tf.tile(tf.expand_dims(tf.range(B, dtype=C.dtype), 1), [1, cols])
# Complete index for gather_nd
gather_idx = tf.stack([idx, C], axis=-1)
# Gather result
result = tf.gather_nd(A, gather_idx)
添加回答
举报
0/150
提交
取消