为了账号安全,请及时绑定邮箱和手机立即绑定

使用 tf.gather 基于另一个张量逐行提取张量

使用 tf.gather 基于另一个张量逐行提取张量

一只名叫tom的猫 2021-11-23 17:01:03
我有两个尺寸分别为 asA:[B,3000,3]和 的张量C:[B,4000]。我想使用tf.gather()张量 C 中的每一行作为索引,并使用张量 A 中的每一行作为参数,以获得 size 的结果[B,4000,3]。这是一个更容易理解的例子:假设我有张量A = [[1,2,3],[4,5,6],[7,8,9]],C = [0,2,1,2,1],结果 = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],通过使用tf.gather(A,C). 当应用于维度小于 3 的张量时,这一切都很好。但是当以描述为开头的情况下,通过应用tf.gather(A,C,axis=1),结果张量的形状为[B,B,4000,3]似乎tf.gather()只是将张量 C 中的每个元素都作为索引来收集张量 A 中的元素。我正在考虑的唯一解决方案是使用for循环,但这会极大地降低计算能力,因为使用tf.gather(A[i,...],C[i,...])以获得正确的张量的大小[B,4000,3]因此,是否有任何功能能够类似地完成此任务?
查看完整描述

1 回答

?
慕码人8056858

TA贡献1803条经验 获得超6个赞

你需要使用tf.gather_nd:


import tensorflow as tf


A = ...  # B x 3000 x 3

C = ...  # B x 4000

s = tf.shape(C)

B, cols = s[0], s[1]

# Make indices for first dimension

idx = tf.tile(tf.expand_dims(tf.range(B, dtype=C.dtype), 1), [1, cols])

# Complete index for gather_nd

gather_idx = tf.stack([idx, C], axis=-1)

# Gather result

result = tf.gather_nd(A, gather_idx)


查看完整回答
反对 回复 2021-11-23
  • 1 回答
  • 0 关注
  • 261 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信