为了账号安全,请及时绑定邮箱和手机立即绑定

Keras:收集张量更改批量维度

Keras:收集张量更改批量维度

墨色风雨 2022-06-14 16:57:59
我有一个形状为 (5, 2) 的输入张量,代表 2D 空间中的五个点。我想取第一点,然后从所有五点中减去它。仔细阅读,我想我可以用它K.gather来切片和重复第一层。在 Lambda 层中应用它后,批处理维度被覆盖:_input = Input(shape=(5, 2))x = Reshape((5 * 2,))(_input)x_ = Lambda(lambda t: K.gather(t, [0, 1] * 5))(x)结果是:__________________________________________________________________________________________________Layer (type)                    Output Shape         Param #     Connected to                     ==================================================================================================input_1 (InputLayer)            (None, 5, 2)         0                                            __________________________________________________________________________________________________reshape_1 (Reshape)             (None, 10)           0           input_1[0][0]                    __________________________________________________________________________________________________lambda_1 (Lambda)               (10, 10)             0           reshape_1[0][0]                  __________________________________________________________________________________________________我究竟做错了什么?另外,有没有更简单的方法来做到这一点?
查看完整描述

2 回答

?
幕布斯6054654

TA贡献1876条经验 获得超7个赞

gather函数从批处理(0th)轴返回提供的索引值。因此,它为我们提供了形状为 (10, 10) 的批次中的第一个 (index:0) 和第二个 (index:1) 样本 (形状 (10,)) 的列表 (length=10) 而我们想要第一个批次中每个样本的(索引:0)和第二(索引:1)特征点。为了解决这个问题,我们可以在使用gather函数之前转置张量,以便gather函数选择正确的值,最后生成的张量应该再次转置。


_input = Input(shape=(5, 2))

x = Reshape((5 * 2,))(_input)

x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)

输出:


_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

input_1 (InputLayer)         [(None, 5, 2)]            0         

_________________________________________________________________

reshape (Reshape)            (None, 10)                0         

_________________________________________________________________

lambda (Lambda)              (None, 10)                0         

=================================================================


查看完整回答
反对 回复 2022-06-14
?
慕雪6442864

TA贡献1812条经验 获得超5个赞

如果你使用tf.gather(),你可以避免使用@bit01 描述的转置操作。中有一个axis论点tf.gather()。


_input = Input(shape=(5, 2))

x = Reshape((5 * 2,))(_input)

x_ = Lambda(lambda t: tf.gather(t, [0, 1]*5, axis=1))(x)


# Layer (type)                 Output Shape              Param #   

# =================================================================

# input_2 (InputLayer)         (None, 5, 2)              0         

# _________________________________________________________________

# reshape_2 (Reshape)          (None, 10)                0         

# _________________________________________________________________

# lambda_1 (Lambda)            (None, 10)                0         

# =================================================================


查看完整回答
反对 回复 2022-06-14
  • 2 回答
  • 0 关注
  • 216 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信