V · 1月2日

分布匹配蒸馏:扩散模型的单步生成优化方法研究

扩散模型在生成高质量图像领域具有显著优势,但其迭代去噪过程导致计算开销较大。分布匹配蒸馏(Distribution Matching Distillation,DMD)通过将多步扩散过程精简为单步生成器来解决这一问题。该方法结合分布匹配损失函数和对抗生成网络损失,实现从噪声图像到真实图像的高效映射,为快速图像生成应用提供了新的技术路径。

分布匹配机制

与传统扩散模型不同,单步生成器并不直接学习完整的数据分布,而是通过强制对齐的方式逼近目标分布。这种方法摒弃了逐步近似的过程,直接建立噪声样本到目标分布的映射关系。

在此过程中,蒸馏机制起到关键作用。预训练模型作为教师网络,提供目标分布的高精度中间表征。

DMD 技术实现流程

阶段 0:系统初始化

  1. 单步生成器基于预训练扩散 unet 进行初始化,时间步设定为 T-1
  2. real_unet 作为固定权重的教师网络,表征真实数据分布
  3. fake_unet 用于对生成器的数据分布进行建模

阶段 1:噪声到图像的生成

生成器接收随机噪声图作为输入,通过单步去噪操作生成图像 x,此时生成的图像 x 符合生成器的概率密度分布 p_fake

阶段 2:高斯噪声注入

对生成图像 x 施加高斯噪声,获得噪声图像 xt,在 0.2T0.98T 范围内均匀采样时间步 t(避开极端噪声状态),噪声注入操作促进 p_fakep_real 分布的重叠,为后续分布比较创造条件

阶段 3:双重网络处理

  1. real_unet 生成 pred_real_image,作为清晰图像的参考近似
  2. fake_unet 生成 pred_fake_image,反映当前时间步的生成器分布特征

通过对比 pred_real_imagepred_fake_image 量化真实分布与生成分布的差异

阶段 4:损失计算

计算 x 与 x — grad 之间的均方误差(MSE)作为损失度量。其中 x — grad 表示经过梯度校正的输出,用于减小与真实数据分布的偏差。

阶段 5:假分布更新机制

fake_unet 通过 x 和 pred_fake_image 之间的扩散损失进行参数更新。这一过程使 fake unet 能够追踪生成器分布的动态变化。与传统 unet 使用 xt-1_pred 和 xt-1_gt 计算损失不同,这里采用 xt-1_pred 和 x 之间的损失,使 fake UNet 能够将生成器输出的噪声版本(xt)还原为当前生成器输出 x。

核心问题解析

问题 1: 为何 fake_unet 采用 xt-1_pred 和 x0 之间的散度作为损失度量,而非采用 xt-1_pred 和 xt-1_gt 的比较?

选择 xt-1_pred 和 x 之间的散度是基于 fake_unet 的核心功能考虑。其目标是将生成器输出的噪声版本(xt)映射回生成器的当前输出(x)。这种设计确保了 fake_unet 能够准确捕获生成器的动态分布特征,从而提供有效的梯度信息来优化生成器输出。

问题 2:fake_unet 的必要性何在?是否可以直接利用预训练的 real_unet 输出与生成器输出计算 KL 散度?

生成器的设计目标是实现单步完全去噪,而预训练的 real_unet 在相同时间步内仅能实现部分去噪。这种本质差异导致 real_unet 输出无法提供有效的 KL 散度用于生成器训练。相比之下,fake_unet 通过持续学习生成器的动态分布,能够准确approximation当前生成器输出的特征。通过比较 real_unetfake_unet 的输出,可以获得用于优化生成器概率分布的有效梯度方向,从而提升单步图像合成的质量。# 分布匹配损失机制

训练过程中,通过 KL 散度定量评估生成器分布与真实分布之间的差异。

其中 Preal 代表真实数据的概率密度函数,Pfake 表示生成器 Gθ 产生的假分布概率密度函数。

对于高维数据集,直接计算概率密度在计算复杂度上存在显著挑战。例如,对于 32×32 像素的灰度图像,其维度空间为 256¹⁰²⁴,直接计算在实际应用中不可行。

因此,采用分数函数对真实分布和生成分布进行特征表征。

这种方法使得 KL 散度的计算成为可能:Sreal 引导 x 向 Preal 的模态靠近,而 −Sfake 则促使其远离真实分布。

其中 Sreal(x) 为真实数据分布的分数函数,Sfake(x) 为生成数据分布的分数函数,∇θ Gθ(z) 表示生成器输出 x 对参数的梯度。

Sreal(x)−Sfake(x) 表征了真实分数与生成分数的差异。对于生成样本 x,由于其 Sreal 接近零,需要引入扰动以支持扩散模型从 xt 进行去噪。

Sfake 和 Sreal 的定义参考自论文 "Song et al. — Score-based generative modeling through stochastic differential equations"

最终损失函数

技术原理剖析

在时间步 t−1,利用 real_unetfake_unet 的输出构建梯度,引导生成器的当前输出 x 向 real_unet 在 t=0 时刻的输出收敛。随后计算生成器原始输出与梯度校正后输出的均方误差(MSE)。这一校正机制确保 x 能够逐步对齐真实数据分布。

损失函数的代码实现

该图展示了不同时间步的损失函数变化,详细说明了多步生成器对单步生成器的训练过程。注意: 图中未详细展示 weighting_factor 相关细节,并对底层分布作出了特定假设。

核心思想在于利用 xfake 和 xreal 之间的差异产生的梯度,将生成器输出引导至 real_unet 在 t=0 时刻的目标输出。随着训练进行,生成器输出逐步向真实分布靠近,同时带动 fake_unet 输出的优化。最终,校正后的图像 ∥x−grad∥ 收敛至真实分布。

总结

本文深入探讨了分布匹配蒸馏(DMD)的技术原理和实现机制,着重阐述了其在图像生成领域的应用价值。欢迎学术界同仁就相关技术细节提供建议和讨论,以促进该领域的持续发展。

https://avoid.overfit.cn/post/c8b74a7d05944be5908b583559294a24

推荐阅读
关注数
4197
内容数
906
SegmentFault 思否旗下人工智能领域产业媒体,专注技术与产业,一起探索人工智能。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息