1 回答
![?](http://img1.sycdn.imooc.com/54584ee0000179f302200220-100-100.jpg)
TA贡献1911条经验 获得超7个赞
ResNet 实现
我能想到的最简单的投影通用版本是这样的:
class Residual(torch.nn.Module):
def __init__(self, module: torch.nn.Module, projection: torch.nn.Module = None):
super().__init__()
self.module = module
self.projection = projection
def forward(self, inputs):
output = self.module(inputs)
if self.projection is not None:
inputs = self.projection(inputs)
return output + inputs
您可以传递module两个堆叠卷积之类的东西,并添加1x1卷积(带有填充或步幅或其他东西)作为投影模块。
对于tabular数据,您可以将其用作module(假设您的输入具有50功能):
torch.nn.Sequential(
torch.nn.Linear(50, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 50),
)
基本上,您所要做的就是将input某个模块添加到其输出中,仅此而已。
理由如下nn.Identity
构建神经网络(然后读取它们)可能会更容易,例如批量归一化(取自上述 PR):
batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
batch_norm = Identity
现在您可以nn.Sequential轻松地使用它:
nn.Sequential(
...
batch_norm(N, momentum=0.05),
...
)
当打印网络时,它总是具有相同数量的子模块(带有BatchNorm或Identity),这也使整个过程在我看来更加流畅。
这里提到的另一个用例可能是删除现有神经网络的部分内容:
net = tv.models.alexnet(pretrained=True)
# Assume net has two parts
# features and classifier
net.classifier = Identity()
net.features(input)现在,您可以运行而不是运行net(input),这对其他人来说也更容易阅读。
添加回答
举报