Happy · 2021年02月03日

Pytorch量化入门之超分量化(一)

首发:AIWalker
作者:HappyAIWalker

最近Happy在尝试进行图像超分的INT8量化,发现:pytorch量化里面的坑真多,远不如TensorFlow的量化好用。不过花了点时间终于还是用pytorch把图像超分模型完成了量化,以EDSR为例,模型大小73%,推理速度提升40%左右(PC端),视觉效果几乎无损,定量指标待补充。有感于网络上介绍量化的博客一堆,但真正有帮助的较少,所以Happy会尽量以图像超分为例提供一个完整的可复现的量化示例(分两章内容进行)。

背景

量化在不同领域有不同的定义,而在深度学习领域,量化有两个层面的意义:(1) 存储量化,即更少的bit来存储原本需要用浮点数(一般为FP32)存储的tensor;(2) 计算量化,即用更少的bit来完成原本需要基于浮点数(一般为FP32,FP16现在也是常用的一种)完成的计算。量化一般有这样两点好处:

  • 更小的模型体积,理论上减少为FP32模型的75%左右,从笔者不多的经验来看,往往可以减少73%;
  • 更少的内存访问与更快的INT8计算,从笔者的几个简单尝试来看,一般可以加速40%左右,这个还会跟平台相关。

对于量化后模型而言,其部分或者全部tensor(与量化方式、量化op的支持程度有关)将采用INT类型进行计算,而非量化前的浮点类型。量化对于底层的硬件支持、推理框架等要求还是比较高的,目前X86CPU,ARMCPU,Qualcomm DSP等主流硬件对量化都提供了支持;而NCNN、MACE、MNN、TFLite、Caffe2、TensorRT等推理框架也都对量化提供了支持,不过不同框架的支持度还是不太一样,这个就不细说了,感兴趣的同学可以自行百度一下。
笔者主要用Pytorch进行研发,所以花了点精力对其进行了一些研究&尝试。目前Pytorch已经更新到了1.7版本,基本上支持常见的op,可以参考如下:

  • Activation:ReLU、ReLU6、Hardswish、ELU;
  • Normalization:BatchNorm、LayerNorm、GroupNorm、InstanceNorm;
  • Convolution:Conv1d、Conv2d、Conv3d、ConvTranspose1d、ConvTranspose2d、Linear;
  • Other:Embedding、EmbeddingBag。

目前Pytorch支持的量化有如下三种方式:

  • Post Training Dynamic Quantization:动态量化,推理过程中的量化,这种量化方式常见诸于NLP领域,在CV领域较少应用;
  • Post Training Static Quantization:静态量化,训练后静态量化,这是CV领域应用非常多的一种量化方式;
  • Quantization Aware Training:感知量化,边训练边量化,一种比静态量化更优的量化方式,但量化时间会更长,但精度几乎无损。

注:笔者主要关注CV领域,所以本文也将主要介绍静态量化与感知量化这种方式。

Tensor量化

要实现量化,那么就不可避免会涉及到tensor的量化,一般来说,量化公式可以描述如下:
目前Pytorch中的tensor支持int8/uint8/int32等类型的数据,并同时scale、zero\_point、quantization\_scheme等量化信息。这里,我们给出一个tensor量化的简单示例:

x = torch.rand(3, 3)  
print(x)  
x = torch.quantize_per_tensor(x, scale=0.2, zero_point=3, dtype=torch.quint8)  
print(x)  
print(x.int_repr())  

一个参考输出如下所示:

image.png

注1:蓝框为原始的浮点数据,红框为tensor的量化信息,绿框则对应了量化后的INT8数值。
注2:量化不可避免会出现精度损失,这个损失与scale、zero\_point有关。
在量化方面,Tensor一般有两种量化模式:per tensor与per channel。对于PerTensor而言,它的所有数值都按照相同方式进行scale和zero\_point处理;而对于PerChannel而言,它有多种不同的scale和zero\_point参数,这种方式的量化精度损失更少。

Post Training Static Quantization

静态量化一般有两种形式:(1) 仅weight量化;(2) weight与activation同时量化。对于第一种“仅weight量化”而言,只针对weight量化可以使得模型参数所占内存显著减小,但在实际推理过程中仍需要转换成浮点数进行计算;而第二种“weight与activation同时量化”则不仅对weight进行量化,还需要结合校验数据进行activation的量化。第一种的量化非常简单,这里略过,本文仅针对第二种方式进行介绍。
Pytorch的静态量化一把包含五个步骤:

  • fuse\_model:该步骤用来对可以融合的op进行融合,比如Conv与BN的融合、Conv与ReLU的融合、Conv与BN以及ReLU的融合、Linear与BN的融合、Linear与BN以及ReLU的融合。目前Pytorch已经内置的融合code:
fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None)  

在完成融合后,第一个op将被替换会融合后的op,而其他op则会替换为

nn.Identity。
  • qconfig:该步骤用于设置用于模型量化的方式,它将插入两个observer,一个用于监测activation,一个用于监测weight。考虑到推理平台的不同,pytorch提供了两种量化配置:针对x86平台的fbgemm以及针对arm平台的qnnpack

不同平台的量化配置方式存在些微的区别,大概如下:

image.png

  • Prepare:该步骤用于给每个支持量化的模块插入Observer,用于收集数据并进行量化数据分析。以activation为例,它将根据所喂入数据统计min\_val与max\_val,一般观察几个次迭代即可,然后根据所观察到数据进行统计分析得到scale与zero\_point。
  • Feed Data:为了更好的获得activation的量化参数信息,我们需要一个合适大小的校验数据,并将其送入到前述模型中。这个就比较简单了,就按照模型验证方式往里面送数据就可以了。
  • Convert:在完成前述四个步骤后,接下来就需要将完成量化的模型转换为量化后模型了,这个就比较简单了,通过如下命令即可。
torch.quantization.convert(model, inplace=True)  

该过程本质上就是用量化OP替换模型中的费量化OP,比如用nnq.Conv2d替换nn.Conv2d, nnq.ConvReLU2d替换nni.ConvReLU2d(注:这是Conv与ReLU的合并)。之前的量化op以及对应的被替换op列表如下:

DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {  
    QuantStub: nnq.Quantize,  
    DeQuantStub: nnq.DeQuantize,  
    nn.BatchNorm2d: nnq.BatchNorm2d,  
    nn.BatchNorm3d: nnq.BatchNorm3d,  
    nn.Conv1d: nnq.Conv1d,  
    nn.Conv2d: nnq.Conv2d,  
    nn.Conv3d: nnq.Conv3d,  
    nn.ConvTranspose1d: nnq.ConvTranspose1d,  
    nn.ConvTranspose2d: nnq.ConvTranspose2d,  
    nn.ELU: nnq.ELU,  
    nn.Embedding: nnq.Embedding,  
    nn.EmbeddingBag: nnq.EmbeddingBag,  
    nn.GroupNorm: nnq.GroupNorm,  
    nn.Hardswish: nnq.Hardswish,  
    nn.InstanceNorm1d: nnq.InstanceNorm1d,  
    nn.InstanceNorm2d: nnq.InstanceNorm2d,  
    nn.InstanceNorm3d: nnq.InstanceNorm3d,  
    nn.LayerNorm: nnq.LayerNorm,  
    nn.LeakyReLU: nnq.LeakyReLU,  
    nn.Linear: nnq.Linear,  
    nn.ReLU6: nnq.ReLU6,  
    # Wrapper Modules:  
    nnq.FloatFunctional: nnq.QFunctional,  
    # Intrinsic modules:  
    nni.BNReLU2d: nniq.BNReLU2d,  
    nni.BNReLU3d: nniq.BNReLU3d,  
    nni.ConvReLU1d: nniq.ConvReLU1d,  
    nni.ConvReLU2d: nniq.ConvReLU2d,  
    nni.ConvReLU3d: nniq.ConvReLU3d,  
    nni.LinearReLU: nniq.LinearReLU,  
    nniqat.ConvBn1d: nnq.Conv1d,  
    nniqat.ConvBn2d: nnq.Conv2d,  
    nniqat.ConvBnReLU1d: nniq.ConvReLU1d,  
    nniqat.ConvBnReLU2d: nniq.ConvReLU2d,  
    nniqat.ConvReLU2d: nniq.ConvReLU2d,  
    nniqat.LinearReLU: nniq.LinearReLU,  
    # QAT modules:  
    nnqat.Linear: nnq.Linear,  
    nnqat.Conv2d: nnq.Conv2d,  
}   

在完成模型量化后,我们就要考虑量化模型的推理了。其实量化模型的推理与浮点模型的推理没什么本质区别,最大的区别有这么两点:

  • 量化节点插入:需要在网络的forward里面插入QuantStub与DeQuantSub两个节点。一个非常简单的参考示例,摘自torchvision.model.quantization.resnet.py。
class QuantizableResNet(ResNet):  
  
    def __init__(self, *args, **kwargs):  
        super(QuantizableResNet, self).__init__(*args, **kwargs)  
  
        self.quant = torch.quantization.QuantStub()  
        self.dequant = torch.quantization.DeQuantStub()  
  
    def forward(self, x):  
        x = self.quant(x)  
        # Ensure scriptability  
        # super(QuantizableResNet,self).forward(x)  
        # is not scriptable  
        x = self._forward_impl(x)  
        x = self.dequant(x)  
        return x  
  • op替换:需要将模型中的Add、Concat等操作替换为支持量化的FloatFunctional,可参考如下示例。
class QuantizableBasicBlock(BasicBlock):  
    def __init__(self, *args, **kwargs):  
        super(QuantizableBasicBlock, self).__init__(*args, **kwargs)  
        self.add_relu = torch.nn.quantized.FloatFunctional()  
  
    def forward(self, x):  
        identity = x  
  
        out = self.conv1(x)  
        out = self.bn1(out)  
        out = self.relu(out)  
  
        out = self.conv2(out)  
        out = self.bn2(out)  
  
        if self.downsample is not None:  
            identity = self.downsample(x)  
  
        out = self.add_relu.add_relu(out, identity)  
  
        return out  

推荐阅读

本文章著作权归作者所有,任何形式的转载都请注明出处。更多动态滤波,图像质量,超分辨相关请关注我的专栏深度学习从入门到精通
推荐阅读
关注数
6197
内容数
191
夯实深度学习知识基础, 涵盖动态滤波,超分辨,轻量级框架等
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息