2 回答
TA贡献1776条经验 获得超12个赞
这是一个解决方案:
tf.reset_default_graph()
a = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.int32)
columns = tf.constant([1, 5], dtype=tf.int32)
a_padded = tf.Variable(tf.zeros((3, 7), dtype=tf.int32))
indices = tf.stack(tf.meshgrid(tf.range(tf.shape(a_padded)[0]), columns, indexing='ij'), axis=-1)
update_cols = tf.scatter_nd_update(a_padded, indices, a)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(update_cols))
TA贡献1821条经验 获得超4个赞
(OP 在这里)我设法使用tf.scatter_nd. 诀窍是对齐 a、列和输出形状的尺寸。
a_np = np.array([[1, 2],
[3, 4],
[5, 6]])
# Note the Transpose on every line below
a = tf.constant(a_np.T)
columns = tf.constant(np.array([[1, 5]]).T.astype('int32'))
shape = tf.constant((7, 3))
a_padded = tf.transpose(tf.scatter_nd(columns, a, shape))
添加回答
举报