专栏名称: arXiv每日学术速递
跟踪计算机视觉、人工智能、机器学习、NLP、语音识别、量化金融等热门方向学术信息
目录
相关文章推荐
物道  ·  春日氛围感家居服,正在派送中~ ·  15 小时前  
51好读  ›  专栏  ›  arXiv每日学术速递

线性扩散模型LiT来了,用极简线性注意力助力扩散模型AIPC时代端侧部署

arXiv每日学术速递  · 公众号  ·  · 2025-02-03 14:52

正文


太长不看版:香港大学联合上海人工智能实验室,华为诺亚方舟实验室提出 高效扩散模型 LiT :探索了扩散模型中极简线性注意力的 架构设计 训练策略 。LiT-0.6B 可以在断网状态,离线部署在 Windows 笔记本电脑上,遵循用户指令快速生成 1K 分辨率逼真图片。

图 1:LiT 在 Windows 笔记本电脑的离线端侧部署:LiT 可以在端侧,断网状态,以完全离线的方式遵循用户指令,快速生成 1K 分辨率图片

  • 论文名称:LiT: Delving into a Simplified Linear Diffusion Transformer for Image Generation
  • 论文地址:https://arxiv.org/pdf/2501.12976v1
  • 项目主页:https://techmonsterwang.github.io/LiT /

为了提高扩散模型的计算效率,一些工作使用 Sub-quadratic 计算复杂度的模块来替代二次计算复杂度的自注意力(Self-attention)机制。这其中,线性注意力的主要特点是: 1) 简洁;2) 并行化程度高 。这对于大型语言模型、扩散模型这样的大尺寸、大计算的模型而言很重要。

就在几天前,MiniMax 团队著名的《MiniMax-01: Scaling Foundation Models with Lightning Attention》已经在大型语言模型中验证了线性模型的有效性。而在扩散模型中,关于「线性注意力要怎么样设计,如何训练好基于纯线性注意力的扩散模型」的讨论仍然不多。

本文针对这个问题,该团队提出了几条「拿来即用」的解决方案,向社区读者报告了可以如何设计和训练你的线性扩散 Transformer(linear diffusion Transformers)。列举如下:

  1. 使用极简线性注意力机制足够扩散模型完成图像生成。除此之外,线性注意力还有一个「免费午餐」,即: 使用更少的头(head) ,可以在增加理论 GMACs 的同时 (给模型更多计算),不增加实际的 GPU 延迟。
  2. 线性扩散 Transformer 强烈建议从一个预训练好的 Diffusion Transformer 里做权重继承。但是,继承权重的时候, 不要继承自注意力中的任何权重 (Query, Key, Value, Output 的投影权重)。
  3. 可以使用知识蒸馏(Knowledge Distillation)加速训练。但是,在设计 KD 策略时,我们强烈建议不但蒸馏噪声预测结果,同样 也蒸馏方差预测结果 (这一项权重更小)

LiT 将上述方案汇总成了 5 条指导原则,方便社区读者拿来即用。

在标准 ImageNet 基准上,LiT 只使用 DiT 20% 和 23% 的训练迭代数,即可实现相当 FID 结果。LiT 同样比肩基于 Mamba 和门控线性注意力的扩散模型。

在文生图任务中,LiT-0.6B 可以在断网状态,离线部署在 Windows 笔记本电脑上,遵循用户指令快速生成 1K 分辨率逼真图片,助力 AIPC 时代降临。

目录

1 LiT 研究背景
2 线性注意力计算范式
3 线性扩散 Transformer 的架构设计
4 线性扩散 Transformer 的训练方法
5 图像生成实验验证
6 文生图实验验证
7 离线端侧部署

1 LiT 研究背景

Diffusion Transformer 正在助力文生图应用的商业化,展示出了极强的商业价值和潜力。但是,自注意力的二次计算复杂度也成为了 Diffusion Transformer 的一个老大难问题。因为这对于高分辨率的场景,或者端侧设备的部署都不算友好。

常见的 Sub-quadratic 计算复杂度的模块有 Mamba 的状态空间模型(SSM)、门控线性注意力(GLA)、线性注意力等等。目前也有相关的工作将其用在基于类别的(class-conditional)图像生成领域 (非文生图),比如使用了 Mamba 的 DiM、使用了 GLA 的 DiG 。但是,虽然这些工作确实实现了 Sub-quadratic 的计算复杂度,但是,这些做法也存在明显的不足:

  • 其一,SSM 和 GLA 模块都依赖递归的状态 (State) 变量,需要序列化迭代计算,对于并行化并不友好。
  • 其二,SSM 和 GLA 模块的计算图相对于 线性注意力 而言更加复杂,而且会引入一些算数强度 (arithmetic-intensity) 比较低的操作,比如逐元素乘法。

而线性注意力相比前两者,如下图 2 所示,不但设计简单,而且很容易实现并行化。这样的特点使得线性注意力对于高分辨率极其友好。比如对于 2048px 分辨率图片,线性注意力比自注意力快约 9 倍,对于 DiT-S/2 生成所需要的 GPU 内存也可以从约 14GB 降低到 4GB。因此,训练出一个性能优异的基于线性注意力的扩散模型很有价值。

图 2:与 SSM 和 GLA 相比,线性注意力同样实现 sub-quadratic 的计算复杂度,同时设计极其简洁,且不依赖递归的状态变量

但是, 对于有挑战性的图像生成任务,怎么快速,有效地训练好基于线性注意力的扩散模型呢?

这个问题很重要,因为一方面,尽管线性注意力在视觉识别领域已经被探索很多,可以取代自注意力,但是在图像生成中仍然是一个探索不足的问题。另一方面,从头开始训练扩散模型成本高昂。比如训练 RAPHAEL 需要 60K A100 GPU days ( 中报告)。因此,针对线性扩散 Transformer 的高性价比训练策略仍然值得探索。

LiT 从架构设计和训练策略中系统地研究了纯线性注意力的扩散 Transformer 实现。LiT 是一种使用纯线性注意力的 Diffusion Transformer。LiT 训练时的成本效率很高,同时在推理过程中保持高分辨率友好属性,并且可以在 Windows 11 笔记本电脑上离线部署。在基于类别的 ImageNet 256×256 基准上面,100K 训练步数的 LiT-S/B/L 在 FID 方面优于 400K 训练步数的 DiT-S/B/L。对于 ImageNet 256×256 和 512×512,LiT-XL/2 在训练步骤只有 20% 和 23% 的条件下,实现了与 DiT-XL/2 相当的 FID。在文生图任务中,LiT-0.6B 可以在断网状态,离线部署在 Windows 笔记本电脑上,遵循用户指令快速生成 1K 分辨率逼真图片。

2 线性注意力计算范式


3 线性扩散 Transformer 的架构设计

鉴于对生成任务上的线性扩散 Transformer 的探索不多,LiT 先以 DiT 为基础,构建了一个使用线性注意力的基线模型。基线模型与 DiT 共享相同的宏观架构,唯一的区别是将自注意力替换为 线性注意力。所有实验均在基于类别的 ImageNet 256×256 基准上进行,使用 256 的 Batch Size 训练了 400K 迭代次数。

Guideline 1:Simplified 线性注意力对于基于 DiT 的图像生成扩散模型完全足够。

我们首先尝试了在通用视觉基础模型中成功验证的常见线性注意力的架构设计,比如 ReLU 线性注意力 (使用 ReLU 激活函数作为线性注意力的 Kernel Function)。

对于性能参考,将其与 DiT 进行比较,其中任何性能差异都可以归因于线性注意力对生成质量的影响。如图 4 中所示。与 DiT 相比,使用 ReLU 线性注意力的 LiT-S/2 和 B/2 性能下降很大。结果表明,视觉识别中常用的线性注意力在噪声预测任务中有改进的空间。

然后我们探索以下方法:

  • 简化型线性注意力 (图 3,相当于在 ReLU 线性注意力的基础上加上 Depth-wise 卷积)。
  • Focused 线性注意力。
  • Focused 线性注意力 (使用 GELU 替换 ReLU)。

这些选择中的每一个都保持了线性复杂度,保持了 LiT 在计算效率方面的优势。我们使用相对较大的卷积核 (Kernel Size 5) 来确保在预测噪声时足够大的感受野。

图 3:在 Simplified 线性注意力中使用更少的 heads

图 4:不同架构的线性注意力消融研究

实验结果如图 4 所示。加了 DWC 的模块都可以取得大幅的性能提升,我们认为这是因为模型在预测给定像素的噪声时关注相邻像素信息。同时,我们发现 Focused Function 的有效性有限,我们将其归因于其设计动机,以帮助线性注意聚焦于特定区域。此功能可能适合分类模型,但可能不是噪声预测所必需的。为了简单起见,最后使用简化 线性注意力。

Guideline 2:在线性注意力中建议使用很少的头,可以在增加计算的同时不增加时延。

多头自注意力和线性注意力的计算量分别为:


直觉上似乎使用更多头可以减少计算压力。但相反,我们建议使用更少的头,因为我们观察到线性注意力存在 Free Lunch 效应,如图 5 所示。图 5 展示了使用线性注意力的 Small,Base,Large,XLarge 模型使用不同头数量的延迟和 GMACs 变化。

图 5:线性注意力中的 Free Lunch 效应:不同头数量线性注意的延迟与理论 GMACs 比较

我们使用 NVIDIA A100 GPU 生成 256×256 分辨率的图像,批量大小为 8 (NVIDIA V100 GPU 出现类似现象)。结果表明,减小头数量会导致理论 GMACs 稳定增加,实际延迟却并没有呈现出增加的趋势,甚至出现下降。我们将这种现象总结为线性注意力的「免费午餐(Free Lunch)」效应。

我们认为在线性注意力中使用更少的头之后,允许模型有较高的理论计算,根据 scaling law,允许模型在生成性能上达到更高的上限。

实验结果如图 6 所示,对于不同的模型尺度,线性注意力中使用更少的头数 (比如,2,3,4) 优于 DiT 中的默认设置。相反,使用过多的头(例如,S/2 的 96 或 B/2 的 192),则会严重阻碍生成质量。

4 线性扩散 Transformer 的训练方法

LiT 与 DiT 共享一些相同的结构,允许权重继承自预训练的 DiT 架构。这些权重包含丰富的与噪声预测相关的知识,有望以成本高效的方式转移到 LiT。因此,在这个部分我们探索把预训练的 DiT 权重 (FFN 模块、adaLN、位置编码和 Conditional Embedding 相关的参数) 继承给线性 DiT,除了线性注意力部分。

图 6:线性扩散 Transformer 的权重继承策略

Guideline 3:线性扩散 Transformer 的参数应该从一个预训练到收敛的 DiT 初始化。

我们首先预训练 DiT-S/2 不同的训练迭代次数:200K、300K、400K、600K 和 800K,并且在每个实验中,分别将这些预训练的权重加载到 LiT-S/2 中,同时线性注意力部分的参数保持随机。然后将初始化的 LiT-S/2 在 ImageNet 上训练 400K 迭代次数,结果如图 6 所示。

我们观察到一些有趣的发现:

  1. DiT 的预训练权重,即使只训练了 200K 步,也起着重要作用,将 FID 从 63.24 提高到 57.84。
  2. 使用预训练权重的指数移动平均 (EMA) 影响很小。
  3. DiT 训练更收敛时 (800K 步),更适合作为 LiT 的初始化,即使架构没有完全对齐。

我们认为这种现象的一种可能解释是 Diffusion Transformer 中不同模块的功能是解耦的。尽管 DiT 和 LiT 具有不同的架构,但它们的共享组件 (例如 FFN 和 adaLN) 的行为非常相似。因此,可以迁移这些组件预训练参数中的知识。同时,即使把 DiT 训练到收敛并迁移共享组件的权重,也不会阻碍线性注意力部分的优化。

图 7:ImageNet 256×256 上的权重继承消融实验结果

Guideline 4:线性注意力中的 Query、Key、Value 和 Output 投影矩阵参数应该随机初始化,不要继承自自注意力。

在 LiT 中,线性注意力中的一些权重与 DiT 的自注意力中的权重重叠,包括 Query、Key、Value 和 Output 投影矩阵。尽管计算范式存在差异,但这些权重可以直接从 DiT 加载到 LiT 中,而不需要从头训练。但是,这是否可以加速其收敛性仍然是一个悬而未决的问题。

我们使用经过 600K 次迭代预训练的 DiT-S/2 进行消融实验。探索了 5 种不同类型的加载策略,包括:

  • 加载 Query,Key 和 Value 投影矩阵。
  • 加载 Key 和 Value 投影矩阵。
  • 加载 Value 投影矩阵。
  • 加载 Query 投影矩阵。
  • 加载 Output 投影矩阵。

结果如图 7 所示。与没有加载自注意力权重的基线相比,没有一个探索的策略显示出更好的生成性能。这种现象可归因于计算范式的差异。具体来说,线性注意力直接计算键和值矩阵的乘积,但是自注意力就不是这样的。因此,自注意力中的 Key 和 Value 相关的权重对线性注意力的好处有限。

我们建议继承除线性注意力之外的所有预训练参数从预训练好的 DiT 中,因为它易于实现并且非常适合基于 Transformer 架构的扩散模型。

图 8:混合知识蒸馏训练线性扩散 Transformer

Guideline 5:使用混合知识蒸馏训练线性扩散 Transformer 很关键,不仅蒸馏噪声预测结果,还蒸馏方差的预测结果。

知识蒸馏通常采用教师网络来帮助训练轻量级学生网络。对于扩散模型,蒸馏通常侧重于减少目标模型的采样步骤。相比之下,我们专注于在保持采样步骤的前提下,从复杂的模型蒸馏出更简单的模型。



图 9:ImageNet 256×256 上的知识蒸馏实验结果,带有下划线的结果表示不使用知识蒸馏

到目前为止,LiT 遵循 DiT 的宏观 / 微观设计,但采用了高效的线性注意力。使用我们的训练策略,LiT-S/2 显著地提高了 FID。接下来,我们在更大的变体 (例如 B/L/XL) 和具有挑战性的任务 (比如 T2I) 上验证它。

5 图像生成实验验证

ImageNet 256×256 基准






请到「今天看啥」查看全文