元峰 · 2020年05月28日

Pytorch中交叉熵Loss趣解

最近一直在总结Pytorch中Loss的各种用法,交叉熵是深度学习中最常用的计算方法,写这个稿子把交叉熵的来龙去脉做一个总结。
作者:元峰
来源:AIZOO

什么是交叉熵

信息量

引用百度百科中信息量的例子来看,

在日常生活中,极少发生的事件一旦发生是容易引起人们关注的,而司空见惯的事不会引起注意,也就是说,极少见的事件所带来的信息量多。如果用统计学的术语来描述,就是出现概率小的事件信息量多。因此,事件出现得概率越小,信息量愈大。即信息量的多少是与事件发生频繁(即概率大小)成反比。

故越小概率的事情发生的事件本身具有的信息量就越大。例如在去年夏天,小卡拒了湖人投奔快船还捎带打劫雷霆了一个泡椒,这种闷声大发财的事情就有很大的信息量。

image.png

信息量的计算公式为:
image.png

信息熵

理解了信息量之后,信息熵的理解也就不再困难了。熵原本是热力学中的一个概念,是用来衡量混乱程度的物理量。信息熵则是借用热力学的概念,衡量在事件发生前对于产生信息量的期望。即信息量是确定的具体事件发生后的信息的度量,信息熵是事件发生前预估的期望。

信息熵的计算公式为:
image.png

可以看到信息熵是一个求和的函数,是求得信息量的期望。还是以小卡为例,小卡转会前,假设去湖人的概率是0.4,去其他30支球队的概率分别为image.png(计算方便),猛龙概率为0(心疼...),那么小卡转会的信息熵为
image.png

所以小卡转会这个事件预计的信息量为image.png,但是实际小卡去了快船,实际的信息量为。因为这是一个非常轰动的事件,所以实际的信息量大于了估计所得的期望。

image.png

KL散度与交叉熵

理解了信息量和信息熵之后,接下来就是交叉熵的概念了。介绍交叉熵之前,Loss是绕不开的。Loss的通俗解释就是预测值和真实值的差异,然后有各种各样的方法来衡量这个差异有多大,本文所介绍的交叉熵也是一种衡量Loss的方法。

KL散度

在讲交叉熵之前,有一个类似的东西叫KL散度,KL散度是用来衡量两个分布之间差异的指标,计算公式为
image.png
公式里面是真实值,是预测值,如果与相同时,即两者之间没有差异。

交叉熵

现在我们将KL散度的公式进行变形,
image.png

其中image.png是真实值的信息熵,第二项

image.png

就是多分类的交叉熵。因此KL散度也被成为相对熵。对于二分类而言,交叉熵为
image.png

二分类交叉熵

Pytorch总共提供了两种二分类交叉熵,一种是nn.BCELoss,另一种是nn.BCEWithLogitsLoss,这两个的差别非常细微,nn.BCEWithLogitsLoss=nn.Sigmoid+nn.BCELoss。这里结合Pytorch的代码做一下验证,首先先验证nn.BCELoss

image.png

对于nn.BCEWithLogitsLoss而言,使用的代码为

image.png

可以看到两者的输入完全相同,输出nn.BCEWithLogitsLoss完全等于nn.BCELoss加上nn.Sigmoid

多分类交叉熵

对于多分类交叉熵函数而言,一般使用nn.CrossEntropyLoss,该函数的计算流程为:

  1. 在输入值上施加nn.Softmax函数
  2. 对于第一步所得结果使用log函数,将较为耗时的乘法运算改为加法运算,并将其归一化到image.png之间
  3. 将第二步所得输出输入nn.NLLLoss函数中,nn.NNLLLoss的作用就是接受负对数似然值,然后对其求平均。

具体的案例在下一节的CIFAR-10的分类问题中。

实际应用

分类问题

这里我们使用Pytorch自带的CIFAR-10的数据集进行分类,训练的网络为

image.png

网络结构如下图所示:

image.png

使用的loss为nn.CrossEntropyLoss,使用的优化器为SGD优化器,batch size为4,分类的图片类别为10类。

网络的输入为:

image.png

网络的输出为

image.png

输出是一个4×10的矩阵,对应的label为,

image.png

将网络直接的输出输入到nn.softmax可得,并验证加和结果为

image.png

接下来,将nn.softmax的输出输入torch.log可以得到

image.png

最后将log_outputs通过nn.NLLLoss并与nn.CrossEntropy对比

image.png

可以发现经过组合的loss和直接用nn.CrossEntroyLoss得到的loss是一样的。

下一次推送我们将会解析一下Kaiming大神的Focal Loss。

参考文献

[1] https://gombru.github.io/2018...\_entropy\_loss/

[2] https://www.baidu.com/link?ur...\_TPjZ8WbzU3im5Hq1JstcfLngNj4y0P5H4gC9lAhGLWnTBAgoucSnBu-Ek\_fwM-RuyWSOfPxv4Idbxr0hm-udxOVd3Yz4rFgPymoQpsOb8\_UsSmub-I&wd=&eqid=b0181d0f0002cd51000000045eccec89

[3] https://pytorch.org/docs/mast...

[4] https://pytorch.org/docs/mast...

相关文章

关于作者

我是元峰,互联网+AI领域的创业者,欢迎在微信搜索“AIZOO”关注我们的公众号AIZOO。

如果您是有算法需求,例如目标检测、人脸识别、缺陷检测、行人检测的算法需求,欢迎添加我们的微信号AIZOOTech与我们交流,我们团队是一群算法工程师的创业团队,会以高效、稳定、高性价比的产品满足您的需求。

如果您是算法或者开发工程师,也可以添加我们的微信号AIXZOOTech,请备注学校or公司名称-研究方向-昵称,例如“西电-图像算法-元峰”,元峰会拉您进我们的算法交流群,一起交流算法和开发的知识,以及对接项目。

WX20200219-104846.png

关注元峰微信号“AIZOOTech”

更多算法模型相关请关注AIZOO专栏
推荐阅读
关注数
225
内容数
48
AIZOO.com 致力于搭建AI开发者、AI公司与需求方的桥梁,打造中国最大的算法和产品商城。传播AI领域资讯和技术,展示AI算法和产品。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息