为了账号安全,请及时绑定邮箱和手机立即绑定

使用 [:, :, 0] 的 TF2 / Keras 切片张量

使用 [:, :, 0] 的 TF2 / Keras 切片张量

holdtom 2022-04-27 16:14:02
在 TF 2.0 Beta 中,我正在尝试:x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)print(x.shape) # (None, 240, 2)a = x[:, :, 0]print(a.shape) # <unknown>在 TF 1.x 中,我可以这样做:x = tf1.placeholder(tf1.float32, (None, 240, 2)a = x[:, :, 0]它会正常工作。如何在 TF 2.0 中实现这一点?我认为tf.split(x, 2, axis=2)可能有效,但是我想使用切片而不是硬编码 2(轴 2 的暗淡)。
查看完整描述

1 回答

?
慕工程0101907

TA贡献1887条经验 获得超5个赞

不同之处在于返回的对象Input代表一个层,而不是任何类似于占位符或张量的东西。因此,x上面的 tf 2.0 代码中的 是层对象,而xtf 1.x 代码中的 是张量的占位符。


您可以定义一个切片层来执行该操作。有现成可用的层,但是对于像这样的简单切片,Lambda层非常容易阅读,并且可能最接近您在 tf 1.x 中使用的切片方式。


像这样的东西:


input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)

sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])

您可以像这样在您的 keras 模型中使用它:


model = tf.keras.models.Sequential([

    input_lyr,

    sliced_lyr,

    # ...

    # <other layers>

    # ...

])


当然,以上是特定于keras模型的。相反,如果你有一个张量而不是一个 keras 层对象,那么切片就像以前一样工作。像这样的东西:


my_tensor = tf.random.uniform((8,240,2))

sliced = my_tensor[:,:,0]


print(my_tensor.shape)

print(sliced.shape)

输出:


(8, 240, 2)

(8, 240)

正如预期的那样


查看完整回答
反对 回复 2022-04-27
  • 1 回答
  • 0 关注
  • 138 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信