为了账号安全,请及时绑定邮箱和手机立即绑定

pytorch 的 autograd 可以处理 torch.cat 吗?

pytorch 的 autograd 可以处理 torch.cat 吗?

交互式爱情 2021-09-11 15:21:40
我正在尝试实现一个应该学习灰度图像的简单神经网络。输入由像素的 2d 索引组成,输出应该是该像素的值。该网络的构造如下:每个神经元都连接到输入(即像素的索引)以及每个先前神经元的输出。输出只是这个序列中最后一个神经元的输出。这种网络在学习图像方面非常成功,如这里所示。问题: 我在执行之间的损失的功能住宿0.2,并0.4取决于神经元数目,学习速率和使用迭代的次数,这是非常糟糕的。此外,如果您将输出与我们在那里训练的内容进行比较,它看起来就像是噪音。但这是我第一次torch.cat在网络内使用,所以我不确定这是否是罪魁祸首。谁能看到我做错了什么?from typing import Listimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torch.nn import Linearclass My_Net(nn.Module):    lin: List[Linear]    def __init__(self):        super(My_Net, self).__init__()        self.num_neurons = 10        self.lin = nn.ModuleList([nn.Linear(k+2, 1) for k in range(self.num_neurons)])    def forward(self, x):        v = x        recent = torch.Tensor(0)        for k in range(self.num_neurons):            recent = F.relu(self.lin[k](v))            v = torch.cat([v, recent], dim=1)        return recent    def num_flat_features(self, x):        size = x.size()[1:]        num = 1        for i in size():            num *= i        return nummy_net = My_Net()print(my_net)#define a small 3x3 image that the net is supposed to learnmy_image = [[1.0, 1.0, 1.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] #represents a T-shapemy_image_flat = []    #output of the net is the value of a pixelmy_image_indices = [] #input to the net is are the 2d indices of a pixelfor i in range(len(my_image)):    for j in range(len(my_image[i])):        my_image_flat.append(my_image[i][j])        my_image_indices.append([i, j])#optimization loopfor i in range(100):    inp = torch.Tensor(my_image_indices)    out = my_net(inp)    target = torch.Tensor(my_image_flat)    criterion = nn.MSELoss()    loss = criterion(out.view(-1), target)    print(loss)    my_net.zero_grad()    loss.backward()    optimizer = optim.SGD(my_net.parameters(), lr=0.001)    optimizer.step()print("output of current image")print([[my_net(torch.Tensor([[i,j]])).item() for i in range(3)] for j in range(3)])print("output of original image")print(my_image)
查看完整描述

1 回答

  • 1 回答
  • 0 关注
  • 221 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信