CVPR2022:计算机视觉中长尾数据平衡对比学习

【前言】

现实中的数据通常存在长尾分布,其中一些类别占据数据集的大部分,而大多数稀有样本包含的数量有限,使用交叉熵的分类模型难以很好的分类尾部数据。在这篇论文中,作者专注不平衡数据的表示学习。通过作者的理论分析,发现对于长尾数据,它无法形成理想的几何结构(在下文中解释该结构)。为了纠正 SCL(Supervised Contrastive Learning,有监督对比学习) 的优化行为并进一步提高长尾视觉识别的性能,作者提出了一种新的BCL(Balanced Contrastive Learning,平衡对比学习)损失。

与 SCL 相比,作者在 BCL 上有两个改进:

  • 一个是平衡负样本的梯度贡献,称为类平均;
  • 一个是使所有类都出现在每个mini-batch中,称为类补充。

image.png
Paper:https://openaccess.thecvf.com/content/CVPR2022/papers/Zhu_Balanced_Contrastive_Learning_for_Long-Tailed\_Visual_Recognition_CVPR_2022_paper.pdf

Code:https://github.com/FlamieZhu/BCL

一、过去的方式

为了解决数据不均衡的问题,早期的方法有:

  1. 对训练数据进行重新采样,低采样高频类或过采样低频类;
  2. 采用加权计算损失函数的方式来关注稀缺类别,为每个类或每个例子的不同的训练样本分配不同的损失;

最近也有一些新研究,如logits补偿方式来校准分布的数据。解耦则是采取一种二阶的训练方案,其中分类器在第二阶段被重新加强训练。然而,对比学习的方法却很少被探索。

二、长尾识别对比学习

2.1 监督对比损失

image.png

image.png

2.2 简单正则构体:

image.png

2.3 分析

为了分析SCL在长尾数据上的问题,作者主要关注由每个类形成的几何构体产生的变化。当监督对比损失在平衡数据集上达到最小值时,每个类的表示会内接到简单正则构体的顶点,但SCL在长尾数据上会形成不对称分布。

2.4 进一步探究

作者分析了损失的极小值,由于直接计算整个长尾数据集的下界是比较困难的,所以只关注某个特定小批处理的损失,更加合理的损失函数应该需要被表示成:

image.png

上述 SCL 损失由排斥项和吸引项组成。随着训练不断进行,吸引项会导致所有类内表示最终都会趋向于到它们的类均值。这意味着无论数据集是否平衡,同一类内的样本最终都会尽可能接近。排斥项影响类间均匀性并由高频类主导,但 SCL得到的特征难以分离,说明数据不平衡主要影响排斥项。并显然,排斥项与小批量中出现的类数据分布密切相关。当数据集长尾分布时,作者对每个 mini-batch 采样都是不均衡的,这导致了头类在排斥项中占主导地位,并使每个样本离头部更远。此外,对于每个样本,来自头类的梯度将远大于尾类。这不可避免地导致损失函数的优化更加集中在头类上,并导致不对称的几何构体。
image.png

2.5 解决方案

image.png

2.6 平衡对比学习

一种包含两部分,一个是类平均,一个是类补充

1. 类平均: 关键思想是在一个小批量中平均每个类的实例,以便每个类对优化有一个近似相近的贡献(就是减少分母中头类的比例),作者给出L1,L2,L3三种损失形式如下:

image.png

当表示的正类被平均时,左部权重分母将减一。  L1 和 L2 之间的唯一区别是平均操作发生在不同的位置。L1 是在指数函数外进行平均,而 L2 在指数函数内进行平均。L3 中每个样本被拉向其原类并远离其他类,以下为类平均的代码实现。

2. 类补充: 为了让所有类都出现在每个 mini-batch 中,作者引入了类中心表示,即平衡对比学习的原型,公式如下:

image.png

`class BalSCL(nn.Module):  
    def __init__(self, cls_num_list=None, temperature=0.1):  
        super(BalSCL, self).__init__()  
        self.temperature = temperature  
        self.cls_num_list = cls_num_list  
  
    def forward(self, centers1, features, targets, ):  
  
        device = (torch.device('cuda')  
                  if features.is_cuda  
                  else torch.device('cpu'))  
        batch_size = features.shape[0]  
        targets = targets.contiguous().view(-1, 1)  
        targets_centers = torch.arange(len(self.cls_num_list), device=device).view(-1, 1)  
        targets = torch.cat([targets.repeat(2, 1), targets_centers], dim=0)  
        batch_cls_count = torch.eye(len(self.cls_num_list))[targets].sum(dim=0).squeeze()  
  
        mask = torch.eq(targets[:2 * batch_size], targets.T).float().to(device)  
        logits_mask = torch.scatter(  
            torch.ones_like(mask),  
            1,  
            torch.arange(batch_size * 2).view(-1, 1).to(device),  
            0  
        )  
        mask = mask * logits_mask  
          
        # class-complement  
        features = torch.cat(torch.unbind(features, dim=1), dim=0)  
        features = torch.cat([features, centers1], dim=0)  
        logits = features[:2 * batch_size].mm(features.T)  
        logits = torch.div(logits, self.temperature)  
  
        # For numerical stability  
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)  
        logits = logits - logits_max.detach()  
  
# class-averaging  
        exp_logits = torch.exp(logits) * logits_mask  
        per_ins_weight = torch.tensor([batch_cls_count[i] for i in targets], device=device).view(1, -1).expand(  
            2 * batch_size, 2 * batch_size + len(self.cls_num_list)) - mask  
            # 计算每个mini batch的均值作为右部分前系数权重  
        exp_logits_sum = exp_logits.div(per_ins_weight).sum(dim=1, keepdim=True)  
        # 上位除以下位  
          
        log_prob = logits - torch.log(exp_logits_sum)  
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)  
        # 左半系数权重  
          
        loss = - mean_log_prob_pos  
        loss = loss.view(2, batch_size).mean()  
        return loss  
`

image.png

左:将Anchor Sample与其他样本进行对比。右:对特定类同时应用类平均和类补充。作者对 Anchor Sample 和 Batch sample 以及 Class prototype 之间的相似性进行平均。请注意,蓝色的类不会出现在 mini-batch 中,因此可以直接将 anchor 与其prototype的相似度作为结果。

3. 架构: 所提出的架构如下图所示。由两个主要部分构成:分类分支和对比学习分支。两个分支同时训练并共享相同的特征提取器。BCL 是一个端到端模型,不同于传统两阶段训练策略的对比学习方法。作者对这两个分支有不同的扩充方法。总共生成了三个不同的视图,其中 v1 是用于分类任务的视图,v2 和 v3 是对比学习的成对视图。

image.png
v2 和 v3 采用了与 v1 不同的相同增强方法,两个分支共享主干的特征。分类器权重由 MLP 单独转换以用作 prototypes。所有的表示都经过 ℓ2 归一化,以实现对比损失的平衡。

4. 使用 Logit补偿进行优化: 对于长尾学习任务,由于数据的不平衡性,最后一个分类层的输出 logit 通常存在偏差。Logit 补偿旨在消除由数据不平衡引起的偏差。补偿可以应用在训练或测试期间,概括为以下形式:

image.png

三、实验

3.1 实验细节

对于 CIFAR-10-LT 和 CIFAR-100-LT(LT指long tail),作者使用 ResNet-32 作为主干。使用 AutoAugment 和 Cutout 作为分类分支的数据增强策略,使用 SimAugment 作为对比学习分支的数据增强策略。为了控制 和 的影响,λ 设置为 2.0,μ 为 0.6,温度T设置为 0.1。将批量大小设置为 256,权重衰减为 5e−4。MLP 的隐藏层和输出层的维度分别设置为 512 和 128。运行 BCL 200 个 epoch,学习率在前 5 个 epoch 将设为 0.15,并在 epoch 160 和 180 处以 0.1 的步长衰减。使用 Nvidia GeForce 1080Ti GPU 训练上述模型。

对于 ImageNet,使用 ResNet-50 和 ResNeXt50 作为主干。运行 BCL 90 个 epoch,初始学习率为 0.1,权重衰减为 5e-4。对于 iNaturalist,使用 ResNet-50 作为主干,并使用 0.2 的初始学习率和 1e-4 的权重衰减运行 BCL 100 个 epoch。两类数据集上的学习率使用余弦下降,λ 设置为 1.0,μ 设置为 0.35。Batch设置为 256。对分类分支使用 RandAug 增强策略,对比学习分支使用 SimAug。所有模型都使用 SGD 优化器进行训练,动量设置为 0.9。

3.2 实验过程

  1. 首先比较不同类平均实现(即 L1、L2 和 L3)的性能,所有实验均在 CIFAR-100 上进行。L1 和 L2 之间的主要区别在于执行平均操作的顺序。对于 L3,使用prototype而不是同一类的平均值。如下表 所示,L1 取得了最佳性能,这与作者在之前的分析一致。令人惊讶的是,L3 实现了比 L2 更好的性能,这可能归因于prototype的良好表征特性。

image.png

  1. 为了证明平衡对比损失函数的优越性,作者在下表中比较了不同损失的性能。使用带有 logit 补偿 (LC) 的交叉熵损失作为基线。SC 表示添加对比学习分支和常规监督对比损失的基线。类补充和类平均是所提出的平衡对比损失函数的主要思想。结果表明,单独使用类补充或类平均不能提高整体准确性。 相比之下,两者同时应用可以获得显着的性能提升,这表明这两个策略都是实现更强性能不可或缺的组件。
    image.png

3.3 实验结果

1. CIFAR-LT: BCL 与其他现有方法在 CIFAR 上的比较结果如下表所示。从表中可以看出,BCL 始终优于其他方法。作者注意到 BCL 和 Hybrid-SC 之间的准确度差距随着数据不平衡程度的降低而减小。这一结果主要是由于当不平衡问题更严重时,传统的监督对比损失会导致表示学习中的偏差更严重。

image.png

2. ImageNet-LT:  表 5 和表 6 列出了 ImageNet-LT 上的结果,与 Balanced Softmax 相比,通过根据类频率调整 prediction 来添加 logit 补偿,BCL 在所有分组上的效果都显着优于 Balanced Softmax 。 LWS 、T-norm 和 DisAlign 采用 two-stage 学习策略。这些方法侧重于在第二阶段对分类器进行微调,而忽略了表示学习阶段隐含的偏差。PaCo 在监督对比学习中使用了一组参数中心,这些中心被分配了更大的权重,可以看作是分类器的权重。但是,BCL 中使用的 prototypes 补充了每个类的样本,以确保所有类都出现在每个 mini-batch 中。与 PaCo 相比,BCL 实现了 57.1% 的准确率,头类和低频类的准确率显着提高。

image.png
image.png
3. iNaturalist 2018: 表 5 显示了在 iNaturalist 2018 上的实验结果。由于 BCL 是一种对比学习方法,可以从更长的训练时间中获得更多的受益。 然而,为了公平比较,作者仅训练了各种模型 100 个 epoch 的结果。 Hybrid-SC 和 Hybrid-PSC 是对比学习方法,由于表示学习中产生的潜在偏差,它们的性能不如 BCL。与基于集成模型的 RIDE 策略相比,BCL 也始终表现出更好的性能,整体准确率达到 71.8%。

image.png

总结

在这项工作中,作者从表示学习的角度研究了长尾问题,提供了深入的分析来证明现有的监督对比学习对于长尾数据形成了一种分布不友好的不对称几何构体。为了解决不平衡的数据表示学习问题,作者研究出一种平衡的对比损失 BCL 。除了 BCL 之外,作者还采用了一个带有 logit 补偿的分类分支来解决分类器产生偏差的问题(个人感觉类似于 model ensemble)。并在 CIFAR、ImageNet-LT 和 iNaturalist 2018 的长尾数据上进行了广泛的实验,结果充分证明了 BCL 与现有长尾学习方法相比具备更好的性能。

陈er
GiantPandaCV

推荐阅读

三星提出XFormer | 超越MobileViT、DeiT、MobileNet等模型
万字长文 | 手把手教你优化轻量姿态估计模型(算法篇)
超越 ConvNeXt、RepLKNet | 看 51×51 卷积核如何破万卷!
一种融合卷积的ViT模型

更多嵌入式AI相关技术干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。
1 阅读 316
推荐阅读
关注数
14912
内容数
809
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息