为了账号安全,请及时绑定邮箱和手机立即绑定

如何在 pytorch nn.module 中设置图层的值?

如何在 pytorch nn.module 中设置图层的值?

狐的传说 2022-10-18 16:06:52
我有一个模型,我正在尝试使用它。我正在解决这些错误,但现在我认为它已经归结为我层中的值。我收到此错误:RuntimeError: Given groups=1, weight of size 24 1 3 3, expected input[512, 50, 50, 3] to have 1 channels, but got 50 channels instead我的参数是:LR = 5e-2N_EPOCHS = 30BATCH_SIZE = 512DROPOUT = 0.5我的图像信息是:depth=24channels=3original height = 1600original width = 1200resized to 50x50这是我的数据的大小:Train shape (743, 50, 50, 3) (743, 7)Test shape (186, 50, 50, 3) (186, 7)Train pixels 0 255 188.12228712427097 61.49539262385051Test pixels 0 255 189.35559211469533 60.688278787628775我在这里试图了解每一层的期望值,但是当我在这里输入它所说的内容时,https://towardsdatascience.com/pytorch-layer-dimensions-what-sizes-should-they-be-and-为什么-4265a41e01fd,它给了我关于错误通道和内核的错误。我发现 torch_summary 让我对输出有更多的了解,但它只会提出更多的问题。这是我的 torch_summary 代码:from torchvision import modelsfrom torchsummary import summaryimport torchimport torch.nn as nnclass CNN(nn.Module):    def __init__(self):        super(CNN, self).__init__()        self.conv1 = nn.Conv2d(1,24, kernel_size=5)  # output (n_examples, 16, 26, 26)        self.convnorm1 = nn.BatchNorm2d(24) # channels from prev layer        self.pool1 = nn.MaxPool2d((2, 2))  # output (n_examples, 16, 13, 13)        self.conv2 = nn.Conv2d(24,48,kernel_size=5)  # output (n_examples, 32, 11, 11)        self.convnorm2 = nn.BatchNorm2d(48) # 2*channels?        self.pool2 = nn.AvgPool2d((2, 2))  # output (n_examples, 32, 5, 5)        self.linear1 = nn.Linear(400,120)  # input will be flattened to (n_examples, 32 * 5 * 5)        self.linear1_bn = nn.BatchNorm1d(400) # features?        self.drop = nn.Dropout(DROPOUT)        self.linear2 = nn.Linear(400, 10)        self.act = torch.relu
查看完整描述

1 回答

?
慕姐4208626

TA贡献1852条经验 获得超7个赞

看来您输入x张量轴的顺序错误。
正如您在输入中看到的,必须是doc Conv2d(N, C, H, W)

N是批量大小,C表示通道数,H是以像素为单位的输入平面的高度,以像素为单位W的宽度。

因此,为了正确使用torch.permute前传中的交换轴。

...

def forward(self, x):

    x = x.permute(0, 3, 1, 2)

    ...

    ...

    return self.linear2(x)

...

示例permute:


t = torch.rand(512, 50, 50, 3)

t.size()

torch.Size([512, 50, 50, 3])


t = t.permute(0, 3, 1, 2)

t.size()

torch.Size([512, 3, 50, 50])


查看完整回答
反对 回复 2022-10-18
  • 1 回答
  • 0 关注
  • 88 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号