为了账号安全,请及时绑定邮箱和手机立即绑定

手把手教你使用PyTorch(2)-requires_grad&computation graph

import torch

1. Requires_grad

图片描述

但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记

在PyTorch中,通用的数据结构tensor包含一个attributerequires_grad,它被用于说明当前量是否需要在计算中保留对应的梯度信息,以上文所述的线性回归为例,容易知道参数www为需要训练的对象,为了得到最合适的参数值,我们需要设置一个相关的损失函数,根据梯度回传的思路进行训练。

图片描述

官方文档中的说明如下

If there’s a single input to an operation that requires gradient, its output will also require gradient.

只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。

而反之,若所有的输入都不需要保存梯度,那么输出的requires_grad会自动设置为False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。

Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.

对于那些要求梯度的tensor,PyTorch会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的tensor是不需要梯度的。

而我们在使用神经网络时,这些全连接层卷积层等结构的参数都是默认需要梯度的。

a = torch.tensor([1., 2., 3.])
print('a:', a.requires_grad)
b = torch.tensor([1., 4., 2.], requires_grad = True)
print('b:', b.requires_grad)
print('sum of a and b:', (a+b).requires_grad)
a: False
b: True
sum of a and b: True

2. Computation Graph

从PyTorch的设计原理上来说,在每次进行前向计算得到pred时,会产生一个用于梯度回传的计算图,这张图储存了进行back propagation需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放

这张计算图保存了计算的相关历史和提取计算所需的所有信息,以output作为root节点,以input和所有的参数为leaf节点,

we only retain the grad of the leaf node with requires_grad =True

在完成了前向计算的同时,PyTorch也获得了一张由计算梯度所需要的函数所组成的图

而从数据集中获得的input其requires_grad为False,故我们只会保存参数的梯度,进一步据此进行参数优化

在PyTorch中,multi-task任务一个标准的train from scratch流程为

for idx, data in enumerate(train_loader):
    xs, ys = data

    optmizer.zero_grad()
    # 计算d(l1)/d(x)
    pred1 = model1(xs) #生成graph1
    loss = loss_fn1(pred1, ys)
    loss.backward()  #释放graph1

    # 计算d(l2)/d(x)
    pred2 = model2(xs)#生成graph2
    loss2 = loss_fn2(pred2, ys)
    loss.backward()  #释放graph2

    # 使用d(l1)/d(x)+d(l2)/d(x)进行优化
    optmizer.step()

Computation Graph本质上是一个operation的图,所有的节点都是一个operation,而进行相应计算的参数则以叶节点的形式进行输入

借助torchviz库以下面的模型作为示例

import torch.nn.functional as F
import torch.nn as nn

class Conv_Classifier(nn.Module):
    def __init__(self):
        super(Conv_Classifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(5, 16, 5)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 20)
        self.fc2 = nn.Linear(20, 10)
    
    def forward(self, x):
        x = F.relu(self.pool1((self.conv1(x))))
        x = F.relu(self.pool2((self.conv2(x))))
        x = F.dropout2d(x, training=self.training)
        x = x.view(-1, 256)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

Mnist_Classifier = Conv_Classifier()
from torchviz import make_dot
input_sample = torch.rand((1, 1, 28, 28))
make_dot(Mnist_Classifier(input_sample), params=dict(Mnist_Classifier.named_parameters()))

其对应的计算梯度所需的图(计算图)为

图片描述
可以看到,所有的叶子节点对应的操作都被记录,以便之后的梯度回传。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消