专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
数据派THU  ·  WWW2025 | ... ·  昨天  
数据派THU  ·  【AAAI2025】TimeDP:通过领域提 ... ·  8 小时前  
黑马程序员  ·  喜报!应届生均薪破万,最高薪资24000元! ·  昨天  
黑马程序员  ·  喜报!应届生均薪破万,最高薪资24000元! ·  昨天  
51好读  ›  专栏  ›  极市平台

超分辨图像无限生成!清华甩出Inf-DiT:Diffusion Transformer 任意分辨率上采样

极市平台  · 公众号  ·  · 2024-06-03 18:37

正文

↑ 点击 蓝字 关注极市平台
作者丨科技猛兽
编辑丨极市平台

极市导读

综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

本文目录

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样
(来自清华大学,唐杰团队)
1 Inf-DiT 论文解读
1.1 超高分辨率图像生成问题的挑战:GPU 显存需求
1.2 单向块注意力机制
1.3 O(N) 显存消耗的推理过程
1.4 Inf-DiT 架构
1.5 全局和局部一致性
1.6 实验结果

太长不看版

扩散模型在图像生成方面表现出了很显著的性能。然而对于生成超高分辨率的图像 (比如 4096 ×4096) 而言,由于其 Memory 也会二次方增加,因此生成的图像的分辨率通常限制在 1024×1024。在这项工作中。作者提出了一种单向块注意力机制,可以在推理过程中自适应地调整显存开销并处理全局依赖关系。在这个模块的基础上,作者使用 DiT 的架构,并逐渐执行上采样,最终开发了一个无限的超分辨率模型 Inf-DiT,能够对各种形状和分辨率的图像进行上采样。综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。

图1:基于 SDXL、DALL-E 3 和真实图像,选择出的 Inf-DiT 超高分辨率上采样示例

本文做了哪些具体的工作

  1. 提出了单向块注意力机制 (Unidirectional Block Attention,UniBA) 算法,在推理过程中将最小显存消耗从 降低到 , 其中 表示边长。该机制还能够通过调整并行生成的块数量、在显存和时间开销之间进行权衡来适应各种显存限制。
  2. 基于这些方法,训练了一个图像上采样扩散模型 Inf-DiT,这是一个 700M 的模型,能够对不同分辨率的和形状图像进行上采样。Inf-DiT 在机器 (HPDV2 和 DIV2K 数据集) 和人工评估中都实现了最先进的性能。
  3. 设计了多种技术来进一步增强局部和全局一致性,并为灵活的文本控制提供 Zero-Shot 的能力。

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样

论文名称:Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (Arxiv 2024.03)

论文地址:

https://arxiv.org/pdf/2405.04312

项目地址:

https://github.com/THUDM/Inf-DiT

1.1 超高分辨率图像生成问题的挑战:GPU 显存需求

近年来,扩散模型取得了快速发展,显着推动了图像生成和编辑领域的发展。尽管取得了进步,但仍然存在一个关键的限制:现有图像扩散模型生成的图像的分辨率通常被限制在 1024×1024 像素或更低,这对生成超高分辨率图像提出了重大挑战,这在包括复杂的设计项目、广告和海报和墙壁纸的创建等各种实际应用中是必不可少的。

生成高分辨率的常用方法是 Cascaded Generation,它首先生成低分辨率图像,然后应用多个上采样模型逐步提高图像的分辨率。这种方法将高分辨率图像的生成分解为多个子任务。基于前一阶段产生的结果,后期的模型只需要执行局部的生成。在级联结构的基础上,DALL-E2[1]和 Imagen[2]都可以有效地生成 1024×1024 分辨率的图像。

上采样到更高分辨率的图像的最大挑战是关于 GPU 显存需求。例如,如果使用广泛采用的 U-Net 架构,例如 SDXL[3]进行图像推理 (见下图2),可以观察到显存消耗随着分辨率的增加而急剧增加。具体来说,如果生成 4096×4096 分辨率的图像,其包含超过 16 亿个像素需要超过 80GB 的显存,超过了标准 RTX 4090 或 A100 显卡的容量。此外,用于高分辨率图像生成的训练模型的过程加剧了这些需求,因为它需要额外的显存来存储梯度、优化器状态等。LDM[4]通过利用变分自动编码器 (Variational Autoencoder,VAE) 压缩图像并在更小的 Latent Space 中生成图像来减少显存消耗。然而,过高的压缩比会大大降低生成的质量,对显存消耗的减少造成了严重的限制。

图2:本文模型和 SDXL 架构之间不同分辨率的推理期间显存使用的比较

1.2 单向块注意力机制

作者观察到生成超高分辨率图像的关键障碍是显存限制。随着图像的分辨率的增加,网络中相应的 hidden states 的大小呈二次方的复杂度扩展。例如,1层中形状为 2048×2048×1280 的单个 hidden state 需要 20GB 的显存,这使得很难生成非常大的图像。如何避免将整个图像的 hidden state 存储在内存中成为关键的问题。

作者的主要想法是将图像 划分为 Blocks , 其中 是块大小, 。当图像被送入网络时, Block 的大小和分辨率可能会改变, 但 Block之间的布局和相对位置关系保持不变。如果有一种方法可以应用顺序批量生成 Blocks,其中每个 Batch 同时生成 Blocks 的子集,则只需要同时在内存中保留少量 Blocks 的隐藏状态,就可以生成超高分辨率图像。

本文的方法单向块注意力 (Unidirectional Block Attention, UniBA) 如下图3所示。对于每个层,每个 Block 直接依赖于3个一阶相邻的 Block:顶部的 Block、左侧和左上角的 Block。例如,如果采用 Diffusion Transformer (DiT) 架构,Block 之间的依赖关系是注意力操作,每个 Block 的 Query 向量与4个 Block 的 Key,Value 向量交互:位于其左上角和自身的3个 Blocks,如图3所示。

图3:左侧:单向块注意力。在我们的实现中,每个 Block 直接取决于每一层的3个 Blocks:左上角的块、左侧和顶部的 Block;右侧:Inf-DiT 的推理过程。Inf-DiT 根据内存大小每次生成 n×n 个 Block。在这个过程中,只有后续块所依赖的块的 KV-cache 存储在内存中

Transformer 中的 UniBA 过程可以表述为:

其中, 是第 层, 第 行, 第 列的 hidden state, 是 Block-level 的相对位置编码。

1.3 O(N) 显存消耗的推理过程

尽管本文的方法可以按顺序生成每个 Block,但它与自回归的生成模型不同。在自回归的生成模型中,下一个 Block 取决于前面 Blocks 的最终输出。本文方法可以并行生成任意数量的块。基于这一特性,作者实现了一个简单但有效的推理过程。如图3所示,一次性生成 n×nn×n 个块,从左上角到右下角。在生成一组块后,丢弃不再使用的隐藏状态,并将新生成的 KV-cache 附加到显存中。

可以很容易地证明, 在此过程中保留在显存中的 Block KV-cache 的数量总是 。假设模型在生成单个 Block 时所需的空间为 , 一个 Block 的 KV-cache 的空间为 , 其他基本空间消耗 (例如存储原始输入图像) 为C, 则推理过程的最大空间使用为 。当 远小于 时, 内存消耗与 成正比。

在实际应用中, 尽管对于不同的 值, 生成图片的总 FLOPs 是恒定的, 但是受算子初始化时间与显存分配时间的影响, 当 增加时, 生成时间减少。因此, 最好选择内存限制允许的最大

1.4 Inf-DiT 架构

如下图4所示是 Inf-DiT 架构,它基于 DiT[5]。与基于卷积的结构 (如 U-Net[6]) 相比,DiT 仅利用注意力作为 Patch 之间的交互机制,可以方便地实现 UniBA。为了适应 UniBA,提高上采样的性能,作者做了如下几个修改和优化。

图4:Inf-DiT 架构

模型输入

Inf-DiT首先将输入图像划分为多个不重叠的 Blocks,进一步划分为 Patches。与 DiT 不同,考虑到颜色偏移和细节损失等压缩损失,Inf-DiT 的修补是在 RGB 像素空间中进行的,而不是在 Latent Space。在超分辨率 次的情况下,Inf-DiT 首先将低分辨率 RGB 图像条件上采样 倍,然后将其与扩散的噪声输入在特征维数上 Concat 起来,然后将其输入到模型中。

位置编码

最近 LLM 的结果表明,与绝对位置编码相比,相对位置编码在捕获词位置相关性方面更有效。作者参考了 Rotary 位置编码 (RoPE)[7]的设计,它在长上下文生成中表现良好,并将其适配到二维形式的图像生成中。具体来说,作者将隐藏状态的通道分成两半,一个用于编码 坐标,另一个用于编码 坐标,分别使用 RoPE。

作者创建了一个足够大的 Rotary 位置编码表。为了确保训练过程中模型可以看到位置编码表的所有部分, 作者使用随机起点: 对于每个训练图像, 为图像的左上角随机分配一个位置 , 而不是默认的

此外, 考虑到同一个 Block 内和不同 Block 之间的交互差异, 作者还引入了 Block-level 的相对位置编码 , 它根据注意前的相对位置分配不同的 learnable embedding。

1.5 全局和局部一致性

使用 CLIP Image Embedding 针对全局一致性

低分辨率 (LR) 图像中的全局语义信息,如艺术风格和物体材料,在上采样过程中起着至关重要的作用。然而,与文生图像模型相比,上采样模型还有一个额外的任务:理解和分析 LR 图像的语义信息,大大增加了模型的负担。在没有文本数据进行训练时尤其具有挑战性,因为高分辨率图像很少具有高质量的配对文本,这使得模型的这些方面变得困难。

受 DALL-E2[1]的启发,作者利用预训练的 CLIP[8]中的图像编码器从低分辨率图像中提取 Image Embedding ,称之为语义输入。由于 CLIP 是在互联网上海量的图像-文本对上训练的,其图像编码器可以有效地从低分辨率图像中提取全局信息。作者将全局语义嵌入添加到 Diffusion Transformer 的 time Embedding 中,并将其输入到每一层,使模型能够直接从高级语义信息中学习。

全局语义嵌入的另一个有趣优势是,使用 CLIP 中的对齐图像-文本 Latent Space,即使本文模型没有在任何图像-文本对上进行训练, 也可以使用文本来指导生成。给定一个正提示 和一个负提示 , 可以更新图像嵌入:

其中, 可以控制指导的强度。在推理过程中, 可以简单地使用 代替 作为全局语义嵌入来进行控制。例如, 为了获得更清晰的结果集, "clear" 和 " blur" 有时会有所帮助。

使用 Nearby LR Cross Attention 针对局部一致性

尽管将 LR 图像与噪声输入 Concat 起来已经为模型学习 LR 和 HR 图像之间的局部对应关系提供了良好的归纳偏差,但仍然可能存在连续性的问题。原因是,对于给定的 LR Block,有几种上采样的可能性,这需要与附近的几个 LR Block 一起分析以选择一种解决方案。假设上采样仅基于其左侧的 LR Block 执行,它可能会选择一个与右侧和下方 LR Block 冲突的 HR 生成解决方案。然后,当将 LR Block 上采样到右侧时,如果模型认为符合其对应的 LR Block 比与左侧的 Block 连续更重要,则会生成一个与先前块不连续的 HR Block。一个简单的解决方案是将整个 LR 图像输入到每个 Block,但当 LR 图像的分辨率也很大时,它的成本太高。

为了解决这个问题,作者引入了 Nearby LR Cross-Attention。在第一层中,每个 Block 对周围的 3×3 LR Block 进行 Cross-Attention,以捕获附近的 LR 信息。实验结果表明,这种方法显着减少了生成不连续图像的概率。值得注意的是,这个操作不会改变我们的推理过程,因为在生成之前知道整个 LR 图像。

1.6 实验结果

训练细节

本文的数据集包括 LAION-5B[9]的一个子集,分辨率高于 1024×1024,美学得分高于 5 的 100000 来自互联网的分辨率墙纸。在训练过程中,作者使用 512×512 分辨率的固定大小的 Image crop。由于上采样只能使用局部信息进行,因此在推理过程中可以直接用于更高的分辨率,这对于大多数生成模型来说并不容易。

数据准备

由于扩散模型生成的图像通常包含残余噪声和各种细节不准确,因此增强上采样模型的鲁棒性以解决这些问题变得至关重要。作者采用类似于 Real-ESRGAN[10]的方法对训练数据中的低分辨率输入图像执行各种退化。

在处理分辨率高于 512 的图像时,有两种替代方法:一种是直接执行随机裁剪,另一种是在执行随机裁剪之前将较短的边调整为 512。虽然直接裁剪方法在高分辨率图像中保留了高频特征,但调整大小后裁剪方法避免了频繁裁剪单个颜色背景的区域,不利于模型的收敛。因此在实践中,作者从这两种处理方法中随机选择裁剪训练图像。







请到「今天看啥」查看全文