为了账号安全,请及时绑定邮箱和手机立即绑定

具有不同批量大小的pytorch恢复模型

具有不同批量大小的pytorch恢复模型

呼如林 2021-06-14 15:31:52
我有一个关于如何重新加载具有不同批量大小的 pytorch 模型的问题。在训练中,我的批量大小为 64,但在推理中,我希望批量大小为 1(一个接一个地馈送数据)。这是我用来保存和恢复模型的代码:torch.save(agent.qnetwork_local.state_dict(), './ckpt/checkpoint.pth')saved_model = QNetwork(state_size=37, action_size=4, seed=0)saved_model.load_state_dict(torch.load('./ckpt/checkpoint.pth'))运行推理模型时出现此错误:RuntimeError: size mismatch, m1: [37 x 1], m2: [37 x 64] at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2070这个错误意味着模型的输入必须是 37x64,其中 37 是数据维度,64 是训练批次大小。但是测试输入是 37x1,这意味着数据维度是 37,批大小是 1。重载pytorch模型中不同的批量大小有什么解决方案吗?非常感谢你。
查看完整描述

2 回答

?
吃鸡游戏

TA贡献1829条经验 获得超7个赞

当您建立模型时,您可以使用 -1 来动态表示您的批量大小。例如,下面是前向阶段代码


def forward(self, x):

     x = self.conv1(x)

     x = self.layer1(x)

     x = self.layer2(x)

     x = self.avgpool(x)

     x = x.view(-1, 37)

 #instead using x.view(64,37) 

     x = self.fc(x)

希望它可以帮助你


查看完整回答
反对 回复 2021-06-22
  • 2 回答
  • 0 关注
  • 173 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信