V · 2023年07月21日 · 北京市

使用Cleanlab、PCA和Procrustes可视化ViT微调

与传统的卷积神经网络不同,vit使用最初设计用于自然语言处理任务的Transformers 架构来处理图像。微调这些模型以获得最佳性能可能是一个复杂的过程。

下面是使用动画演示了在微调过程中嵌入的变化。这是通过对嵌入执行主成分分析(PCA)来实现的。这些嵌入是从处于不同微调阶段的模型及其相应的检查点生成的。

在本文中,我们将介绍如何创建这样一个动画,主要包括:微调、创建嵌入、异常值检测、PCA、Procrustes、创建动画。

微调

第一步是对预训练好的ViT模型进行微调。为了简单起见我们使用了CIFAR-10数据集,其中包含6万张图像,分为10个不同的类别

微调代码很简单,我们这里主要就是在微调时增加日志记录

 from transformers import TrainerCallback
 
 class PrinterCallback(TrainerCallback):
     def on_log(self, args, state, control, logs=None, **kwargs):
         _ = logs.pop("total_flos", None)
         if state.is_local_process_zero:
             if len(logs) == 3:  # skip last row
                 with open("log.csv", "a") as f:
                     f.write(",".join(map(str, logs.values())) + "\n")

通过在TrainingArguments中设置save_strategy="step"和一个较低的save_step值来增加检查点的保存间隔是很重要的,这样可以确保动画有足够的检查点。动画中的每一帧对应一个检查点。在训练期间为每个检查点和CSV文件创建一个文件夹

创建嵌入

我们使用Transformers库中的AutoFeatureExtractor和autommodel来使用不同的模型检查点中生成嵌入。

每个嵌入是一个768维向量,测试图像总计有10,000个。生成的这些嵌入与检查点存储在同一个文件夹中

提取离群值

我们可以使用Cleanlab库提供的OutOfDistribution类,根据每个检查点的嵌入来识别离群值,可以识别出动画的前10个离群值。这些值也就是我们所说的分类错误的特征,对我们研究模型是非常有用的

 from cleanlab.outlier import OutOfDistribution
 
 def get_ood(sorted_checkpoint_folder, df):
   ...
   ood = OutOfDistribution()
   ood_train_feature_scores = ood.fit_score(features=embedding_np)
   df["scores"] = ood_train_feature_scores

PCA和Procrustes

使用scikit-learn包的主成分分析(PCA),我们通过将768维向量减少到2维来可视化二维空间中的嵌入。当为每个时间步重新计算PCA时,由于轴翻转或旋转,可能会出现动画中的大的条约,这样显示效果很不好。所以为了解决这个问题,我们还从SciPy包中应用了一个额外的Procrustes Analysis,以几何方式将每一帧转换为最后一帧,这只涉及平移、旋转和均匀缩放。这使得动画中的过渡更加平滑。

 from sklearn.decomposition import PCA
 from scipy.spatial import procrustes
 
 def make_pca(sorted_checkpoint_folder, pca_np):
   ...
   embedding_np_flat = embedding_np.reshape(-1, 768)
   pca = PCA(n_components=2)
   pca_np_new = pca.fit_transform(embedding_np_flat)
   _, pca_np_new, disparity = procrustes(pca_np, pca_np_new)

使用Spotlight进行检查

在完成整个动画之前,可以在Spotlight中进行最后的检查。我们用第一个和最后一个检查点来执行嵌入生成、PCA和异常值检测。在Spotlight中加载结果DataFrame如下:

创建动画

通过使用make_pca(…)和get_ood(…)函数对每个模型的检查点创建一个图表,它们分别生成代表嵌入的2D点并提取前8个异常值。2D点用对应于它们各自类别的颜色绘制。异常值是根据他们的分数排序的,最后的训练损失从CSV文件加载并绘制的线形图。

最后,图像使用imageio或类似的库编译成GIF。

总结

本文介绍了如何创建视ViT模型的微调过程可视化。我们通过生成和分析嵌入、可视化结果以及创建将这些元素结合在一起的动画的步骤。

创建这样的动画不仅有助于理解微调ViT模型的复杂过程,而且还可以作为向他人传达这些概念的强大工具。

本文的源代码:

https://avoid.overfit.cn/post/96c2cedd55204af687ea63cfee149dd0

推荐阅读
关注数
4189
内容数
867
SegmentFault 思否旗下人工智能领域产业媒体,专注技术与产业,一起探索人工智能。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息