Happy · 2022年04月14日

NAFNet :无需非线性激活,真“反直觉”!但复原性能也是真强!

edf0b350ea9c1abfe664048b069cce73.jpg

本文提出一种超简基线方案Baseline,它不仅计算高效同时性能优于之前SOTA方案;在所得Baseline基础上进一步简化得到了NAFNet:移除了非线性激活单元且性能进一步提升。所提方案在SIDD降噪与GoPro去模糊任务上均达到了新的SOTA性能,同时计算量大幅降低(可参考下图)。

bbe57ad74babe659c6be470a7d9c777f.jpg

1、Building A Simple Baseline

在该部分内容中,我们将从头开始构建一个用于图像复原的简单基线(Simple Baseline)。为保证结构的简洁性,基本原则是:如无必要,勿增实体(奥卡姆剃刀)。参考HINet-Simple,我们主要在16GMACs计算量(输入为)范围进行试验分析,其他计算量模型的结果见试验部分;在任务方面,我们主要在SIDD降噪、GoPro去模糊上进行验证。

Architecture

25360dca8934f736f4d906a6d7636afc.jpg

上图给出了图像复原领域常用架构示意图,包含多阶段架构、多尺度融合架构以及UNet架构。为减少块间(inter-block)复杂度,我们采用了带跳过连接的UNet架构

2、A Plain Block

神经网络一般采用模块堆叠方式构建,所选用的UNet架构决定了模块堆叠方式,但模块的设计仍然是个问题。

a8d9384beb9ed844a208db84a43c9a04.jpg

上图a给出了Restormer一文所构建的模块,我们以其作为参考并进行简化:采用卷积替代Transformer(见上图b)。这里的替换主要是基于以下三个考量:

  • 尽管Transformer在CV领域表现出了惊人的优势,但一些研究表明:Transformer并非达成SOTA结果的必要条件;
  • depthwise卷积比自注意力更简单;
  • 本文并非旨在讨论Transformer与卷积的优劣,而仅在于提供了一个简单基线。

Normalization

归一化技术在high-level任务中已被广泛应用,但在low-level任务中应用极少。但是,依托于Transformer,LN得到了越来越多的应用。基于该事实,我们猜想:LN可能是达成SOTA复原器的关键,故在上述Plain模块中添加了LN(见上面图示c)。LN的引入使得训练更平滑,甚至可以将学习率放大10倍更大的学习率可以带来显著性能提升:0.44dB@SIDD(39.29dB→39.73dB),3.39dB@GoPro(28.51dB → 31.90dB)。

Activation

尽管ReLU是最常用的激活函数,现有SOTA方案中采用GELU进行代替。激活函数的替代在性能方面导致:-0.02dB@SIDD(39.73dB → 39.71dB),0.21dB@GoPro(31.90dB → 21.11dB)。由于GELU可以保持降噪性能相当且大幅提升去模糊性能,故我们采用GELU替代ReLU(见上面图示c)。

Attention

受启发于Restormer中的注意力机制,我们意识到:普通通道注意力可以满足计算效率需求并引入全局信息;此外通道注意力的有效性已在多个图像复原任务中得到验证。因此,我们进一步添加通道注意力,见上面图示c。通道注意力可以带来额外的性能提升:0.14dB@SIDD(37.71dB → 39.85dB)、0.24dB@GoPro(32.11dB → 32.35dB)。

Summary

1ec3e7c5133d97bc5e418c3a461c45a0.jpg

到此,我们从头开始构建了本文的Baseline(结果见上表)。尽管所设计模块中的每个成分都非常简单,但组合后可以得到一个强基线方案:在SIDD与GoPro数据集上超越了其他SOTA方案,同时计算量大幅降低。

3、Nonlinear Activation Free Network

尽管上述所提Baseline足够简单且竞争力,那么是否可能在确保简洁性的同时进一步提升性能呢是否可以更简介且无性能损失呢?我们尝试从SOTA方案(VRT, MAXIM, Restormer)中寻找共性点以回答上述问题,我们发现:这些方案均采用了Gated Linear Units(GLU,定义如下)。
1649902746(1).png
将GLU引入到Baseline中可能会改善性能,但同时会导致块内(intra-block)计算复杂度提升,而这并非我们所期望的。

为此,我们对Baseline中的激活函数进行了回顾,其定义与近似实现如下:
1649902765(1).png
GELU与GLU的实现可以发现:GELU是GLU的一种特例。我们从另一个角度猜想:GLU可视作一种广义激活函数,它是可以用于替代非线性激活函数。此外,我们注意到:**GLU自身已包含非线性且该非线性并不依赖****1649902808(1)(1).png**。

54e3a4f3cfc0d9b651a4e2f3ac7df983.jpg

基于上述,我们提出了一种简化版GLU变种(见上图c):直接将特征沿通道维度分成两部分并相乘。采用所提SimpleGate对GELU进行替换导致的性能提升为:0.08dB@SIDD(39.85dB → 39.93dB)、0.41dB@GoPro(32.35dB → 32.76dB)。相比GELU的复杂实现,SimpleGate的实现非常简单:

Simplified Channel AttentionBaseline方案中采用了通道注意力(见上图a),它定义如下:

可以看到:CA的定义与GLU非常像。这就是促使我们将CA视作GLU的一种特例并可进一步简化。通过保留通道注意力的两个重要作用(全局信息聚合、通道信息交互),我们提出了如下简化版通道注意力(见上面图示b)

为公平对比,我们调整了CA的特征维度以保持与SCA计算复杂度相当。尽管SCA足够简单,但它并未造成性能损失:0.03dB@SIDD(39.93dB → 39.96dB),0.09dB@GoPro(32.76dB → 32.85dB)。

d11f0b1ba4616855eff6dfae61cfc439.jpg

以Baseline为基础,我们采用SimpleGate替换GELU、采用SCA替换CA达成了进一步的简化,且未噪声性能损失。值得一提的是:简化后的网络中不包含非线性激活函数。因此,我们将所得方案称之为NAFNet(Nonlinear Activation Free Network)。

4、Experiments

008062a6fe4dbfe19b50df7e527da802.jpg

f8ba9799168947c03c7b1de7214c8d58.jpg

上图与表为SIDD数据集上不同方案的性能对比,可以看到:

  • 所提Baseline与NAFNet以0.28dB指标优于此前最佳方案Restormer,同时计算量更低
  • 在重建效果方面,相比其他方案,所提方案可以重建更细粒度细节。

69b02f46cb47ecd5bff354c0ac6d8ed3.jpg

11356a573388ea5053353596b5ad7ff4.jpg

上图与表为GoPro数据集上不同方案的性能对比,可以看到:

  • 所提Baseline与NAFNet分别比此前最佳方案MPRNet-local高0.09dB与0.38dB,同时仅需8.4%!的(MISSING)计算量
  • 在重建效果方面,相比其他方案,所提方案的重建结果更锐利。

391503b07a202ca20996fd6904309323.jpg

de131604de7c34af99ac8a0a77ac0d73.jpg

上图与表为Raw降噪与JPEG伪影+去模糊组合任务(NTIRE2021图像去模糊Track2)上的性能对比,可以看到:

  • 相比PMRID,所提方案NAFNet(通道数与模块数进行了减少以确保计算量相当)具有更高的PSNR指标,同时具有更优的重建效果。该实验同时说明了NAFNet - 从Table8可以看到:相比NTIRE2021竞赛冠军方案HINet与MAXIM,所提NAFNet取得了更优的PSNr与SSIM指标,同时具有更低的计算量(约三分之一)。

5、Code Implement

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2
        
class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_ratio=0):
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(c, dw_channel, 1)
        self.conv2 = nn.Conv2d(dw_channel, dw_channel, 3, 1, 1 group=dw_channel)
        self.conv3 = nn.Conv2d(dw_channel//2, dw_channel, 1)
        
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dw_channel//2, dw_channel//2, 1)
        )
        
        self.sg = SimpleGate()
        
        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(c, ffn_channel, 1)
        self.conv5 = nn.Conv2d(ffn_channel, c, 1)
        
        self.norm1 = LayerNorm2d()
        self.norm2 = LayerNorm2d()
        
        # skip-init trick to stabilize training.
        self.beta = nn.Parameter(torch.zeros((1,c,1,1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1,c,1,1)), requires_grad=True)
        
    def forward(self, inp):
        x = inp 
        
        x = self.norm(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)
        
        y = inp + x * self.beta
        
        x = self.norm2(y)
        x = self.conv4(x)
        x = self.sg(x)
        x = self.conv5(x)
        
        y = y + x * self.gamma
        return y
来源:AIWalker
作者:HappyAIWalker

推荐阅读

本文章著作权归作者所有,任何形式的转载都请注明出处。更多动态滤波,图像质量,超分辨相关请关注我的专栏AIWalker
推荐阅读
关注数
6194
内容数
191
夯实深度学习知识基础, 涵盖动态滤波,超分辨,轻量级框架等
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息