TL;DR
本文提出了一种Contrastive Loss的实现方式(
Inf-CL
),通过分块计算策略,我们在单台A800机器上就能把
batch size
扩展到
4M
。不严谨地说,该方案突破了以前公知的
”contrastive loss不能scaling batch size,显存会炸“
的前提,实现了
Contrastive Loss 的 batch size 近乎无限的扩展。
中国人不骗中国人,以后对比损失实现就用Inf-CL!!
对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo等),文本检索(DPR等)是核心地位。之前相关工作的前提都是”
增大batch size/负样本,GPU显存会炸“
,比如早期MoCo提出用”momenturm encoder“和“memory bank”来规避这个问题。这个工作直面显存痛点,
将对比损失的显存消耗打到底
,且额外时间开销极少,为对比损失相关辐射领域提供了新的scaling机会。
先放炸裂结果:
图 1:Inf-CL 与现有方法(CLIP 和 OpenCLIP)的 GPU 显存使用对比。
图中标出了常见的 GPU 显存限制。对于超过
80GB A800 显存瓶颈
的情况,通过曲线拟合估算显存消耗。
-
左图
:在
8×A800 GPU
配置下,CLIP 和 OpenCLIP 的显存消耗呈
二次增长
,而
Inf-CL
实现了
线性增长
。在
256k batch size
下,Inf-CL 将显存消耗降低了
78倍
。
-
右图
:在
1024k batch size
下,即使使用
128 块 GPU
,CLIP 和 OpenCLIP 的显存仍然会炸。而
Inf-CL
将显存需求减少了
281倍
。
题目:Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss
论文链接:
https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/assets/inf_cl.pdf
Arxiv链接:
https://arxiv.org/abs/2410.17243
Huggingface Papers:
https://huggingface.co/papers/2410.17243
代码链接:
https://github.com/DAMO-NLP-SG/Inf-CLIP
1. 准备工作
1.1 Contrastive Loss
对比学习从20年以来开始爆火,从那个时代走过来的小伙伴,应该还记得这个简单的损失函数绽放了多大的光彩。在图像自监督领域
SimCLR 和 MoCo
两大模型系列相互争锋,跨模态检索领域,开启图文检索预训练的
CLIP
模型,在 NLP和信息检索领域,大家耳熟能详的
SimCSE
和
DPR
等模型,都采用了Contrastive Loss作为训练损失。
这里以CLIP中的实现为例简单回顾一下contrastive loss。假设 batch size 为 b,图像和文本特征的维度为 [b,c],则 CLIP 中的图像到文本的 Contrastive Loss 公式如下:
其中
是第 i 个图像和第 j 个文本之间的余弦相似度, 这里
是匹配样本(正样本对)的相似度。为了简化讨论,公式中省略了温度因子。
从公式中我们可以看到,对比损失会将batch 内非匹配的文本作为负样本,来计算匹配图文对(正样本对)归一化的概率。这个就叫做
In-batch negative
策略——即将 batch 内的所有其他样本视作负样本。这种策略的优点在于,batch size 越大,模型就能接触到更多的负样本,从而学到更具判别性的特征。因此,了解对比学习的同学们都知道,
batch size 理论上越大,效果就越好
,这点也有很多文章从理论上进行分析。
那么一个直观地想法是,我们直接batch size 扩大不就好了,就像别的分类,回归,或者文本生成的任务一样,把梯度累积步数多开一些,batch size不就能一直增大了吗?但遗憾的是,对比学习的batch size 方法一直是一个比较蛋疼的问题。实现过对比损失的同学都知道,核心限制主要是”
增大batch size/负样本,GPU显存会炸“。接下来我们来分析显存消耗到了什么地方。
1.2 显存限制
图2. (a) Vanilla 实现的 Contrastive Loss:将所有特征广播到所有显卡,并同时将完整的相似度矩阵实例化到显存中。存储复杂度为 ( 2),且在所有显卡上重复存储该矩阵。(b) Inf-CL 方法:采用分块-串行累加的策略减少显存占用。
在
经典
的对比损失实现中(如CLIP),首先需要构建
相似度矩阵
,并将其存储在
高带宽内存 (HBM)
中。然后对相似度矩阵应用
Softmax 归一化
和
负对数似然计算
来完成损失计算。
然而,相似度矩阵
及其归一化结果的显存需求,会随着
batch size
呈二次方增长,即显存复杂度是
,这意味着当 batch size 较大时,显存占用会变得非常庞大。例如即使在采用
ViT-B/16
这种轻量化模型的情况下,当 batch size 达到
64k
时,Loss 计算部分的GPU 显存消耗仍然极为惊人。如图 2 (a)所示,尽管模型自身的显存开销仅为
5.24GB
,但损失计算所需的显存却高达
66GB
。
这个例子我们可以清楚看到,在scaling batch size 时,
显存瓶颈主要集中在损失计算上
。现有的方法,如
Gradient Cache
和
BASIC
等,虽在一定程度上优化了模型的显存占用,但依然未能突破loss 计算过程中
显存二次增长
的限制。
2. 方法
2.1 分块计算策略
正如在 上一小节
vanilla 实现
中讨论的那样,显存消耗的核心问题在于
相似度矩阵
X
的完全实例化
。那么我们有没有办法避免将它存储呢?为了达到这个效果,我们首先分析这个矩阵是用来计算什么的,所以先将对比损失的公式进行拆解分析:
公式分解后,我们可以将contrastive loss的计算拆解为两部分:
-
第一部分
:计算所有
正样本对的相似度
并累加
。
这部分的计算复杂度是 \mathcal{O}(b) ,即线性增长,因此不会造成显存瓶颈。
-
第二部分
:计算
Log-Sum-Exp (LSE)
,即所有负样本对的相似度的对数-指数和。这部分是由
全局相似度矩阵
计算得到的
,
如果直接计算并存储整个矩阵,就会导致显存开销迅速增加。
将公式拆解后我们发现,原来相似度矩阵
的完全实例化是为了
计算LSE
这一项
,
其实也就是
Softmax操作的分母部分。
看到这里,熟悉
on-line Softmax
和
FlashAttention
技术的同学们可能已经秒懂了,本质问题是一样的:如果我们能通过分块计算避免一次性存储整个矩阵,LSE 的计算也就不会消耗很多的显存。既然
大模型
的输入长度都能扩展到
百万级别
(例如 FlashAttention 支持的超长序列),那么对比损失的
batch size scaling
问题自然也可以迎刃而解。
前向传播过程:
具体来说, 分块策略的
前向传播
计算过程如下:
其中,
和
分别表示行和列方向上的分块数量。通俗的说,就是
不把
矩阵
一次性计算并存储下来,
而是
将矩阵
的
计算划分为多个块
(即子矩阵)
,并在每个块内部计算局部LSE 值
, 之后沿着
行方向
逐步合并每列块的 LSE 值,最终得到全局 LSE 向量
。
这种
分块计算
方法显著减少了对显存的需求,因为每次只需计算和存储相似度矩阵的一部分,而不是整个
矩阵。此外,在列方向的运算支持并行,能够很好适应多 GPU 或GPU内部多芯片的并行架构,
防溢出策略:
为了避免在合并过程中出现数值不稳定或溢出,采用如下稳定的数值计算公式:
其中初始值
。每次迭代维护列
方向
的LSE向量
,将中间值
累积到
中,完成行方向所有块的计算后,得到最终的全局 LSE 向量
。
此外,在计算
时,直接对矩阵求指数可能导致数值溢出。为此,我们采用以下稳定的公式进行计算:
其中
是一个行最大值向量,每个元素代表
中对应行的最大值,用作确保指数计算不会溢出。
反向传播过程:
其实在传统实现方式的前向传播过程中,相似度矩阵
会存储在计算图内,能够直接调用torch的autograd机制来计算梯度。既然我们在前向过程中仅仅存储了最终得到的LSE向量
,那么就需要自定义实现反向传播的算子。
具体运算过程如下,假设已经计算得到loss的结果,要计算对于图像特征输入
和文本特征
的梯度
根据2.1小节拆解的公式,以 I_i 为例,完整的梯度公式为:
简化后:
从该公式可以看出,第二项计算依赖于相似度矩阵的值。我们在反向计算中也采用与前向过程相同的
分块计算策略
:
-
-
最终梯度为:
其中
是用于累积的临时变量。通过这种分块计算,我们在反向传播中同样避免了完整存储矩阵
的需求,进一步降低了显存开销,并实现了高效的梯度计算。详细的算法步骤在论文中可以找到。
2.2 Multi-Level Tiling
看到这里的小伙伴们可能会产生疑问,分块累加这种操作本质上是将并行计算的过程用串行合并来替代了,也是一种时间换空间的策略,而且反向传播的recompute过程也会带来额外的计算,难道不会很慢吗?其实问题的答案是:整体计算量会增加,但我们可以通过GPU的分布式运算特性来加速这个过程,运算速度却并不会减慢很多。加速过程主要是两块,即
跨GPU的通讯和GPU内显存的IO加速
。我们将其称为
多层级分块策略
。该策略将 LSE 的计算分配为
粗粒度的跨 GPU 分块
和
细粒度的单 GPU 分块
,以最大化计算效率。
图3. 多层级分块策略示意图。上:在 跨 GPU 分块中,每个 GPU 被分配多行数据,并负责对应行的LSE计算。计算与列方向的通信采用 异步 方式执行。下:在 单 GPU 分块 中,将行方向的计算任务分配给多个 CUDA 核心。每行的累积操作在一个 kernel 中执行,以减少 SRAM 和HBM之间I/O次数。
跨 GPU 分块 (Cross-GPU Tile)
在
并行训练
中,假设有
个 GPU,每个 GPU 处理一部分图像和文本数据,分别生成视觉特征
和文本特征
,其中
表示单个 GPU 上的 batch size。计算对比损失时,我们将不同行的数据分配给不同的 GPU,并逐步同步各 GPU 之间的列数据。
具体而言,第
个 GPU 负责相似度矩阵的第
行子块的
: 及其对应的 LSE 向量
。为了降低显存开销,结合
分块策略
,每个行块
可以进一步拆分为
步小块
来计算 LSE,具体过程可以在论文中
算法 1 中找到
。每个小块
的 LSE 计算采用
单 GPU 分块策略
(详见下节)。
由于计算
(当
时)需要访问其他 GPU 上的文本特征