为了账号安全,请及时绑定邮箱和手机立即绑定

如何同时使用 Tensorflow tf.nn.Conv2d 进行训练和预测?

如何同时使用 Tensorflow tf.nn.Conv2d 进行训练和预测?

繁花不似锦 2021-09-01 19:15:36
我目前正在深入研究 tensorflow,我对正确使用tf.nn.Conv2d(input, filter, strides, padding). 尽管乍一看看起来很简单,但我无法听到以下问题:的使用filter, strides, padding对我来说很清楚。然而,不清楚的是input.我来自强化学习 Atari (Pong) 问题,在该问题中我想使用网络进行批量训练,并且(以一定的概率)也用于每一步的预测。这意味着,对于训练,我正在为网络提供一整批(假设为 100 ),每个单元由 3 个帧组成,大小为 160、128。使用 tensorflow 的 NHWC 格式,我的输入input将是 atf.placeholder形状(100,160,128,3)。所以为了训练,我喂了 100 个 160x128x3 的包。但是,在某种情况下预测来自我的网络的输出(用乒乓球拍向上或向下)时,我只向网络馈送一包160x128x3(即一包三帧)。现在这是 tensorflow 崩溃的地方。它期望(100,160,128,3)但接收(1,160,128,3).现在我很困惑。我显然不想将批量大小设置为 1 并且总是只提供一个包进行训练。但是我怎么能在这里继续呢?这将如何实施tf.nn.conv2d?如果有人能在这里引导我走向正确的方向,我将不胜感激
查看完整描述

2 回答

?
繁星点点滴滴

TA贡献1803条经验 获得超3个赞

我目前正在深入研究 tensorflow,我对正确使用tf.nn.Conv2d(input, filter, strides, padding). 尽管乍一看看起来很简单,但我无法听到以下问题:


的使用filter, strides, padding对我来说很清楚。然而,不清楚的是input.


我来自强化学习 Atari (Pong) 问题,在该问题中我想使用网络进行批量训练,并且(以一定的概率)也用于每一步的预测。这意味着,对于训练,我正在为网络提供一整批(假设为 100 ),每个单元由 3 个帧组成,大小为 160、128。使用 tensorflow 的 NHWC 格式,我的输入input将是 atf.placeholder形状(100,160,128,3)。所以为了训练,我喂了 100 个 160x128x3 的包。


但是,在某种情况下预测来自我的网络的输出(用乒乓球拍向上或向下)时,我只向网络馈送一包160x128x3(即一包三帧)。现在这是 tensorflow 崩溃的地方。它期望(100,160,128,3)但接收(1,160,128,3).


现在我很困惑。我显然不想将批量大小设置为 1 并且总是只提供一个包进行训练。但是我怎么能在这里继续呢?这将如何实施tf.nn.conv2d?


如果有人能在这里引导我走向正确的方向,我将不胜感激


查看完整回答
反对 回复 2021-09-01
?
慕斯709654

TA贡献1840条经验 获得超5个赞

你需要像下面这样设置你的占位符 tf.placeholder(shape=(None,160,128,3) ....) ,在第一维中有 None ,你的占位符将灵活到你输入的任何值 1 或 100。


查看完整回答
反对 回复 2021-09-01
  • 2 回答
  • 0 关注
  • 185 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信