我正在训练一个 RNN,我需要使用索引来查找示例时间流的另一部分中的值v = tf.constant([ [[.1, .2], [.3, .4]], # timestream 1 values [[.6, .5], [.7, .8]] # timestream 2 values])ixs = tf.constant([ [1, 0], # indices into timestream 1 values [0, 1] # indices into timestream 2 values])我正在寻找一个可以进行查找并用张量值替换索引并产生的操作:[ [[.3, .4], [.1, .2]], [[.6, .5], [.7, .8]]]tf.gather 和 tf.gather_nd 听起来他们可能是正确的道路,但我真的不明白我从他们那里得到的结果。v_at_ix = tf.gather(v, ixs, axis=-1)sess.run(v_at_ix)array([[[[0.2, 0.1], [0.1, 0.2]], [[0.4, 0.3], [0.3, 0.4]]], [[[0.5, 0.6], [0.6, 0.5]], [[0.8, 0.7], [0.7, 0.8]]]], dtype=float32)v_at_ix = tf.gather_nd(v, ixs)sess.run(v_at_ix)array([[0.6, 0.5], [0.3, 0.4]], dtype=float32)有谁知道正确的方法来做到这一点?
添加回答
举报
0/150
提交
取消