为了账号安全,请及时绑定邮箱和手机立即绑定

Pytorch 使用知识转移保存和加载 VGG16

Pytorch 使用知识转移保存和加载 VGG16

拉莫斯之舞 2021-06-04 18:45:26
我使用以下语句保存了一个带有知识转移的 VGG16:torch.save(model.state_dict(), 'checkpoint.pth')并使用以下语句重新加载:state_dict = torch.load('checkpoint.pth') model.load_state_dict(state_dict)只要我重新加载 VGG16 模型并使用以下代码为其提供与以前相同的设置,就可以工作:model = models.vgg16(pretrained=True)model.cuda()for param in model.parameters(): param.requires_grad = Falseclass Network(nn.Module):    def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):#             input_size: integer, size of the input#             output_size: integer, size of the output layer#             hidden_layers: list of integers, the sizes of the hidden layers#             drop_p: float between 0 and 1, dropout probability        super().__init__()        # Add the first layer, input to a hidden layer        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])        # Add a variable number of more hidden layers        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])        self.output = nn.Linear(hidden_layers[-1], output_size)        self.dropout = nn.Dropout(p=drop_p)    def forward(self, x):        ''' Forward pass through the network, returns the output logits '''        # Forward through each layer in `hidden_layers`, with ReLU activation and dropout        for linear in self.hidden_layers:            x = F.relu(linear(x))            x = self.dropout(x)        x = self.output(x)        return F.log_softmax(x, dim=1)classifier = Network(25088, 102, [4096], drop_p=0.5)model.classifier = classifier如何避免这种情况?如何重新加载模型而不必重新加载 VGG16 并重新定义分类器?
查看完整描述

1 回答

  • 1 回答
  • 0 关注
  • 274 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信