stable diffusion(三)——LoRA原理与实践


LoRA

论文链接:https://arxiv.org/abs/2106.09685

代码链接: https://github.com/microsoft/LoRA

LoRA是一种fine tune扩散模型的训练技术。通过对标准的checkpoint模型做微小的修改,可以比checkpoint模型小10到100倍。区别于Dreambooth和textual inversion技术,Dreambooth训练得到的模型很强大,但模型也很大(2-7GB);而textual inversion生成的模型很小(大约100KB),但模型效果一般。而LoRA则在模型大小(大约200MB)和效果上达到了平衡。

和textual inversion一样,不能直接使用LoRA模型,需要和checkpoint文件一起使用。

How does LoRA work?

LoRA通过在checkpoint上做小的修改替换风格,具体而言修改的地方是UNet中的cross-attention层。该层是图像和文本prompt交界的层。LORA的作者们发现微调该部分足以实现良好的性能。

cross attention层的权重是一个矩阵,LoRA fine tune这些权重来微调模型。那么LoRA是怎么做到模型文件如此小?LoRA的做法是将一个权重矩阵分解为两个矩阵存储,能起到的作用可以用下图表示,参数量由(1000 2000)减少到(1000 2+2000*2), 大概300多倍!

思路其实很简单,LoRA的论文中表明,在对cross attention finetune的时候能大大减少参数量,但不会对性能产生什么负面影响。C站和huggingface上有大量LoRA模型,可见LoRA模型的受欢迎程度。

cross attention

cross-attention是扩散模型中关键的技术之一,在LoRA中通过微调该模块,即可微调生成图片的样式,而在Hypernetwork中使用两个带有dropout和激活函数的全链接层,分别修改cross attention中的key和value,也可以定制想要的生成风格。可见cross attention的重要性。

在讲cross-attention之前,先看看经典的transformer中attention的含义,attnetion实际上用了三个QKV矩阵,来计算不同token之间的彼此的依赖关系,Q和K可以用来计算当前token和其他token的相似度,这个相似度作为权值对V进行加权求和,可以作为下一层的token。更通俗点说,Q和k的作用是用来在token之间搬运信息,而value本身就是从当前token当中提取出来的信息. 比较常见的是self-attention,该注意力是一个sequence内部不同token间产生注意力,而cross-attention的区别是在不同的sequence之间产生注意力。

扩散模型中cross-attention可以用下面这张图说明,出自High-Resolution Image Synthesis with Latent Diffusion Models

先说下这张图的整体的含义,x是原始的图片,$\tilde{x}$是生成的图片,$\varepsilon$是编码器,$D$是解码器Diffusion Process即前向过程,会随机添加噪声。Denoising U-Net即反向过程,目的是将随机噪声的分布,逐渐去躁生成真实的样本。具体细节可以看之前的文章扩散模型汇总——从DDPM到DALLE2

cross-attention的部分主要体现在上图右侧,和添加进U-Net中作为QKV。生成模型可以建模成$p(z|y)$,$z$是隐变量(Latent Variable), y是条件。在DDPM中有将time也建模出来,用于告知U-Net模型现在是反向传播的第几步,则建模为$\epsilon_{\theta}(z_t,t,y)$。为了能从多个不同的模态获取y,使用了领域专用编码器(domain specific encoder)$\tau_{\theta}$, 可以编码Semantic Map, Text, Representations, Images。得到了$y$和完全是噪声的$z_T$, 再经过cross attention即可融合两部分的特征,可以看作是$y$指导反向过程中$z_T$如何一步步变成去躁后的$z$。下面贴下cross-attention的代码实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

训练LoRA

区别于之前介绍的Textual Inversion、Dreambooth方法。LoRA是一种训练技巧,可以和其他的方法结合,例如在Dreambooth中,训练的时候可以选择LoRA的方式。在之前Dreambooth的篇章中,介绍了人物的训练,这篇填上风格训练的坑。数据的处理方式和之前的处理一样,训练数据改用特定风格的图片,Concepts中选用Training Wizard (Object/Style)

下面是训练过程中的图片

为了和Textual Inversion方法对比,选用了和之前一样的prompt

prompt: flower, grass, outdoors, playground, a photo of style_zhu cartoon

prompt: beamed_eighth_notes, book, clock, eighth_note, musical_note,plant,potted_plant, a photo of style_zhu cartoon

实验了一些prompt的生成效果,和textual inversionx相比Dreambooth+LoRA的方法产生的图片细节更好,也更完整,整体看起来更接近想要学习出的绘画风格。

-------------本文结束感谢您的阅读-------------