AI学习者 · 2024年08月09日

ArcFace的原理以及代码的理解

最近因需要粗浅的学习了一下ArcFace损失函数,由于在学习中遇到了很多问题,特将问题的思考分享出来,权当分享个人愚见,希望可以有人看到后进行讨论进步。

ArcFace的引入

人脸识别分为四个过程:人脸检测、人脸对齐、特征提取、特征匹配。其中,特征提取作为人脸识别最关键的步骤,提取到的特征更偏向于该人脸“独有”的特征,对于特征匹配起到举足轻重的作用,而我们的网络和模型承担着提取特征的重任,优秀的网络和训练策略使得模型更加健壮。

但在Resnet网络表现力十分优秀的情况下,要提高人脸识别模型的性能,除了优化网络结构,修改损失函数是另一种选择,优化损失函数可以使模型从现有数据中学习到更多有价值的信息。

而在我们以往接触的分类问题有很大一部分使用了Softmax loss来作为网络的损失层,实验表明Softmax loss考虑到样本是否能正确分类,而在扩大异类样本间的类间距离和缩小同类样本间的类内距离的问题上有很大的优化空间,因而作者在Arcface文章中讨论了Softmax到Arcface的变化过程,同时作者还指出了数据清洗的重要性,改善了Resnet网络结构使其“更适合”学习人脸的特征。

image.png

对特征提取和分类的个人理解

首先我们思考在分类层前全连接层的意义是什么,全连接层可以视为一个权重矩阵W和网络模型提取到的特征X(我们可以理解为通过全连接层之前的网络结构并且已经进行过flatten的特征)相乘的过程。即为一个W1 * X的过程。那么这个相乘操作的物理意义是什么呢,此时我们可以回忆向量的点乘,向量的点乘即为两个向量的模(常数)的乘积再乘上他们之间夹角的cos值,它的物理意义是两个向量之间的相似度大小。

我们来看一个例子:

image.png

这个分类任务的目的是为了区分输入图像为笔记本电脑还是平板电脑。我们假设通过网络模型提取到了

第I个样本的特征featurei,一共有五个特征。全连接层的操作就是将[2,5]的权重矩阵,乘上这个[5,1]的特征矩阵,得到[2,1]的分类结果矩阵。(此时为logits因为还没有经过Softmax层)。我们可以理解全连接层的权重就是这个样本的“标准特征向量“而提取到的特征向量与”标准特征向量“进行点乘其实是计算出了,第i个样本提取特征和分类Ci的标准特征向量的相似度,所以我们取Ci相似度最大的结果作为最后的分类结果。

我们通常的操作就是将全连接层和提取特征向量的乘积结果送入全连接层,得到一个sum为1的概率向量,取向量中概率最大的index作为分类结果。

但是这样的分类,我们只能得到类似下图的分类结果:

image.png

这种结果只能让不同类别(用颜色表示)简单分开,并不能拉大类别之间的距离,减小类别内样本之间的距离。

这样的简单分类不适合做人脸识别的任务(我们可以思考一下,如果仅仅使用softmax完成如图所示的分类效果,如果存在双胞胎这种两个人长的很像的类型,类别之间距离不够,便很难将其分开)于是ArcFace出现了。

ArcFace的推导与遇到问题的个人见解

image.png

image.png

image.png

下面是ArcFace是数学推导,推导前我们要注意一个问题就是L2归一化,即为在推导过程中将W和X化为1的计算过程,L2归一化是将向量内的每个元素除以向量的L2范数的过程。

image.png

  • 代码实现(基于pytorch)
class ArcMarginModel(nn.Module):
    def __init__(self, m=0.5,s=64,easy_margin=False,emb_size=512):
        super(ArcMarginModel, self).__init__()
 
        self.weight = Parameter(torch.FloatTensor(num_classes, emb_size))
        # num_classes 训练集中总的人脸分类数
        # emb_size 特征向量长度
        nn.init.xavier_uniform_(self.weight)
        # 使用均匀分布来初始化weight
 
        self.easy_margin = easy_margin
        self.m = m
        # 夹角差值 0.5 公式中的m
        self.s = s
        # 半径 64 公式中的s
        # 二者大小都是论文中推荐值
 
        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        # 差值的cos和sin
        self.th = math.cos(math.pi - self.m)
        # 阈值,避免theta + m >= pi
        self.mm = math.sin(math.pi - self.m) * self.m
 
    def forward(self, input, label):
        x = F.normalize(input)
        W = F.normalize(self.weight)
        # 正则化
        cosine = F.linear(x, W)
        # cos值
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        # sin
        phi = cosine * self.cos_m - sine * self.sin_m
        # cos(theta + m) 余弦公式
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
            # 如果使用easy_margin
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # 将样本的标签映射为one hot形式 例如N个标签,映射为(N,num_classes)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        # 对于正确类别(1*phi)即公式中的cos(theta + m),对于错误的类别(1*cosine)即公式中的cos(theta)
        # 这样对于每一个样本,比如[0,0,0,1,0,0]属于第四类,则最终结果为[cosine, cosine, cosine, phi, cosine, cosine]
        # 再乘以半径,经过交叉熵,正好是ArcFace的公式
        output *= self.s
        # 乘以半径
        return output
  • 个人遇到的主要问题以及查找和思考

1.参数s和m具体代表什么:

通过ArcFace,分类结果可以”进化“为

image.png

这种样子,我们把分类的可视化结果视为一个圆,s就代表这个圆的半径,m则可以调整类别之间的夹角距离(?)

2.代码中这easy_margin部分的意义(为什么需要这两行代码):

image.png

首先没有查到easy_margin相关的资料,希望有人可以指点下作者这部分相关。
这个部分的代码主要意义是为了保持Cos的单调性,那么我们首先思考为什么要保持Cos这个函数的单调性。因为在ArcFace中,我们将特征向量和“类别标准向量”的相似度衡量标准从点乘结果转变为了仅仅看两者之间Cos“夹角”(此处的夹角表达意思不完全准确,仅供此部分的理解所需)的值。根据余弦函数的特点,当角度超过Pi时,余弦函数会丢失单调性特征。但是我们在衡量相似度时,所用的夹角是建立在余弦函数的单调性之上的,比如夹角(>0)时,夹角越大,余弦值越小,因此我们就可以说余弦值小的两个向量,相似度较小。但是一旦丢失单调性,这种理论基础便不复存在了。因此我们需要cosine-self.mm的这个操作在Cos(theta+m) > Pi 的时候进行代替,强制使其小于Pi

  • 未解决的问题:

    1.为什么减去的值是

在这里插入图片描述

以上内容均为本人个人理解,不代表准确立场,希望大家在评论区指出错误,一起讨论问题。

作者:火车切片
文章来源:CSDN

推荐阅读

更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。


  1. T
推荐阅读
关注数
10834
内容数
80
搭载基于安谋科技自研“周易”NPU的芯擎科技工业级“龍鹰一号”SE1000-I处理器
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息