2 回答
TA贡献1876条经验 获得超7个赞
gather函数从批处理(0th)轴返回提供的索引值。因此,它为我们提供了形状为 (10, 10) 的批次中的第一个 (index:0) 和第二个 (index:1) 样本 (形状 (10,)) 的列表 (length=10) 而我们想要第一个批次中每个样本的(索引:0)和第二(索引:1)特征点。为了解决这个问题,我们可以在使用gather函数之前转置张量,以便gather函数选择正确的值,最后生成的张量应该再次转置。
_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)
输出:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 5, 2)] 0
_________________________________________________________________
reshape (Reshape) (None, 10) 0
_________________________________________________________________
lambda (Lambda) (None, 10) 0
=================================================================
TA贡献1812条经验 获得超5个赞
如果你使用tf.gather(),你可以避免使用@bit01 描述的转置操作。中有一个axis论点tf.gather()。
_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: tf.gather(t, [0, 1]*5, axis=1))(x)
# Layer (type) Output Shape Param #
# =================================================================
# input_2 (InputLayer) (None, 5, 2) 0
# _________________________________________________________________
# reshape_2 (Reshape) (None, 10) 0
# _________________________________________________________________
# lambda_1 (Lambda) (None, 10) 0
# =================================================================
添加回答
举报