为了账号安全,请及时绑定邮箱和手机立即绑定

如何将 PReLU 合并到量化模型中?

如何将 PReLU 合并到量化模型中?

慕哥9229398 2023-03-08 16:13:07
我正在尝试量化使用PReLU. 替换PReLU为ReLU是不可能的,因为它会极大地影响网络性能,以至于无法使用。据我所知,PReLU在量化方面,Pytorch 不支持。所以我尝试手动重写这个模块并实现乘法和加法torch.FloatFunctional()来绕过这个限制。这是我到目前为止提出的:class PReLU_Quantized(nn.Module):    def __init__(self, prelu_object):        super().__init__()        self.weight = prelu_object.weight        self.quantized_op = nn.quantized.FloatFunctional()        self.quant = torch.quantization.QuantStub()        self.dequant = torch.quantization.DeQuantStub()    def forward(self, inputs):        # inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)            self.weight = self.quant(self.weight)        weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])        inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)        self.weight = self.dequant(self.weight)        return inputs和更换:class model(nn.Module):     def __init__(self)         super().__init__()         ....         self.prelu = PReLU()        self.prelu_q = PReLU_Quantized(self.prelu)         ....基本上,我读取现有 prelu 模块的学习参数,并在新模块中自己运行计算。从某种意义上说,该模块似乎在工作,它并没有使整个应用程序失败。但是,为了评估我的实现是否真的正确并产生与原始模块相同的结果,我尝试对其进行测试。这是普通模型(即非量化模型)的对应物:由于某种原因,实际与我的实现之间的误差PReLU非常大!以下是不同层的示例差异:diff : 1.1562038660049438diff : 0.02868632599711418diff : 0.3653906583786011diff : 1.6100226640701294diff : 0.8999372720718384diff : 0.03773299604654312diff : -0.5090572834014893diff : 0.1654307246208191diff : 1.161868691444397diff : 0.026089997962117195diff : 0.4205571115016937diff : 1.5337920188903809diff : 0.8799554705619812diff : 0.03827812895178795diff : -0.40296515822410583diff : 0.15618863701820374并且在正向传播中 diff 是这样计算的:def forward(self, x):    residual = x    out = self.bn0(x)    out = self.conv1(out)    out = self.bn1(out)    out = self.prelu(out)    out2 = self.prelu2(out)    print(f'diff : {( out - out2).mean().item()}')    out = self.conv2(out)...我在这里错过了什么?
查看完整描述

1 回答

?
偶然的你

TA贡献1841条经验 获得超3个赞

我想到了!我一开始就犯了一个大错误。我需要计算

PReLU(x)=max(0,x)+a∗min(0,x)

或者 不是实际的!或者!这没有任何意义!这是普通模型(即未量化)的最终解决方案!:

//img1.sycdn.imooc.com//640843da0001030d03330168.jpg

torch.mintorch.max


class PReLU_2(nn.Module):

    def __init__(self, prelu_object):

        super().__init__()

        self.prelu_weight = prelu_object.weight

        self.weight = self.prelu_weight


    def forward(self, inputs):

        pos = torch.relu(inputs)

        neg = -self.weight * torch.relu(-inputs)

        inputs = pos + neg

        return inputs

这是量化版本:


class PReLU_Quantized(nn.Module):

    def __init__(self, prelu_object):

        super().__init__()

        self.prelu_weight = prelu_object.weight

        self.weight = self.prelu_weight

        self.quantized_op = nn.quantized.FloatFunctional()

        self.quant = torch.quantization.QuantStub()

        self.dequant = torch.quantization.DeQuantStub()


    def forward(self, inputs):

        # inputs = max(0, inputs) + alpha * min(0, inputs) 

        self.weight = self.quant(self.weight)

        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))

        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)

        inputs = self.dequant(inputs)

        self.weight = self.dequant(self.weight)

        return inputs


旁注:

我在计算差异时也有错别字:


    out = self.prelu(out)

    out2 = self.prelu2(out)

    print(f'diff : {( out - out2).mean().item()}')


    out = self.conv2(out)

需要是


    out1 = self.prelu(out)

    out2 = self.prelu2(out)

    print(f'diff : {( out1 - out2).mean().item()}')

    out = self.conv2(out1)

更新:

如果您在量化方面遇到问题,您可以尝试这个版本


import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.nn.quantized as nnq

from torch.quantization import fuse_modules



class QPReLU(nn.Module):

    def __init__(self, num_parameters=1, init: float = 0.25):

        super(QPReLU, self).__init__()

        self.num_parameters = num_parameters

        self.weight = nn.Parameter(torch.Tensor(num_parameters).fill_(init))

        self.relu1 = nn.ReLU()

        self.relu2 = nn.ReLU()

        self.f_mul_neg_one1 = nnq.FloatFunctional()

        self.f_mul_neg_one2 = nnq.FloatFunctional()

        self.f_mul_alpha = nnq.FloatFunctional()

        self.f_add = nnq.FloatFunctional()

        self.quant = torch.quantization.QuantStub()

        self.dequant = torch.quantization.DeQuantStub()

        self.quant2 = torch.quantization.QuantStub()

        self.quant3 = torch.quantization.QuantStub()

        # self.dequant2 = torch.quantization.QuantStub()

        self.neg_one = torch.Tensor([-1.0])

        

    

    def forward(self, x):

        x = self.quant(x)

        

        # PReLU, with modules only

        x1 = self.relu1(x)

        

        neg_one_q = self.quant2(self.neg_one)

        weight_q = self.quant3(self.weight)

        x2 = self.f_mul_alpha.mul(

            weight_q, self.f_mul_neg_one2.mul(

                self.relu2(

                    self.f_mul_neg_one1.mul(x, neg_one_q),

                ),

            neg_one_q)

        )

        

        x = self.f_add.add(x1, x2)

        x = self.dequant(x)

        return x

    

m1 = nn.PReLU()

m2 = QPReLU()


# check correctness in fp

for i in range(10):

    data = torch.randn(2, 2) * 1000

    assert torch.allclose(m1(data), m2(data))


# toy model

class M(nn.Module):

    def __init__(self):

        super(M, self).__init__()

        self.prelu = QPReLU()

        

    def forward(self, x):

        x = self.prelu(x)

        return x

    

# quantize it

m = M()

m.qconfig = torch.quantization.default_qconfig

torch.quantization.prepare(m, inplace=True)

# calibrate

m(torch.randn(4, 4))

# convert

torch.quantization.convert(m, inplace=True)

# run some data through

res = m(torch.randn(4, 4))

print(res)

并确保阅读此处的注释



查看完整回答
反对 回复 2023-03-08
  • 1 回答
  • 0 关注
  • 345 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信