为了账号安全,请及时绑定邮箱和手机立即绑定

pytorch 预测稳定性

pytorch 预测稳定性

芜湖不芜 2022-04-27 15:58:42
这是我的预测功能。这有什么问题吗?预测并不稳定,每次我运行相同的数据时,我都会得到不同的预测。def predict(model, device, inputs, batch_size=1024):    model = model.to(device)    dataset = torch.utils.data.TensorDataset(*inputs)    loader = torch.utils.data.DataLoader(                    dataset,                     batch_size=batch_size,                    pin_memory=False                )    predictions = []    for i, batch in enumerate(loader):        with torch.no_grad():            pred = model(*(item.to(device) for item in batch))            pred = pred.detach().cpu().numpy()        predictions.append(pred)    return np.concatenate(predictions)
查看完整描述

2 回答

?
智慧大石

TA贡献1946条经验 获得超3个赞

正如Usman Ali建议的那样,您需要eval通过调用将模型设置为模式

model.eval()

在你的prediction功能之前。

什么eval模式:

将模块设置为评估模式。

这仅对某些模块有任何影响。请参阅特定模块的文档以了解其在训练/评估模式下的行为的详细信息,如果它们受到影响,例如 Dropout、BatchNorm 等。

当您完成预测并希望继续训练时,不要忘记通过调用将模型重置为训练模式

model.train()

模型中有几层可能会将随机性引入网络的前向传播。一个这样的例子是dropout层。辍学层随机“丢弃”p其神经元的百分比以增加模型的泛化性。
此外,BatchNorm(以及可能的其他自适应归一化层)跟踪数据的统计信息,因此在train模式或eval模式中具有不同的“行为”。


查看完整回答
反对 回复 2022-04-27
?
慕妹3242003

TA贡献1824条经验 获得超6个赞

您已定义函数,但尚未训练模型。该模型在训练之前将预测随机化,这就是您的预测不一致的原因。如果您设置一个带有损失函数的优化器,并运行多个 epoch,则预测将稳定。此链接可能会有所帮助:https ://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 。看第 3 节和第 4 节



查看完整回答
反对 回复 2022-04-27
  • 2 回答
  • 0 关注
  • 156 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信