V · 7月22日

VQ-VAE:矢量量化变分自编码器,离散化特征学习模型

VQ-VAE 是变分自编码器(VAE)的一种改进。这些模型可以用来学习有效的表示。本文将深入研究 VQ-VAE 之前,不过,在这之前我们先讨论一些概率基础和 VAE 架构。

image.png

后验和先验分布

image.png

证据下界(ELBO)

在机器学习模型中,大多数后验分布都相当复杂。我们使用变分推理这一基于优化的方法来近似这些分布。ELBO 是变分推理中一个至关重要的目标函数。其推导方式如下。

image.png

重构项用于评估解码器从潜在变量重构输入的能力。KL散度项则充当正则化机制。

变分自编码器(VAE)

标准的自编码器将输入映射到潜在空间中的单个点。然而,VAE的编码器输出概率分布的参数(均值和方差)。模型从这个分布中采样一个点,然后将其输入到解码器中。

image.png

我们使用ELBO作为损失函数。

VAE存在后验崩溃的问题:模型中的正则化项开始主导损失函数,后验分布变得与先验分布相似。解码器变得过于强大,忽略了潜在表示。因此后验分布将不包含有关潜在变量的信息。

在VQ-VAE中,通过矢量量化步骤避免了后验崩溃。

矢量量化变分自编码器(VQ-VAE)

离散表示可以有效地用来提高机器学习模型的性能。人类语言本质上是离散的,使用符号表示。我们可以使用语言来解释图像。因此在机器学习中使用潜在空间的离散表示是一个自然的选择。

image.png

首先,编码器生成嵌入。然后从码本中为给定嵌入选择最佳近似。码本由离散向量组成。使用L2距离进行最近邻查找。

在反向传播过程中,通过嵌入选择步骤的梯度流动并非易事。编码器的输出嵌入和解码器的输入嵌入具有相同的维度。所以直接将解码器输入的梯度复制到编码器输出(红色箭头)。这样可以产生一个良好的梯度近似。

image.png

在训练过程中,梯度可以推动编码器嵌入(绿色圆圈)靠近不同的离散表示(紫色圆圈)。

优化编码器、解码器和嵌入(即码本)。损失函数可以用以下方式表达。

image.png

第一个术语是重构损失(类似于标准的VAE)。它衡量解码器在生成与输入分布相似的输出方面的表现。如果输入是正态分布的,这一项将是简单的均方误差。

sg 是停止梯度操作符,用来停止参数学习。由于从解码器到编码器的直接路径,重构损失项不会向嵌入提供学习信号。所以使用第二项来优化码本,将嵌入推向编码器表示。

第三项是commitment损失。它防止嵌入任意增长。

解码器仅由第一项优化。第一项和第三项优化编码器。第二项优化码本。

在训练期间,先验保持均匀。因此,ELBO的KL散度项是恒定的。

image.png

Pytorch实现

矢量量化器可以通过以下方式实现。

 classVectorQuantizer(nn.Module):
     def__init__(self, num_embeddings, embedding_dim, commitment_cost):
         super(VectorQuantizer, self).__init__()
         
         self._embedding_dim=embedding_dim
         self._num_embeddings=num_embeddings
         
         self._embedding=nn.Embedding(self._num_embeddings, self._embedding_dim)
         self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
         self._commitment_cost=commitment_cost
 
     defforward(self, inputs):
         # convert inputs from BCHW -> BHWC
         inputs=inputs.permute(0, 2, 3, 1).contiguous()
         input_shape=inputs.shape
         
         # Flatten input
         flat_input=inputs.view(-1, self._embedding_dim)
         
         # Calculate distances
         distances= (torch.sum(flat_input**2, dim=1, keepdim=True) 
                     +torch.sum(self._embedding.weight**2, dim=1)
                     -2*torch.matmul(flat_input, self._embedding.weight.t()))
             
         # Encoding
         encoding_indices=torch.argmin(distances, dim=1).unsqueeze(1)
         encodings=torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
         encodings.scatter_(1, encoding_indices, 1)
         
         # Quantize and unflatten
         quantized=torch.matmul(encodings, self._embedding.weight).view(input_shape)
         
         # Loss
         e_latent_loss=F.mse_loss(quantized.detach(), inputs)
         q_latent_loss=F.mse_loss(quantized, inputs.detach())
         loss=q_latent_loss+self._commitment_cost*e_latent_loss
         
         quantized=inputs+ (quantized-inputs).detach()
         avg_probs=torch.mean(encodings, dim=0)
         perplexity=torch.exp(-torch.sum(avg_probs*torch.log(avg_probs+1e-10)))
         
         # convert quantized from BHWC -> BCHW
         returnloss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

我们将输入扁平化,并保持嵌入空间的维数为_embedding_dim。假设输入为 16,32,32,64 BHWC/ batch, height, width, channels 。被压扁成[16384,64]。

 # Flatten input
 flat_input = inputs.view(-1, self._embedding_dim)

然后计算从每个嵌入向量到每个码本向量的距离的平方。假设(N, D)是编码器的输出,(K, D)是码本。得到(N, K)大小的结果。

 distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                     + torch.sum(self._embedding.weight**2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

接下来,我们跨dim = 1(跨码本)执行简单的argmin,获得与编码器输出距离最小的嵌入。我们生成N个大小为K的一元向量。

 distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                     + torch.sum(self._embedding.weight**2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

将嵌入表与这个独热向量相乘以提取最接近的码本向量。这就是量化过程。

接下来定义损失项(重建损失除外)。Mse代表均方误差,.detach作为停止梯度操作。

 e_latent_loss = F.mse_loss(quantized.detach(), inputs)
 q_latent_loss = F.mse_loss(quantized, inputs.detach())
 loss = q_latent_loss + self._commitment_cost * e_latent_loss

最后确保梯度可以直接从解码器流向编码器。

 quantized = inputs + (quantized - inputs).detach()

从数学上讲,左右两边是相等的(+输入和-输入将相互抵消)。在反向传播过程中,.detach部分将被忽略

以上就是VQ VAE的完整实现,原始的完整代码可以在这里找到:

https://avoid.overfit.cn/post/85355d48ece84f77b7c1b02f60de9c8f

推荐阅读
关注数
4189
内容数
866
SegmentFault 思否旗下人工智能领域产业媒体,专注技术与产业,一起探索人工智能。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息