《VQ-VAE》论文笔记


论文链接: https://proceedings.neurips.cc/paper/2017/file/7a98af17e63a0ac09ce2e96d03992fbc-Paper.pdf

论文代码: https://github.com/ritheshkumar95/pytorch-vqvae

VQ-VAE(Vector Quantised- Variational AutoEncoder)和VAE(TODO)一样也是生成模型。虽然文中作者给自己的模型取名为VQ-VAE,但实际上和VAE的关系不太太,其模型其实是基于自编码器AE

Pixel CNN

既然思想是基于自编码器AE,那就得追溯到自编码生成模型的代表作Pixel CNN了,Pixcel CNN是Google Deep Mind在2016年提出的。其思想是通过前面的一些数值,得到当前数值的分布,其预测方式为:
$p(x)=\prod_{i=1}^{n^{2}}p(x_i|x_1,…,x_{i-1})$

每个像素是一个256分类的问题,且每个像素的分类,依赖于之前像素的信息。以cifa10数据集为例,图像大小为32323,可以将图片看成序列,则长度为32323=3072。生成cifa图片,需要对3072的序列按seq2seq的方式推理。
基于Pixel CNN的思想,由于模型用到的是CNN,为了实现在推理当前像素时能有效的遮住还未看到的信息。其提出的方式是Gate Convolutions Layers层,其思想如下图所示:

简言之就是对卷积做mask处理,生成一个n*n的卷积,按照从上到下,从左到右的顺序,将卷积中心及之后的特征值置为0,而其他位置置为1,保证卷积操作只能看到该像素之前的像素

Pixel CNN等自回归模型的缺点:

  • 模型耗时较长,对于分辨率稍微高点的图像,自回归模型需要逐像素推理才能还原出图像
  • 图像的像素时很冗余的,这一思想在最近的很多论文中都有论证,如MAE。虽然图像中每个像素时离散的,但事实上连续的像素时相近的,有时RGB值差个别数值,并不影响图像的生成,而转变成像素分类问题,只有非对即错的结果

Method

VQ-VAE的论文确实写的比较难懂,而苏神的blog则要清晰非常多,关于算法原理这块推荐大家可以读读苏神的blog

针对自回归模型存在的缺点,VQ-VAE的思想是先对图像降维,再对编码向量用Pixel CNN的方式建模。这种i想具有如下的坑:

  • 因为Pixel CNN建模使用的离散的序列,就意味着VQ-VAE降维度的时候需要转换为离散的序列。其实自编码器就是很常用的降维方法之一,然而其生成的编码向量都是连续的
  • 降维后的特征和原始特征存在差异,求梯度的时候不能用原始特征和gt比对,因为优化目标已经变成了降维后的特征。也不能直接用降维后的特征,因为VQ-VAE中的降维实际上是映射到编码表,这个过程是不能求梯度的

离散化

在VQ-VAE中,一张图片x会先经过encoder,得到连续的变量z

这里的z是一个大小为d的向量,VQ-VAE还维护一个Embedding层,我们也可以称为编码表,记为

其中每一个$e_i$都是大小为d的向量,接着,VQ-VAE通过最邻近搜索,将z映射为这K个向量之一:

将z映射到编码表后的特征记为$z_q$, 则$z_q$才是编码后的结果,会将$z_q$传给decoder做生成,这样以来就将连续的特征z转变为了降维后的离散特征$z_q$
上面的流程实际上是简化的,如果只编码一个向量,重构时容易出现失真,而且泛化性一般,因此实际编码时直接用多层卷积将x编码为m×m个大小为d的向量,也就是说,z的总大小为m×m×d,它依然保留着位置结构,然后每个向量都用前述方法映射为编码表中的一个,就得到一个同样大小的$z_q$,然后再用它来重构。这样一来,$z_q$也等价于一个m×m的整数矩阵,这就实现了离散型编码。

前向和反向传播

如果是普通的自编码器,直接用下述loss训练即可:

但是$z_q$并不是原来的z, 可就算换成$z_q$也不能计算梯度,换言之,我们的目标其实是$‖x−decoder(z_q)‖_2^2$最小,但是却不好优化,而$||x-decoder(z)||^2_2$容易优化,但却不是我们的优化目标。那怎么办呢?当然,一个很粗暴的方法是两个都用:

但这样并不好,因为decoder(z)并不是优化目标,会带来额外的约束

VQ-VAE中用了一个巧妙且直接的方法,称为Straight-Through Estimator,你也可以称之为“直通估计”。最早源于论文Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation

Straight-Through Estimator的思想很简单,就是前向传播的时候可以用想要的变量(哪怕不可导),而反向传播的时候用你设计的梯度去替代。根据这个思想,我们设计的目标函数是:

其中sg是stop gradient的意思,就是不要它的梯度。这样一来,前向传播计算(求loss)的时候,就直接等价于$decoder(z+z_q−z)=decoder(z_q)$.然后反向传播(求梯度)的时候,由于$z_q−z$不提供梯度,所以它也等价于decoder(z),这个就允许我们对encoder进行优化了。

维护编码表

上面我们提到离散化是通过编码表映射完成,期望是映射后的$z_q$和$z$相近,不然仍然会导致生成的图像失真严重,因为离散化其实是在做量化,而量化的目的是减少计算量的同时,尽量不损失精度。由于编码表E是相对自由的,而z要尽力保证重构效果,所以我们应当尽量“让$z_q$去靠近$z$”而不是“让z去靠近$z_q$”。而因为$‖z_q−z‖^2_2$的梯度等于对zq的梯度加上对z的梯度,所以我们将它等价地分解为:

第一项相等于固定$z$,让$z_q$靠近$z$,第二项则反过来固定$z_q$,让$z$靠近$z_q$。注意这个“等价”是对于反向传播(求梯度)来说的,对于前向传播(求loss)它是原来的两倍。根据我们刚才的讨论,我们希望“让$z_q$去靠近$z$”多于“让$z$去靠近$z_q$”,所以可以调一下最终的loss比例:

其中$\gamma<\beta$,在原论文中使用的是$\gamma=0.25\beta$

生成

经过上面的离散化之后,将图片编码为$m*m$的整数矩阵。该矩阵也一定程度保留了原始图片的信息,可以使用自回归模型如Pixel CNN,对编码矩阵拟合。

通过Pixel CNN得到编码分布后,可以随机生成一个新的编码矩阵,然后通过编码表E映射为浮点数$z_q$,最后经过decoder得到一张图片.

一般来说,得到的 比原来的要小的多,因此计算也更快速

欢迎关注我的公众号!

enter description here

-------------本文结束感谢您的阅读-------------