我已经在 Pytorch 中为 224x224 大小的图像和 4 个类训练了这个网络。class CustomConvNet(nn.Module): def __init__(self, num_classes): super(CustomConvNet, self).__init__() self.layer1 = self.conv_module(3, 64) self.layer2 = self.conv_module(64, 128) self.layer3 = self.conv_module(128, 256) self.layer4 = self.conv_module(256, 256) self.layer5 = self.conv_module(256, 512) self.gap = self.global_avg_pool(512, num_classes) #self.linear = nn.Linear(512, num_classes) #self.relu = nn.ReLU() #self.softmax = nn.Softmax() def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.layer5(out) out = self.gap(out) out = out.view(-1, 4) #out = self.linear(out) return out def conv_module(self, in_num, out_num): return nn.Sequential( nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2), stride=None)) def global_avg_pool(self, in_num, out_num): return nn.Sequential( nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1), #nn.BatchNorm2d(out_num), #nn.LeakyReLU(), nn.ReLU(), nn.Softmax(), nn.AdaptiveAvgPool2d((1, 1)))我从第一个 Conv2D 得到了权重,它的大小torch.Size([64, 3, 3, 3])我已将其保存为:weightsCNN = net.layer1[0].weight.datanp.save('CNNweights.npy', weightsCNN)这是我在 Tensorflow 中构建的模型。我想将从 Pytorch 模型中保存的权重传递到这个 Tensorflow CNN 中。 model = models.Sequential() model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3))) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(128, (3, 3), activation='relu')) model.add(layers.MaxPooling2D((2, 2)))我应该怎么做?Tensorflow 需要什么形状的权重?谢谢!
1 回答
郎朗坤
TA贡献1921条经验 获得超9个赞
keras
您可以非常简单地检查所有层的所有权重的形状:
for layer in model.layers: print([tensor.shape for tensor in layer.get_weights()])
这将为您提供所有权重的形状(包括偏差),因此您可以numpy
相应地准备加载的权重。
要设置它们,请执行类似的操作:
for torch_weight, layer in zip(model.layers, torch_weights): layer.set_weights(torch_weight)
wheretorch_weights
应该是一个列表,np.array
其中包含您必须加载的列表。
通常每个元素torch_weights
都包含一个np.array
权重和一个偏差。
请记住,从打印中收到的形状必须与您放入的形状完全相同set_weights
。
有关更多信息,请参阅文档。
顺便提一句。确切的形状取决于模型执行的层和操作,有时您可能必须转置一些数组以“适应它们”。
添加回答
举报
0/150
提交
取消