专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
江苏教育新闻  ·  推迟VS抢跑,开学“时差”反映了啥? ·  14 小时前  
江苏教育新闻  ·  推迟VS抢跑,开学“时差”反映了啥? ·  14 小时前  
中油工程  ·  发展“创新链” 他们落棋“三子” ·  14 小时前  
六里投资报  ·  景林、但斌300亿持仓披露:东方港湾All ... ·  15 小时前  
六里投资报  ·  景林、但斌300亿持仓披露:东方港湾All ... ·  15 小时前  
51好读  ›  专栏  ›  极市平台

Inf-CL: 把 Contrastive Loss 的 Batch Size 冲到100M!

极市平台  · 公众号  · 科技自媒体  · 2024-11-24 22:00

主要观点总结

本文介绍了一种新的Contrastive Loss实现方式——Inf-CL,它通过分块计算策略,在单台A800机器上将batch size扩展到4M,几乎实现了Contrastive Loss batch size的无限扩展,突破了以往认为增加batch size会导致显存不足的限制。文章详细描述了Inf-CL的方法原理、实验结果的对比,包括显存节省度、速度和精度。

关键观点总结

关键观点1: 对比学习的重要性和限制

对比学习在多个领域如图文检索、图像自监督学习、文本检索中占据重要地位。但增大batch size或负样本会导致GPU显存爆炸,成为该领域的一个难题。

关键观点2: Inf-CL方法介绍

Inf-CL采用分块计算策略,通过减少显存占用实现大batch size的对比损失计算。包括前向传播和反向传播的过程以及Multi-Level Tiling策略。

关键观点3: 实验结果

实验结果显示,Inf-CL在降低显存占用的同时,只引入了极少的时间开销。并且在降低显存占用后,仍然保持了较高的训练速度和精度。

关键观点4: 相关工作与灵感来源

介绍了与本文相关的工作,如Gradient Cache、Flash Attention和Ring Attention等,这些工作为Inf-CL的灵感来源。


正文

↑ 点击 蓝字 关注极市平台
作者丨藤原豆腐皮儿@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/1681887214
编辑丨极市平台

极市导读

本文介绍了一种新的Contrastive Loss实现方式——Inf-CL,它通过分块计算策略,在单台A800机器上将batch size扩展到4M,几乎实现了Contrastive Loss batch size的无限扩展,突破了以往认为增加batch size会导致显存不足的限制。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

TL;DR

本文提出了一种Contrastive Loss的实现方式( Inf-CL ),通过分块计算策略,我们在单台A800机器上就能把 batch size 扩展到 4M 。不严谨地说,该方案突破了以前公知的 ”contrastive loss不能scaling batch size,显存会炸“ 的前提,实现了 Contrastive Loss 的 batch size 近乎无限的扩展。 中国人不骗中国人,以后对比损失实现就用Inf-CL!!

对比学习有多炸不用多说,在图文检索(CLIP为代表),图像自监督学习(SimCLR,MoCo等),文本检索(DPR等)是核心地位。之前相关工作的前提都是” 增大batch size/负样本,GPU显存会炸“ ,比如早期MoCo提出用”momenturm encoder“和“memory bank”来规避这个问题。这个工作直面显存痛点, 将对比损失的显存消耗打到底 ,且额外时间开销极少,为对比损失相关辐射领域提供了新的scaling机会。

先放炸裂结果:

图 1:Inf-CL 与现有方法(CLIP 和 OpenCLIP)的 GPU 显存使用对比。

图中标出了常见的 GPU 显存限制。对于超过 80GB A800 显存瓶颈 的情况,通过曲线拟合估算显存消耗。

  1. 左图 :在 8×A800 GPU 配置下,CLIP 和 OpenCLIP 的显存消耗呈 二次增长 ,而 Inf-CL 实现了 线性增长 。在 256k batch size 下,Inf-CL 将显存消耗降低了 78倍
  2. 右图 :在 1024k batch size 下,即使使用 128 块 GPU ,CLIP 和 OpenCLIP 的显存仍然会炸。而 Inf-CL 将显存需求减少了 281倍

题目:Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss

论文链接: https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/assets/inf_cl.pdf

Arxiv链接: https://arxiv.org/abs/2410.17243

Huggingface Papers: https://huggingface.co/papers/2410.17243

代码链接: https://github.com/DAMO-NLP-SG/Inf-CLIP

1. 准备工作

1.1 Contrastive Loss

对比学习从20年以来开始爆火,从那个时代走过来的小伙伴,应该还记得这个简单的损失函数绽放了多大的光彩。在图像自监督领域 SimCLR 和 MoCo 两大模型系列相互争锋,跨模态检索领域,开启图文检索预训练的 CLIP 模型,在 NLP和信息检索领域,大家耳熟能详的 SimCSE DPR 等模型,都采用了Contrastive Loss作为训练损失。

这里以CLIP中的实现为例简单回顾一下contrastive loss。假设 batch size 为 b,图像和文本特征的维度为 [b,c],则 CLIP 中的图像到文本的 Contrastive Loss 公式如下:

其中 是第 i 个图像和第 j 个文本之间的余弦相似度, 这里 是匹配样本(正样本对)的相似度。为了简化讨论,公式中省略了温度因子。

从公式中我们可以看到,对比损失会将batch 内非匹配的文本作为负样本,来计算匹配图文对(正样本对)归一化的概率。这个就叫做 In-batch negative 策略——即将 batch 内的所有其他样本视作负样本。这种策略的优点在于,batch size 越大,模型就能接触到更多的负样本,从而学到更具判别性的特征。因此,了解对比学习的同学们都知道, batch size 理论上越大,效果就越好 ,这点也有很多文章从理论上进行分析。

那么一个直观地想法是,我们直接batch size 扩大不就好了,就像别的分类,回归,或者文本生成的任务一样,把梯度累积步数多开一些,batch size不就能一直增大了吗?但遗憾的是,对比学习的batch size 方法一直是一个比较蛋疼的问题。实现过对比损失的同学都知道,核心限制主要是” 增大batch size/负样本,GPU显存会炸“。接下来我们来分析显存消耗到了什么地方。

1.2 显存限制

图2. (a) Vanilla 实现的 Contrastive Loss:将所有特征广播到所有显卡,并同时将完整的相似度矩阵实例化到显存中。存储复杂度为 ( 2),且在所有显卡上重复存储该矩阵。(b) Inf-CL 方法:采用分块-串行累加的策略减少显存占用。

经典 的对比损失实现中(如CLIP),首先需要构建 相似度矩阵 ,并将其存储在 高带宽内存 (HBM) 中。然后对相似度矩阵应用 Softmax 归一化 负对数似然计算 来完成损失计算。

然而,相似度矩阵 及其归一化结果的显存需求,会随着 batch size 呈二次方增长,即显存复杂度是 ,这意味着当 batch size 较大时,显存占用会变得非常庞大。例如即使在采用 ViT-B/16 这种轻量化模型的情况下,当 batch size 达到 64k 时,Loss 计算部分的GPU 显存消耗仍然极为惊人。如图 2 (a)所示,尽管模型自身的显存开销仅为 5.24GB ,但损失计算所需的显存却高达 66GB

这个例子我们可以清楚看到,在scaling batch size 时, 显存瓶颈主要集中在损失计算上 。现有的方法,如 Gradient Cache BASIC 等,虽在一定程度上优化了模型的显存占用,但依然未能突破loss 计算过程中 显存二次增长 的限制。

2. 方法

2.1 分块计算策略

正如在 上一小节 vanilla 实现 中讨论的那样,显存消耗的核心问题在于 相似度矩阵 X 的完全实例化 。那么我们有没有办法避免将它存储呢?为了达到这个效果,我们首先分析这个矩阵是用来计算什么的,所以先将对比损失的公式进行拆解分析:

公式分解后,我们可以将contrastive loss的计算拆解为两部分:

  1. 第一部分 :计算所有 正样本对的相似度 并累加 这部分的计算复杂度是 \mathcal{O}(b) ,即线性增长,因此不会造成显存瓶颈。
  2. 第二部分 :计算 Log-Sum-Exp (LSE) ,即所有负样本对的相似度的对数-指数和。这部分是由 全局相似度矩阵 计算得到的 如果直接计算并存储整个矩阵,就会导致显存开销迅速增加。

将公式拆解后我们发现,原来相似度矩阵 的完全实例化是为了 计算LSE 这一项 其实也就是 Softmax操作的分母部分。 看到这里,熟悉 on-line Softmax FlashAttention 技术的同学们可能已经秒懂了,本质问题是一样的:如果我们能通过分块计算避免一次性存储整个矩阵,LSE 的计算也就不会消耗很多的显存。既然 大模型 的输入长度都能扩展到 百万级别 (例如 FlashAttention 支持的超长序列),那么对比损失的 batch size scaling 问题自然也可以迎刃而解。

前向传播过程:

具体来说, 分块策略的 前向传播 计算过程如下:

其中, 分别表示行和列方向上的分块数量。通俗的说,就是 不把 矩阵 一次性计算并存储下来, 而是 将矩阵 计算划分为多个块 (即子矩阵) ,并在每个块内部计算局部LSE 值 , 之后沿着 行方向 逐步合并每列块的 LSE 值,最终得到全局 LSE 向量

这种 分块计算 方法显著减少了对显存的需求,因为每次只需计算和存储相似度矩阵的一部分,而不是整个 矩阵。此外,在列方向的运算支持并行,能够很好适应多 GPU 或GPU内部多芯片的并行架构,

防溢出策略:

为了避免在合并过程中出现数值不稳定或溢出,采用如下稳定的数值计算公式:

其中初始值 。每次迭代维护列 方向 的LSE向量 ,将中间值 累积到 中,完成行方向所有块的计算后,得到最终的全局 LSE 向量

此外,在计算 时,直接对矩阵求指数可能导致数值溢出。为此,我们采用以下稳定的公式进行计算:

其中 是一个行最大值向量,每个元素代表 中对应行的最大值,用作确保指数计算不会溢出。

反向传播过程:

其实在传统实现方式的前向传播过程中,相似度矩阵 会存储在计算图内,能够直接调用torch的autograd机制来计算梯度。既然我们在前向过程中仅仅存储了最终得到的LSE向量 ,那么就需要自定义实现反向传播的算子。

具体运算过程如下,假设已经计算得到loss的结果,要计算对于图像特征输入 和文本特征 的梯度

根据2.1小节拆解的公式,以 I_i 为例,完整的梯度公式为:

简化后:

从该公式可以看出,第二项计算依赖于相似度矩阵的值。我们在反向计算中也采用与前向过程相同的 分块计算策略

  1. 在前向传播时,仅存储大小为 b 的向量
  2. 在反向传播时, 逐块累积 计算梯度:

最终梯度为:

其中 是用于累积的临时变量。通过这种分块计算,我们在反向传播中同样避免了完整存储矩阵 的需求,进一步降低了显存开销,并实现了高效的梯度计算。详细的算法步骤在论文中可以找到。

2.2 Multi-Level Tiling

看到这里的小伙伴们可能会产生疑问,分块累加这种操作本质上是将并行计算的过程用串行合并来替代了,也是一种时间换空间的策略,而且反向传播的recompute过程也会带来额外的计算,难道不会很慢吗?其实问题的答案是:整体计算量会增加,但我们可以通过GPU的分布式运算特性来加速这个过程,运算速度却并不会减慢很多。加速过程主要是两块,即 跨GPU的通讯和GPU内显存的IO加速 。我们将其称为 多层级分块策略 。该策略将 LSE 的计算分配为 粗粒度的跨 GPU 分块 细粒度的单 GPU 分块 ,以最大化计算效率。

图3. 多层级分块策略示意图。上:在 跨 GPU 分块中,每个 GPU 被分配多行数据,并负责对应行的LSE计算。计算与列方向的通信采用 异步 方式执行。下:在 单 GPU 分块 中,将行方向的计算任务分配给多个 CUDA 核心。每行的累积操作在一个 kernel 中执行,以减少 SRAM 和HBM之间I/O次数。

跨 GPU 分块 (Cross-GPU Tile)

并行训练 中,假设有 个 GPU,每个 GPU 处理一部分图像和文本数据,分别生成视觉特征 和文本特征 ,其中 表示单个 GPU 上的 batch size。计算对比损失时,我们将不同行的数据分配给不同的 GPU,并逐步同步各 GPU 之间的列数据。

具体而言,第 个 GPU 负责相似度矩阵的第 行子块的 : 及其对应的 LSE 向量 。为了降低显存开销,结合 分块策略 ,每个行块 可以进一步拆分为 步小块 来计算 LSE,具体过程可以在论文中 算法 1 中找到 。每个小块 的 LSE 计算采用 单 GPU 分块策略 (详见下节)。

由于计算 (当 时)需要访问其他 GPU 上的文本特征







请到「今天看啥」查看全文