专栏名称: 机器之心
专业的人工智能媒体和产业服务平台
目录
相关文章推荐
机器之心  ·  技术大神授课,百亿AI项目招标,2025全球 ... ·  15 小时前  
爱可可-爱生活  ·  本文创新性地提出将 LLM ... ·  21 小时前  
机器之心  ·  DeepSeek ... ·  昨天  
量子位  ·  马斯克“地表最强”Grok ... ·  2 天前  
量子位  ·  马斯克“地表最强”Grok ... ·  2 天前  
51好读  ›  专栏  ›  机器之心

首个基于统计学的线性注意力机制ToST,高分拿下ICLR Spotlight

机器之心  · 公众号  · AI  · 2025-02-17 09:17

正文

图片

AIxiv专栏是机器之心发布学术、技术内容的栏目。过去数年,机器之心AIxiv专栏接收报道了2000多篇内容,覆盖全球各大高校与企业的顶级实验室,有效促进了学术交流与传播。如果您有优秀的工作想要分享,欢迎投稿或者联系报道。 投稿邮箱:[email protected][email protected]


本文第一作者为加州大学伯克利分校三年级博士生吴梓阳,导师为马毅教授。吴的主要研究方向为表征学习与多模态学习。该工作由多所学校与机构的研究者共同完成,包括加州大学伯克利分校、宾夕法尼亚大学、密歇根大学、清华大学、忆生科技、香港大学、约翰·霍普金斯大学等。据悉,马毅教授已受邀在今年四月的ICLR大会上就和此项成果相关的一系列白盒神经网络相关工作,进行为时一小时的主题报告(Keynote)。


Transformer 架构在过去几年中通过注意力机制在多个领域(如计算机视觉、自然语言处理和长序列任务)中取得了非凡的成就。然而,其核心组件「自注意力机制」 的计算复杂度随输入 token 数量呈二次方增长,导致资源消耗巨大,难以扩展到更长的序列或更大的模型。


Token Statistics Transformer (ToST) 提出了一种新的注意力机制,它的时间复杂度是线性的。通过对序列特征的统计建模,ToST 提高了序列处理任务中的效率。文章探讨了基于变分编码率缩减(Variational Rate Reduction, VRR)的框架,并通过实验验证了其在不同任务中的性能,通过革新传统注意力机制,解决了这些长期困扰 Transformer 架构的效率瓶颈。


T oST 也作为 Spotlight 论文,入选了 ICLR 2025 大会。



  • 论文标题:Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction
  • 论文地址:https://arxiv.org/abs/2412.17810
  • 项目主页:https://robinwu218.github.io/ToST/
  • 目前该工作已开源:https://github.com/RobinWu218/ToST


研究背景与动机


一直以来,自注意力机制依赖于对输入 token 两两相似性的计算,这一过程虽然有效,但其资源开销显著;尤其当输入 token 数量极大时,传统注意力机制(如 Transformer 中的全局注意力)在计算复杂度和内存使用上的瓶颈问题愈发显著。


为了应对这一挑战,本文提出了一种基于统计学特征的注意力机制:Token Statistics Self-Attention (TSSA)。它通过避免两两相似性的计算,仅依赖于 token 特征的统计量,显著降低了计算复杂度。

Token Statistics Transformer (ToST) 的架构。Token Statistics Self-Attention (TSSA) 运算符通过对投影后的 token 进行行标量化变换,从而实现了线性复杂度。


核心方法


ToST 的核心方法是通过特定的概率分布函数对输入序列进行建模,减少冗余信息并提取关键特征。具体包括:


1. 统计特征提取 :对序列中的每个 token 提取其统计特征。

2. 变分编码率缩减 :利用 VRR 框架对特征进行压缩,减少信息冗余。

3. 线性复杂度实现 :通过一系列优化,其计算复杂度从 O (n²) 降低为 O (n)。

ToST 的方法概述。在 CRATE 的理论基础上,ToST 通过几何空间的结构化特征实现 token 分组和映射。


网络架构的推导


该团队通过扩展先前的 CRATE 工作推导出网络架构。CRATE 显示,一种 Transformer 风格的架构可以通过 "白盒" 架构设计自然生成,其中网络的每一层都旨在实现最大编码率缩减目标 (MCR²) 的增量优化步骤。


具体来说,该团队推导了 MCR² 目标的一个新颖 的变分形式,并表明通过对该变分目标进行展开梯度下降所得到的架构会引入一种新的注意力模块,称为 Token Statistics Self-Attention (TSSA)。 TSSA 拥有线性的计算和内存复杂度,并从根本上不同于典型的注意力架构,其后者通过计算 token 之间的两两相似性来实现。


关键公式 MCR² 目标函数定义


技术细节


1. 线性时间注意力机制:Token Statistics Self-Attention (TSSA)


通过白盒设计方法(algorithmic unrolling),TSSA 从最大编码率减少(Maximal Coding Rate Reduction, MCR² )的变分形式中推导而来。


传统 Transformer 依赖于 pairwise 相似度计算,而 TSSA 则基于 token 特征的统计量构建注意力机制,其计算复杂度从 O (n²) 降低为 O (n),内存占用同样显著减少。


2. 创新性的网络结构:Token Statistics Transformer (ToST)


ToST 通过将 TSSA 替代标准的自注意力模块,不仅实现了显著的效率提升,还增强了模型的可解释性。


传统模型不同,ToST 架构中的注意力操作基于统计量的低秩投影,通过减少不必要的计算路径,大幅优化了资源使用。


3. 理论支撑与数学推导


基于 MCR² 的变分形式,提出了一种新颖的压缩项公式,可对大型矩阵进行有效的特征提取。


通过设计数据相关的低秩投影,TSSA 在保留关键信息的同时,消除了冗余方向。


实验验证与性能分析


实验覆盖了自然言语处理(NLP)、计算机视觉(CV)等多个领域的任务,包括文本分类、机器翻译、图像识别等。结果表明,ToST 在保证模型性能的同时,大幅降低了计算资源消耗。


1. 计算和内存的线性复杂度分析


实验结果显示,与现有的注意力机制相比,TSSA 的时间和内存复杂度更低。具体而言,TSSA 的复杂度为 O (pn),显著优于传统 Transformer 的 O (n²)。

ToST 在计算时间和内存使用上均随序列长度实现线性扩展,使其显著优于标准 Transformer 的效率。如下:


复杂度分析对比

在 GPU 上评估的速度和内存使用对比


2. 视觉任务性能分析


在 ImageNet-1k 等主流视觉数据集上的实验表明,ToST 的性能可与传统 Transformer 架构(如 ViT 和 XCiT)相媲美,同时显著减少了模型参数量和计算开销。


迁移学习实验中,ToST 在 CIFAR、Oxford Flowers 等数据集上的表现进一步验证了其在多种视觉任务中的适应性。


结果展示了与传统 Transformer 相当的性能,同时在计算效率上显著更高。

3. 长序列任务和语言建模


  • 长序列任务


在长序列任务基准测试(如 Long-Range Arena)中,ToST 展现出优异的长距离建模能力,其性能超越了现有 Transformer 变体。


  • 语言建模


ToST 可以扩展并适用于多种任务场景,包括因果语言建模。针对语言建模,ToST 采用了一种因果版本的 TSSA,在多个数据集上实现了高效的预测能力。此外,即使在参数规模扩大的情况下,ToST 依然保持了优异的时间和内存效率。


NLP 任务中的表现


4. 有原理支持的模型设计


由于 ToST 是通过展开从学习目标中推导出来的,我们可以以有原理支持的方式逐层分析学习到的模型行为。


ToST 模型不同层次的 TSSA 输出的变分压缩项







请到「今天看啥」查看全文