专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
南昌晚报  ·  太突然!她宣布离婚! ·  5 小时前  
南昌晚报  ·  太突然!她宣布离婚! ·  5 小时前  
煮娱星球  ·  不是吧...她怎么又扑一剧啊?! ·  昨天  
财宝宝  ·  如何相亲? ... ·  2 天前  
广西交通台  ·  太意外!44岁女演员官宣分手 ·  2 天前  
广西交通台  ·  太意外!44岁女演员官宣分手 ·  2 天前  
51好读  ›  专栏  ›  极市平台

MAR(Masked AutoRegressive): 破除封建迷信——谁说自回归图像生成一定需要 VQ的!

极市平台  · 公众号  ·  · 2024-08-09 22:00

主要观点总结

文章讨论了MAR模型和VQ技术在自然语言处理中的应用,指出传统自回归模型需要VQ进行离散化,但新研究提出不使用VQ的自回归图像生成方法,即“Masked Autoregressive (MAR) models”。文章首先介绍了自回归模型的基本原理,然后指出LLMs在处理这类模型时可能遇到的挑战。接着,文章重点介绍了VQ技术,这是一种将连续值向量映射到离散表示的方法,有助于提高模型的效率和性能。文章通过理论分析和实验证明,MAR模型能够摆脱对VQ的依赖,实现更灵活的自回归图像生成。

关键观点总结

关键观点1: 自回归模型的基本原理

自回归模型按照从左到右的顺序逐个生成数据,常用于自然语言处理和图像生成。

关键观点2: LLMs在处理自回归模型时可能遇到的挑战

由于自回归模型需要生成离散的数据,因此需要VQ技术将连续值映射到离散表示。

关键观点3: VQ技术的介绍

VQ技术将连续值向量映射到离散表示,有助于提高模型的效率和性能。

关键观点4: MAR模型的提出

文章提出不依赖VQ的自回归图像生成方法,即“Masked Autoregressive (MAR) models”。

关键观点5: MAR模型的特点

MAR模型能够摆脱对VQ的依赖,实现更灵活的自回归图像生成,同时提高了生成速度和精度。


正文

↑ 点击 蓝字 关注极市平台
作者丨CW不要無聊的風格
编辑丨极市平台

极市导读

文章讨论了MAR模型和VQ技术在自然语言处理中的应用。文章首先介绍了autoregressive模型的基本原理,然后指出了LLMs在处理这类模型时可能遇到的挑战。接着,文章重点介绍了VQ技术,这是一种将连续值向量映射到离散表示的方法,有助于提高模型的效率和性能。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

前言

提到自回归(autoregressive),相信有人会立马举手说:

这个我熟!就是 _从左到右按顺序一个个地进行预测_,现在如火如荼的 LLMs 就是这么玩的。

没毛病~ 这种认知似乎已经成为一种刻板印象烙在我们脑子里了。

进一步,如果将自回归生成用于图像,那么就需要对连续(continuous-valued)的像素进行离散化,变为离散的 token,从而才能在预测时实现对 token 的分类预测,这种离散化的技术被称作 "VQ(Vector Quantization)".

嗯,这又是一个刻板印象,或者说已经成为了一种封建迷信:

自回归图像生成需要 VQ,而且是必须!

然而,近来由恺明大神带队完成的一篇 paper( https://arxiv.org/abs/2406.11838 ) 却破除了以上谈到的封建迷信和刻板印象,即:

VQ 在自回归图像生成中 并非 是必需的,且自回归可以按 随机 顺序一次性预测 多个 ,只要是 根据之前已知的去预测未知的 即可。

这对于习惯照搬隔壁 NLP 那套来搞自回归图像生成的 CVer 们来说可能会造成些打击,但~无论如何,作为炼丹者,千万不能本本主义,接受现实、拥抱变化才是正解。要明白既然玩的本是玄学,那么就一切皆有可能~

自回归图像生成的封建迷信

论文开头第一句就揪出了封建迷信所在:

Conventional wisdom holds that autoregressive models for image generation are typically accompanied by vector-quantized tokens.

随后,作者就当机立断地破除了它:

it is not a necessity for autoregressive modeling.

开篇立意明确,一阵“爽朗”之风迎面吹来,em.. 这篇文章在高考场上应该能拿高分!呃,sorry,跑偏了,现在回到正题。

当今流行的自回归图像生成玩法都是借(照)鉴(抄)隔壁 NLP 的,NLP 的自回归生成是基于之前(已经生成)的 token 来预测下一个 token,通常是从左到右 one-by-one 地生成整个 token 序列。由于自然语言天然是离散的,因此每个 token 就顺理成章地被建模为类别分布(Categorical distribution),属于离散随机变量分布。这种简单直白的玩法在大力出奇迹的信念下取得了出奇好的效果,造就了如今 LLMs 不可一世的姿态。

看见隔壁 NLP 如此气盛,CV 小可爱们难免眼红。于是,CVer 们的心声:既然自回归这种简单无脑的玩法这么好使,何不拿(抄)过来试试?BUT! 下一秒他们便发现,直接抄是行不通的,因为图像宝宝们天然是连续(continuous-valued)的啊.. 卧勒个去!

但 CV 界从来都是人才济济,稍加思索,他们便想到了法子——基于图像数据集训练一个离散的 tokenizer 用于对图像进行离散化,从而将一批“特性”相似的连续值像素用一个共同的离散值表示(实际上该离散值背后还是对应着一个连续值的向量,离散值可看作是这个向量的“编号”),这法子在圈内叫作 " VQ ( V ector Q uantization)",经典代表有 VQ-VAE 等。

于是,对图像进行离散化后,也照样可以将像素如 NLP 的 token 一样建模为类别分布了(从而被叫作 "image token"),也同样可以自回归地基于已经生成的像素去预测(分类)下一个像素了。由此,后面就诞生了一批 "autoregressive with vq" 的代表:iGPT, DALL-E, VQ-GAN, MAGE, MaskGIT

虽然这么做是 work 了,但本论文作者不免觉得别扭,他由心地发出疑问:

Is it necessary for autoregressive models to be coupled with vector-quantized representations?

毕竟大伙有目共睹,VQ tokenizer 是真的难训,其中 quantized vector 的采样(从 codebook 中)是不可导的,于是通常采用 straight-through( https://blog.csdn.net/weixin_43135178/article/details/140160466 ) 这样的梯度估计方法将 quantized vector 的梯度(来自 Decoder)直接复制给 encoder output vectors,这种近似而不准确的梯度是导致其不容易训好的原因之一。

不妨来重新思考下 autoregressive 与 vq 的关系:自回归代表的仅仅是“基于已知的预测未知的”,与“数据值本身是离散还是连续”应该是毫不相干的,VQ 是基于照抄隔壁 NLP 的念头(从而才能将像素也变成像 language token 一样是离散的)才被理所当然地加入到自回归的玩法中了,这念头本身就政治不正确!

就像中国要发展特色社会主义一样,图像天然是连续的,没有必要盲目模仿自(资)然(本)语(主)言(义)而整容成离散的,要本其优势寻求合适的方法去发展壮大。也就是说,像素不一定要建模为类别分布(天然就不合适),在隔壁 NLP 中是因为自然语言天生是离散的所以才很自然地将 token 建模为类别分布,它们很好地利用了自己先天的“势”,找到了合适的“术”,从而在通往“道”的方向上前进了一大步,这个思想是很值得 CVer 们借鉴的,但切忌照搬他人之术。

由此可知,真正的关键在于 要合适地建模每个像素的分布,这个分布要使得我们可以从中采样,并且有相应的 loss 函数去衡量建模的好坏。

用扩散模型来建模分布

要说当下图像生成的流量明星,那自然是扩散模型啦!既然刚刚说了关键点在于建模每个像素的分布,那么何不把扩散模型拿过来使呢,并且其天然就适合建模连续型分布(扩散模型反而在建模离散型分布方面有些棘手)。

另外,上文一直在针对每个像素的分布来论述,然而实际上是可以像 LDM 一样 在 latent space 里玩,从而建模的就是每个 latent 变量的分布了。 为了便于与隔壁 NLP 统一(人们都偏好简单粗暴地对不同形式的事物进行统一),我们也将 latent 变量叫作 token,只不过这 token 是连续值的,美名曰: "continuous-valued token" .

至于如何将像素变成 latent(token),已经有诸多前辈(e.g. VAE)为我们铺好路了,实质上就是对原始图像进行压缩,使其变成更为“紧凑”的向量表征,同时提取了抽象语义。对于扩散模型来说,把它当成图像那样玩即可,什么扩散加噪、去噪生成等过程都不用改。

但是, 与通常扩散模型建模图像分布不同,在那里是等价于要建模所有像素的联合分布,而在此处则是变为建模每个 token 的分布。 引用论文原话表述就是:

in our case, the diffusion model is for representing the distribution for each token.

也因此模型的体量自然就无需那么大了,用个简单的 MLP 即可 ,而在建模图像分布的情况下则通常会用到 U-Net 甚至是 Transformer 等庞然大物并且结合 attention 机制(也就是说这里的扩散模型并没用上注意力机制)。

自回归网络辅助扩散模型做条件生成

既然扩散模型充当了建模 token 分布的角色,于是它就相当于用作预测的头部(prediction head),用于生成 token,就像图像分类网络的头部一样,预测结果是由它这里输出的。那么自回归网络那部分生成的就不是 token 而是某种辅助扩散模型去生成 token 的条件变量(也是连续值的),它与 token 是一一对应的关系。

也就是说, 自回归网络基于已知 token 去预测未知 token 所对应的条件变量, 然后进一步把它给到扩散模型去辅助生成对应的未知 token。 记已知 token 为 , 未知 token 所对应的条件变量为 , 那么自回归网络建模的过程就是 , 而扩散模型建模的则是 。结合如 DDPM 里用到的重参数化技巧, 扩散模型训练的 loss 函数就可以表示为:

其中 是标准高斯噪声, 就是在时间步 下的噪声扰动向量(此处是 token)。

这实际上训的就是条件扩散模型,以自回归网络的输出为条件变量来做条件生成,正是 CFG(Classifier-Free Guidance) 那套,于是训练方法还可以白嫖一波~

最吃香的是, 这个 loss 不仅能训练扩散模型,而且还能将自回归网络也一并训了! 因为梯度能从 传过去,这就是没有 VQ 的好处 —— z 是自回归网络输出的 continuous-valued latent,而非从 codebook 中采样而来(采样操作不可导)。

考虑到这个 loss 在这里起到这么关键的作用,作者觉得务必给它起个名字,大名曰:" Diffusion Loss" .

重新审视自回归的意义

在破除了“自回归图像生成需要和 VQ 绑定”这个封建迷信后,作者进一步重新审视了“自回归本身的意义”,即:到底什么是“真•自回归”?

如本文“前言”一节所述,大部分人对于自回归的刻板印象就是:“从左到右(raster order)”、“一个个(one-by-one)地”、“基于已知的去预测未知的”。然而 最贴近自回归本身意义的,应该仅仅是“基于已知的去预测未知的” ,而“从左到右”和“一个个地(每次只预测一个)”并非是必须的,既非充分也非必要条件。

基于这种觉悟,作者“重塑”了大家对自回归的认知。

首先是预测的顺序,不一定非得是先预测左边再预测右边(对于图像这种二维结构则延伸为从左上到右下),毕竟对于图像来说,像素之间并没有明确的顺序规定;其次是预测的数量,每次不只是预测一个,而是预测一批,引用论文中的表述就是“next set-of-tokens prediction”,这样,在相同的迭代步骤下就能更快地预测完所有 tokens,从而起到加速作用。将这两方面结合起来,就变成 先随机预测一批 tokens,然后再基于已经预测的这批 tokens 去预测未知的下一批(也是随机选择的) tokens。

另外,通常以 Transformer 架构去玩自回归时,会用到 causal attention,这是一种从左(前)到右(后)的单向注意力,于是后面的 tokens 就看不到前面的。然而作者认为 只要遵循“基于已知的去预测未知的”就符合自回归的定义了,与 token 之间是如何交互的没有关系 ,也就是说自回归不应当受到单向注意力的约束。

the goal of autoregression is to predict the next token given the previous tokens; it does not constrain how the previous tokens communicate with the next token.

于是,作者毫不犹豫地 采取了双向注意力(bidirectional attention)机制 (顺便 Q一下还有多少人记得 BERT),这样能够使得 tokens 之间的交互更加充分。最后,作者进一步结合 MAE 的做法—— 基于未 mask 的 tokens 去预测 masked tokens 中随机挑选的一批 ;新预测的这批 tokens 的 mask 被放开(成为 unmasked tokens),它们与之前的 unmasked tokens 再一起去预测剩下的 masked tokens 中随机挑选的一批。这是利用了掩码生成模型天然维持了自回归的特性——基于已知(unmasked)的去预测未知(masked)的。

Conceptually, masked generative models predict multiple output tokens simultaneously in a randomized order, while still maintaining the autoregressive nature of “predicting next tokens based on known ones.

就这样不断重复执行自回归预测,masked tokens 便逐步减少,最终所有 mask 都被放开,于是生成了所有 tokens.

作者将他这么玩的模型称作 " M asked A uto r egressive ( MAR )" models:

MAR is a random-order autoregressive model that can predict multiple tokens simultaneously.

Workflow

前面我们讲了这篇论文的主要方法和关键部分,但整个模型具体是怎么 work 的或许还没讲清楚。这一节会将模型的各部分串起来,从输入到输出,包括训练和推理流程,都扒得明明白白。

训练流程

  • from pixel to latent space

上文已经提到,MAR 是在 latent space 中玩的,原始图像 pixel 会先经过编码转换成 latent,前面谈论的 token 也是处于 latent space 中。对于 pixel 和 latent 之间的切换,作者采用了预训练的 VAE 来实现,其中的 Encoder 负责将 pixel 编码为 latent,而 Decoder 则负责将 latent 解码回 pixel。

所以训练流程的第一步就是将输入图像喂给 VAE Encoder 将其编码为 latent space 中的向量。

  • patchify

VAE Encoder 输出的 latent vectors 和输入图像一样是 (b, c, h, w) 的 4-dims 结构,为了方便接下来的 AR(自回归) 网络(通常是 Transformer)进行处理,于是将其划分成为 patches(如 ViT 一样的做法),成为 (b, l, d) 的 3-dims 结构(和隔壁 NLP 玩 token 序列时一样),其中 l = (h // p) x (w // p), d = c x p x p,p 代表 patch size,实质上这就是 reshape 操作。

划分后的每个 patch 被视作 image token(continuous-valued), 同时会 将它们克隆一份作为 ground truth latents, 作为扩散模型的输入 , 在每个时间步 对它们按照 noise schedule 进行加噪就得到被噪声扰动的latents , 如扩散模型操作像素空间(i.e. 对图像进行加噪)一般。

  • random masking

接下来就是 MAE 的随机 mask 掉部分 tokens 的操作:设置一个最小的 mask ratio(通常是 70%),然后从截断的正态分布(Truncated Normal distribution)中采样一个掩码比例,使得比例值在最小 mask ratio 与 100% 之间,然后对 tokens 按比例(在数量上)进行 mask,masked tokens 是随机挑选的。

需要注意的是, 同一个 batch 的掩码比例相同,但每个样本中哪些位置的 tokens 要被 mask 则是不同的,也就是每个样本单独随机挑选 masked tokens。

  • MAE

然后就是 MAE encode + decode 的流程了。首先 Encoder 接收 unmasked tokens 进行编码,然后 Decoder 将 masked tokens 连同 Encoder 的编码结果一起进行解码,输出提供给扩散模型的条件变量 zzz,其中在 Encoder 和 Decoder 中都要 为 tokens 加上位置编码 (在编/解码操作前)。

不过,这其中还藏着有别于 naive MAE 的操作。由于输入图像先经过 VAE 下采样变为 tokens(“尺寸”相比原图变小了),然后又 mask 掉一部分而仅把剩下 unmasked 的部分给到 Encoder,因此 Encoder 拿到手的 token 序列就可能非常短。为了充分利用上计算资源,作者就 在 token 序列的开头补上 64 个 [cls] tokens(也要加上位置编码) 而后再丢给 Encoder 去编码 。同时,为了能够直接把 CFG 的那套训练方法拿过来用, 每个样本所对应的 64 个 [cls] tokens 会以一定概率全部设置为真实的 class embeddings 或 fake latent (也就是假的条件向量,用于无条件生成)。

Decoder 解码后会先将 [cls] tokens 所对应的解码结果丢掉,然后再次加上位置编码 (与前面的位置编码向量是独立的),这才是最终给到扩散模型条件向量 zzz 。

不过 CW 认为最后这次的位置编码是否必须得加非常值得怀疑!

于是抱着又社恐又难以按捺的心情问了作者大大,没想到个人所想与作者的契合度还蛮高:

Decoder 的解码结果本身已包含了位置信息,因此就逻辑上来说,最后的这次位置编码是没有必要的。

无奈作者所用的预训练模型也是用了最后这次的位置编码来进行训练的,所以就把这个逻辑保留在代码中了。

  • Diffusion loss

最后一步就是计算 Diffusion loss 了,其实就是扩散模型的常规训练方法。

将先前在 patchify 阶段 clone 下来的 gt latents 和 MAE Decoder 解码出来的 conditioning vectors 一同喂给扩散模型, 然后随机采样一个时间步 , 根据 noise schedule 计算出该时间步对应的噪声强度从而对 加噪得到 , 接着扩散模型根据 去预测噪声 ,最后用 MSE loss 计算和真实噪声的误差即是。

需要注意的是, 真正用于计算梯度的仅仅是 masked tokens 那部分的 loss ,只需将计算出来的 loss tensor 对应乘上 mask 即可,因为 loss tensor 和 是一样的 shape.

另外,由于时间步是采样出来的, 为了让模型在每个时间步学习得更加充分 (每个时间步对应不同噪声强度,不同信噪比,模型需要懂得区分它们以便正确去噪) ,作者在每个时间步下都会将样本复制4份以达到对同一时间步采样4次的等价效果。 并且前面也提到了,扩散模型的结构非常小(small MLP network),因此这么做并不会带来太大的负担。

推理流程

在进行推理时,由于没有输入图像(目标就是要生成图像),因此 直接在 latent space 开玩 ,待 MAE + Diffusion models 自回归地生成所有 tokens 后,再由 VAE 的 Decoder 解码成图像。

那么究竟是如何自回归的呢?上文没有讲到,毕竟训练流程是体现不出来的(就像隔壁 NLP 训练 autoregressive models 一样,在训练过程中是并行解码的)。概括来说,就是 MAE encode + decode 后将条件向量给到扩散模型,后者结合该条件向量进行去噪生成(如常规扩散模型的采样生成一般),生成的 tokens 作为已知(unmasked) tokens 再回馈给 MAE 去预测未知(masked) tokens 所对应的条件向量 zzz ,然后喂给扩散模型再次进行去噪生成,生成的结果又作为已知 tokens 给到 MAE 进行下一轮的生成。

就 MAE + Diffusion models 这个整体来说,扩散模型在其中才像是真正的 Decoder ——MAE encode 出富含语义的条件向量辅助扩散模型去 decode 出未知 tokens.

OK,以上仅仅是简单粗暴的概述,接下来 CW 就为大家详细剖析清楚推理流程中各主要环节的具体操作。

  • sample order, autoregressive steps & mask schedule

首先,在采样前,会预先为每个样本随机指定不同的采样次序,从而规定了生成 tokens 的顺序。

接着,要设置自回归的步数,也就是你打算分几步来完成整个生成过程,论文中作者使用了64步。

然后,根据这个步数,定义一种 mask 策略,使得 mask 比例随步数增加而减少,从 100% 降至 0(实际上不会到 0,最后一步预测要保证至少有1个 masked token,而这一步结束后所有 tokens 就都预测完了),从而使模型能够顺利根据已知(unmasked)的去预测未知(masked)的 tokens. 作者使用的是 cosine schedule,使得 mask 比例呈余弦曲线下降的趋势。

假设指定好的 token 生成顺序为 [10, 13, 18, 15, 40, 50, 66, 70, ...],数字代表各 token 在原序列的位置,那么根据以上做法,可能产生的效果就是:在第一轮先是生成 10, 13, 18 这3个位置的 tokens,下一轮再生成 15, 40, 50, 66, 70 这几个位置 tokens. 一开始所有位置的 tokens 都被 mask 住,随着自回归迭代,mask 逐步放开,越来越多的 tokens 成为 unmasked(被生成了),但 它们之间 mask 被放开的相互顺序是在采样前就预先指定好的

  • classifier-free guidance

由于使用了 CFG 的训练方法,因此 MAR 天然就可以实现条件生成,比如生成 ImageNet 数据集里其中一个类别的图片。

如果要实现条件生成,那么在采样时就会额外给模型输入一个指定 label,然后将其编码成 class embeddings;同时将无条件的 fake latent 等量(在 batch 维度)进行复制,接着将其与 class embeddings 在 batch 维度拼接(concat)在一起,这是因为 CFG 需要同时预测含 label 情况下的条件噪声与无条件(i.e. 不含 label)的噪声,为避免让模型分别进行两次前向过程,就选择拓展样本数以达到同时预测不同类型噪声的效果(从而这种操作会使得 batch size x 2)。 最后,拼接后的结果就会作为 64 个(这 64 指的是在 sequence 维度) [cls] tokens 喂给 MAE Encoder;相对地,如果是不含 label 的无条件生成(i.e. without guidance),那么就全部使用 fake latent 作为 64 个 [cls] tokens.

另外, 由于在有 guidance 的情况下,batch size “被迫”增加了一倍,因此处理的 tokens 和 mask 都得相应增加,即直接在 batch 维度复制多一份。

  • MAE

在推理阶段 MAE encode + decode 的流程与训练时是一样的,这里就不再赘述了~

  • computing mask

根据当前自回归迭代的步数与预定义的 mask schedule 计算出 mask 比例,并设置好下一轮自回归生成所要用到的 mask( mask_next ),同时还要计算另一种特殊的 mask( mask_to_pred ) 用于指定该轮生成的 tokens 在哪些位置,以便从 MAE Decoder 的解码结果中取出这些 tokens 所对应的 conditioning vectors( zzz ).

mask_to_pred 是根据当前 mask mask_next 来进行计算的: 在当前 mask 中值为 True 而在 mask_next 中为 False 的那些位置就是本轮需要预测(生成)的 tokens 位置,这代表它们在本轮是 masked tokens 而在下一轮是 unmasked tokens ,于是对应在 mask_to_pred 中这些位置就为 True.

你或许会问:那在当前 mask 中本来为 False 但在 mask_next 中却变为 True 的那些位置咋办?

很抱歉,没有这种情况。因为前面已经说过,采样次序是在采样开始前就预先指定好的,mask 只是根据这个次序逐步放开。由于 tokens 生成的顺序已经被固定,因此当前已经是 unmasked 的位置在之后也会一直 keep 住是 unmasked 的。

  • token sampling by Diffusion models

利用计算好的 mask_to_pred 在 MAE Decoder 的解码结果中将所要生成的 tokens 的 conditioning vectors( ) 取出来, 然后喂给扩散模型做去噪生成(如常规扩散模型一般,从纯高斯噪声开始迭代去噪)。如果是含 label 的条件生成(i.e. with guidance),那么初始噪声(采样起始点)需要在 batch 维度复制多一倍,因为 CFG 需要同时估计含 label 的条件噪声与无条件噪声(前面也已经说过),此时的 也是 2 x batch size 的,包含了等量的 class embeddings 和 fake latents 的编解码结果。

另外,在扩散模型采样时,作者还参考了 Classifier-Guidance(CG) 这篇 paper 中的建议, 使用了温度参数 来 scale 每步采样时的噪声,从而达到调节生成多样性的效果。 待扩散模型生成 tokens 后,就将它们提供给 MAE 进行下一轮的自回归生成。

  • from latent to pixel

当自回归流程全部完成后,就生成了所有的 tokens,但它们是 (b, l, d) 的 3-dims 结构并且是处于 latent space 中的,所以我们需要先进行 unpatchify(i.e. reshape),将其变为 (b, c, h, w) 的 4-dims 结构,然后利用预训练 VAE 的 Decoder 将其解码回图像空间。

到此为止,整个推理流程就结束了,这就是由 latent vectors 生成图像的整个过程。最后有一点提一下: 一开始进行自回归生成时(第一轮) mask 全为 True,代表全为 masked tokens,从而 MAE Encoder 的输入仅仅是那 64 个 [cls] tokens ,这也体现了 pad 这些 [cls] tokens 的作用,否则 MAE Encoder 就只有玩空气的份~

核心源码解析

这一节会对 MAR 的“原创”代码实现进行解析,与上一节的理论剖析相对应。所谓“原创”即其核心思想逻辑但不包括从其它 codebase 搬运过来的部分,比如有关 VAE 的输入输出流程 以及 扩散模型的计算逻辑,诸如这些 CW 就不在这里展示了,有精神的友友们可以自行参考官方库。

附完整源码: https://github.com/LTH14/mar

训练流程

由于省略了 VAE 将 pixel 编码至 latent space 这部分,因此以下所涉及的 code 都是在 latent space 中玩的。尽管有些变量命名为 img ,但千万别当真,它其实是 latents.

  • 主要逻辑

MAR 自然是会被封装为一个类(继承 nn.Module )的,训练的主要逻辑(输入、输出 & 计算 loss)就放在了其 forward() 方法中。值得注意的是,这里用到了三个位置编码,并且每一个都是可学习的。

建议先不看初始化( __init__() )方法,直接看 forward() ,之后再调头回来看~

class MAR(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""


def __init__(self, img_size=256, vae_stride=16, patch_size=1,
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm,
vae_embed_dim=16,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4,
grad_checkpointing=False,
):
super().__init__()

# --------------------------------------------------------------------------
# VAE and patchify specifics
self.vae_embed_dim = vae_embed_dim

self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
self.seq_len = self.seq_h * self.seq_w
self.token_embed_dim = vae_embed_dim * patch_size**2
self.grad_checkpointing = grad_checkpointing

# --------------------------------------------------------------------------
# Class Embedding
self.num_classes = class_num
self.class_emb = nn.Embedding(1000, encoder_embed_dim)
self.label_drop_prob = label_drop_prob
# Fake class embedding for CFG's unconditional generation
self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))

# --------------------------------------------------------------------------
# MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
self.mask_ratio_generator = stats.truncnorm(
(mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)

# --------------------------------------------------------------------------
# MAR encoder specifics
self.z_proj = nn.Linear(self.token_embed_dim,
encoder_embed_dim, bias=True)
self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
self.buffer_size = buffer_size
self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(
1, self.seq_len + self.buffer_size, encoder_embed_dim))

self.encoder_blocks = nn.ModuleList([
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
self.encoder_norm = norm_layer(encoder_embed_dim)

# --------------------------------------------------------------------------
# MAR decoder specifics
self.decoder_embed = nn.Linear(
encoder_embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(
1, self.seq_len + self.buffer_size, decoder_embed_dim))

self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

self.decoder_norm = norm_layer(decoder_embed_dim)
self.diffusion_pos_embed_learned = nn.Parameter(
torch.zeros(1, self.seq_len, decoder_embed_dim))

self.initialize_weights()

# --------------------------------------------------------------------------
# Diffusion Loss
self.diffloss = DiffLoss(
target_channels=self.token_embed_dim,
z_channels=decoder_embed_dim,
width=diffloss_w,
depth=diffloss_d,
num_sampling_steps=num_sampling_steps,
grad_checkpointing=grad_checkpointing
)
self.diffusion_batch_mul = diffusion_batch_mul


def forward(self, imgs, labels):
# class embed (B, D)
class_embedding = self.class_emb(labels)

''' patchify and mask (drop) tokens '''

# (B, C, H, W) -> (B, l = (H // P) * (W // P), C x P x P)
x = self.patchify(imgs)
# 相当于 x_0, 作为扩散模型训练的 gt, 根据 noise schedule 加噪可得 x_t
gt_latents = x.clone().detach()
# 对每个样本单独打乱 tokens 次序, 结合以下从而做到随机 mask 的效果
orders = self.sample_orders(bsz=x.size(0))
# 计算 mask 比例 r%, mask 掉以上 orders 中前 r% 位置的 tokens
# 由于 orders 是随机顺序, 因此实现了随机 mask 的效果
mask = self.random_masking(x, orders)

''' MAE encode & decode '''

# mae encoder
# 在 token 序列前 pad 上 64 个 [cls] tokens,
# 然后与 unmasked tokens 一起(加上位置编码)进入到 encoder 进行编码
x = self.forward_mae_encoder(x, mask, class_embedding)

# mae decoder
# 将 encoder 的编码结果与 masked tokens 一起(再次加上位置编码)进行解码,
# 解码后去掉 64 个 [cls] tokens 对应的解码结果(最后再加一次位置编码).
z = self.forward_mae_decoder(x, mask)

# diffloss
# 与常规扩散模型的 loss 计算类似, 这里是对 `gt_latents` 加噪得到 x_t,
# 然后将 x_t, t, z 输入扩散模型去估计噪声, 采用与真实噪声的 MSE 进行训练,
# 但是 loss 只取 masked tokens 所对应的部分
loss = self.forward_loss(z=z, target=gt_latents, mask=mask)

return loss
  • 随机 mask

随机 mask 实际上是通过随机采样 token 次序而实现的,看以下代码就懂了。

def sample_orders(self, bsz):
# generate a batch of random generation orders
orders = []
for _ in range(bsz):
order = np.array(list(range(self.seq_len)))
np.random.shuffle(order)
orders.append(order)
orders = torch.Tensor(np.array(orders)).cuda().long()

return orders

def random_masking(self, x, orders):
# generate token mask
bsz, seq_len, _ = x.shape
# 从截断的正态分布中采样出 mask 比例
mask_rate = self.mask_ratio_generator.rvs(1)[0]
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
mask = torch.zeros(bsz, seq_len, device=x.device)
# 因为 orders 是随机的 tokens 次序, 所以计算出需要 mask 的 token 数量后,
# 将 orders 前面这么多数量的 tokens 掩盖掉即实现了随机 mask 的效果
mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
src=torch.ones(bsz, seq_len, device=x.device))

return mask
  • MAE

MAE 编解码的实现如下所示,重点我都在以下进行注释了,结合上一节的解释一起搭配食用即可。

def forward_mae_encoder(self, x, mask, class_embedding):
# 将最后一维映射到 encoder embedding dim
x = self.z_proj(x)
bsz, _, embed_dim = x.shape

# 提前预留出 64(即 `buffer_size`) 个 [cls] tokens 的位置, 初始化为 0, 拼接在原 token 序列前面
x = torch.cat([torch.zeros(bsz, self.buffer_size,
embed_dim, device=x.device), x], dim=1)
# mask 也要相应拓展, 值为 0 表示 [cls] tokens 均不会被 mask
mask_with_buffer = torch.cat(
[torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

# random drop class embedding during training
# CFG 的那套玩法, 在训练时以一定概率 drop 掉条件项(此处以 `fake_latent` 作为无条件的表示),
# 从而实现有条件噪声与无条件噪声估计的训练
if self.training:
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
drop_latent_mask = drop_latent_mask.unsqueeze(
-1).cuda().to(x.dtype)
class_embedding = drop_latent_mask * self.fake_latent + \
(1 - drop_latent_mask) * class_embedding

# 将 [cls] tokens 放到序列的前 64 个位置
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)

# encoder position embedding
x = x + self.encoder_pos_embed_learned
# 过一层 LayerNorm
x = self.z_proj_ln(x)

# dropping
# 仅拿 unmasked tokens 喂给 encoder
x = x[(1-mask_with_buffer).nonzero(as_tuple=True)
].reshape(bsz, -1, embed_dim)

''' encoder 编码 '''

# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.encoder_blocks:
x = checkpoint(block, x)
else:
for block in self.encoder_blocks:
x = block(x)

# 最后过一个归一化层
x = self.encoder_norm(x)

return x

def forward_mae_decoder(self, x, mask):
# 将最后一维映射为 decoder embedding dim
x = self.decoder_embed(x)
# 对原始 mask 拓展出 64 个 [cls] tokens 的位置, 值为 0 表示它们均不被 mask
mask_with_buffer = torch.cat(
[torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

# pad mask tokens
# 由于 masked 仅仅是1个维度为 decoder embedding dim 的向量,
# 因此要进行维度的扩展(在 batch 和 sequence 维度进行复制)
mask_tokens = self.mask_token.repeat(
mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
# 先全部初始化为 masked tokens, 而后把 encoder 的编码结果放到 unmasked 部分
x_after_pad = mask_tokens.clone()
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = \
x.reshape(x.shape[0] * x.shape[1], x.shape[2])

# decoder position embedding
x = x_after_pad + self.decoder_pos_embed_learned

''' decoder 解码 '''

# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.decoder_blocks:
x = checkpoint(block, x)
else:
for block in self.decoder_blocks:
x = block(x)

# 经过一个归一化层
x = self.decoder_norm(x)

# 去掉 [cls] tokens 所对应的解码结果
x = x[:, self.buffer_size:]
# 最后再加上另一个位置编码(与前面的位置编码不同)
x = x + self.diffusion_pos_embed_learned

return x
  • Diffusion loss

以下相当于是计算 loss 前的“准备工作”,真正的计算逻辑并不在此处展现。

def forward_loss(self, z, target, mask):
bsz, seq_len, _ = target.shape

# 之所以要在个数上复制 `diffusion_batch_mul` 这么多倍,
# 是为了实现在每个时间步下采样多次从而达到充分训练的效果, 如论文中所述
target = target.reshape(
bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)

loss = self.diffloss(z=z, target=target, mask=mask)
return loss

Diffusion loss 被封装成一个类,其所用的扩散模型相关的计算逻辑“抄”自大名鼎鼎之 OpenAI 的 ADM(https://github.com/openai/guided-diffusion/tree/main).

class DiffLoss(nn.Module):
    """Diffusion Loss"""

    def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
        super(DiffLoss, self).__init__()
        self.in_channels = target_channels
        self.net = SimpleMLPAdaLN(
            in_channels=target_channels,
            model_channels=width,
            out_channels=target_channels * 2,  # for vlb loss
            z_channels=z_channels,
            num_res_blocks=depth,
            grad_checkpointing=grad_checkpointing
        )

        self.train_diffusion = create_diffusion(
            timestep_respacing="", noise_schedule="cosine")
        self.gen_diffusion = create_diffusion(
            timestep_respacing=num_sampling_steps, noise_schedule="cosine")

    def forward(self, target, z, mask=None):
        t = torch.randint(0, self.train_diffusion.num_timesteps,
                          (target.shape[0],), device=target.device)
        model_kwargs = dict(c=z)

        loss_dict = self.train_diffusion.training_losses(
            self.net, target, t, model_kwargs)
        loss = loss_dict["loss"]
        # 仅取 masked tokens 所对应的 loss
        if mask is not None:
            loss = (loss * mask).sum() / mask.sum()

        return loss.mean()

以上的 self.net 代表一个小型的扩散模型,用于估计噪声,使用带 AdaLN 的 MLP 结构来实现,其对于时间步的编码采用了正余弦编码的方式,而对于 conditioning vectors(即 MAE Decoder 的解码结果)则直接使用一个全连接层映射到特定维度,整个输入输出的流程非常简单,如下:

class SimpleMLPAdaLN(nn.Module):
    """
    The MLP for Diffusion Loss.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param z_channels: channels in the condition.
    :param num_res_blocks: number of residual blocks per downsample.
    """


    ...  # 省略, 懒得贴

  
def forward(self, x, t, c):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param c: conditioning from AR transformer.
        :return: an [N x C x ...] Tensor of outputs.
        """

        x = self.input_proj(x)
        t = self.time_embed(t)
        c = self.cond_embed(c)

        y = t + c

        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.res_blocks:
                x = checkpoint(block, x, y)
        else:
            for block in self.res_blocks:
                x = block(x, y)

        return self.final_layer(x, y)

采样过程

  • 预备工作

在正式进入自回归生成前需要做一些预备工作,以下初始化 mask 全为 True,代表一开始全部都是 masked tokens;同时,还将 tokens 初始化为 0,但实际上“0并不发挥作用”,更多地像是起到了占位符的效果,原因 CW 写在以下注释中了;而上一节所说的在采样开始前确定采样次序即对应以下 sample_orders()

def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):

        ''' init and sample generation orders '''
        
        # 一开始 mask 掉所有 tokens
        mask = torch.ones(bsz, self.seq_len).cuda()
        # 虽然初始 token 设为 0, 但由于一开始全被 mask 掉, 因此实际上是 64 个 [cls] tokens 和 `self.mask_token` 
        # 分别在 encoder 和 decoder 起作用
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
        # 采样前先确定采样次序
        orders = self.sample_orders(bsz)

        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)

        ... # 省略,下文会接上
  • CFG 的相关设置

接下来就正式进入自回归迭代生成的流程了。

首先需要对当前的情况做判断,看是否是含 label 的条件生成,如果是,则需要将样本多复制一倍以便让网络同时估计含 label 的条件噪声和无条件噪声;否则,就将 class embedding 替换为 fake latent 按常规估计无条件噪声即可。

MAE 编解码的过程比较无聊,在前面的训练部分也已经展示过其中的代码逻辑了,于是就顺便在此处带过了。

        # 接以上内容

        ''' generate latents '''
        
        # 自回归迭代
        for step in indices:
            cur_tokens = tokens.clone()

            ''' class embedding and CFG '''
            
            # 含 label 的条件生成
            if labels is not None






请到「今天看啥」查看全文