ronghuaiyang · 2021年04月16日

融合EfficientNet和YoloV5,非常实用的物体检测二阶段pipeline

首发:AI公园公众号
作者:Mostafa Ibrahim
编译:ronghuaiyang

导读

使用EfficientNet和YoloV5的融合可以提升20%的performance。

image.png
在本文中,我将解释上一篇文章中称之为“2 class filter”的概念。这是一种用于目标检测和分类模型的综合技术,在过去几周我一直在做的Kaggle比赛中被大量使用。几乎所有参加比赛的人都使用了这种技术,它似乎可以提高大约5-25%的性能,这是非常有用的。

目标检测:YoloV5

我们首先在我们的数据集上训练YoloV5模型,同时使用加权框融合(WBF)进行后处理/预处理,如果你想了解更多,我建议查看这两篇文章:

1、Kaggle竞赛中使用YoloV5将物体检测的性能翻倍的心路历程

2、WBF:优化目标检测,融合过滤预测框

我不想再深入讨论使用WBF训练YoloV5的细节。但是,你需要做的基本上就是使用WBF消除重复的框,然后对数据进行预处理,在其上运行YoloV5。YoloV5需要一个特定的层次结构来显示数据集,以便开始训练和评估。

分类:EfficientNet

接下来要做的是在数据集上训练一个分类网络。但是,有趣的一点是,虽然目标检测模型在14个不同的类(13个不同类型的疾病和1个无疾病类)上训练,但我们只在2个类(疾病和无疾病)上训练分类网络。你可以认为这是一种建模方法,简化了我们的分类问题,因为2分类网络比14分类容易得多,当我们融合这两个网络时,我们真的不需要每一个疾病的细节,我们将只需要一个2分类。当然,对于你的问题,这可能有点不同,因此你可能需要试验不同的设置,但是希望你能从本文中获得一些想法。

目前最先进的分类网络之一是EfficientNet。对于这个数据集,我们将使用使用Keras (TensorFlow)训练的B6 EfficientNet,以及这些扩展:

(    rescale=1.0 / 255,    rotation_range=40,    width_shift_range=0.2,    height_shift_range=0.2,    shear_range=0.2,    zoom_range=0.2,    horizontal_flip=True,    fill_mode="nearest",)

集成

这就是使用2分类过滤器来提高性能的原因,也是本文真正要讨论的内容。关于训练YoloV5和EfficientNet,我不想说太多,因为有很多资源可以提供给他们。

我想强调的主要思想是,尽管Yolo的分类预测非常好,但如果你可以将它们与另一个更强大的网络的分类混合在一起,你可以获得相当不错的性能提升。让我们看看这是如何实现的。这里使用的想法是设置一个高阈值和一个低阈值。然后我们要检查每个分类预测。如果概率小于低阈值,我们将预测设置为“无疾病”。回想一下,我们最初的问题是对14种疾病中的一种进行分类,或者对“无疾病”进行分类。这个低阈值可以是0到1之间的任何值,但可能是0到0.1之间的某个值。此外,如果分类预测在低阈值和高阈值之间,我们得到一个“No Disease”的预测,该预测具有EfficientNet的置信度(不是Yolo)。最后,如果分类预测高于高阈值,我们什么也不做,因为这意味着网络是高度自信的。

可以这样实现:

low_thr  = 0.08high_thr = 0.95def filter_2cls(row, low_thr=low_thr, high_thr=high_thr):    prob = row['target']    if prob<low_thr:        ## Less chance of having any disease        row['PredictionString'] = '14 1 0 0 1 1'    elif low_thr<=prob<high_thr:        ## More chance of having any disease        row['PredictionString']+=f' 14 {prob} 0 0 1 1'    elif high_thr<=prob:        ## Good chance of having any disease so believe in object detection model        row['PredictionString'] = row['PredictionString']    else:        raise ValueError('Prediction must be from [0-1]')    return row

最后的思考

在比赛期间,我已经在各种不同的场景和模型上试验了这2分类过滤器,它似乎总能提高25%的性能,这是令人惊讶的。我认为如果你想把它应用到你的自定义场景中,你需要考虑在哪些情况下分类网络预测可以帮助你的目标检测模型。这并不完全是交换预测的执行度,而是用一种聪明的方式“融合”它们。

—END—

英文原文:https://towardsdatascience.co...

推荐阅读

关注图像处理,自然语言处理,机器学习等人工智能领域,请点击关注AI公园专栏
欢迎关注微信公众号
AI公园 公众号二维码.jfif
推荐阅读
关注数
8244
内容数
210
关注图像处理,NLP,机器学习等人工智能领域
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息