02-TTT背景简介
2020年,OpenAI缩放定律论文表明,LSTM(一种RNN)不能像Transformers那样进行缩放,也不能有效地使用长上下文。
如上图所示,在左边,我们观察到Mamba,当今最受欢迎的RNN之一,它的规模与强大的Transformer相似,显示出自2020年LSTM以来的巨大进步。然而,在右边,我们观察到
Mamba的问题与Kaplan等人对LSTM的问题相同。序列中较晚的令牌平均来说应该更容易预测,因为它们以更多的信息为条件Transfor mer的情况确实如此,其在每个令牌索引处的平均困惑度在整个32k上下文中都会降低。
相比之下,Mamba在16k后也出现了同样的指标平稳期。
这个结果代表了现有RNN的尴尬现实。一方面,
RNN(与Transformer相比)的主要优点是其线性(与二次型)复杂性
。这种渐近优势只有在长上下文的实践中才能实现,长上下文是在8k之后。
另一方面,一旦上下文足够长,现有的RNN(如Mamba)就很难真正利用所依赖的额外信息。
长上下文的困难是RNN层固有的,与自我注意力不同,RNN层必须将上下文压缩到固定大小的隐藏状态。作为一种压缩启发式方法,更新规则需要发现数千个或可能数百万个令牌之间的底层结构和关系。在本文中,作者观察到:
自监督学习可以将大量训练集压缩为LLM等模型的权重,LLM通常对其训练数据之间的语义连接表现出深刻的理解,这正是我们所需要的。
自注意力机制在长上下文中表现良好,但具有二次复杂性。
现有的RNN层具有线性复杂性,但它们在长上下文中的性能受到其隐藏状态的表达能力的限制。
本文作者提出了一类新的序列建模层,它具有线性复杂性和可表达的隐藏状态。关键思想是使隐藏状态本身成为机器学习模型,更新规则成为自监督学习的一个步骤。
由于隐藏状态甚至在测试序列上也通过训练来更新,因此该层被称为测试时间训练(TTT)层。
随后,作者考虑了两种实例:
TTT-Linear和TTT-MLP
,它们的隐藏状态分别是线性模型和两层MLP。作者在125M到1.3B参数的范围内评估这些实例的性能,与强大的Transformer和现代RNN Mamba进行了比较。
大量的实验结果表明:
TTT-Linear和TTT-MLP都匹配或超过基线。与Transformer类似,它们可以通过限制更多的代币来不断减少困惑,而Mamba在16k上下文后则不能。
经过初步的系统优化,TTT Linear在8k环境下已经比Transformer更快,并且在wall-clock时间上与Mamba相匹配。
然而,
TTT-MLP在内存I/O方面仍然面临挑战
,但在长上下文情况下显示出更大的潜力,为未来的研究指明了一个有希望的方向。
如上图所示,
所有的序列建模层都可以表示为根据更新规则转换的隐藏状态。
TTT架构的关键思想是使隐藏状态本身成为权重为W的模型f,并且更新规则是自监督损失上的梯度步长Ş。
因此,在测试序列上更新隐藏状态相当于在测试时训练模型f。这个过程被称为测试时间训练(TTT),被编程到我们的TTT层中。
将任何RNN层集成到更大架构中的最简单的方法是直接替换Transformer中的自注意力机制,在本文中称为主干。
然而,现有的RNN,如Mamba和Griffin,都使用与Transformer不同的主干。最值得注意的是,它们的主干在RNN层之前包含时间卷积,这可能有助于跨时间收集局部信息。
如上图所示,
左图展示了一个残差块
,它是Transformer的基本构建块。序列建模块被实例化为两个变体:Transformer主干和Mamba主干。
中间的图表示Transfo rmer主干中的TTT层
。O之前的LN来自NormFormer。
右图表示受到Mamba和Griffin的启发,在骨干中的TTT层
。根据这两种架构,σ在这里指的是GELU。为了在不改变嵌入维度的情况下容纳门的额外参数,作者简单地将θK和θQ组合成一个投影。
如上图所示,所有序列建模层都可以从将历史上下文存储到隐藏状态的角度进行查看。
顶部表示了通用序列建模层,表示为根据更新规则转换的隐藏状态。所有序列建模层都可以被视为该图中三个组件的不同实例化:初始状态、更新规则和输出规则。
底部表示了序列建模层的示例及其三个组件的实例化。初始TTT层如图所示。
自注意力有一种隐藏状态,随着上下文的增长而增长,因此每个令牌的成本也在增长。原生的RNN和TTT层都将不断增长的上下文压缩为固定大小的隐藏状态,因此它们的每个令牌的成本保持不变。
05.02-TTT高级计算图
上图展示了第一个TTT小批量的高级计算图,
其中节点表示变量,边表示计算。蓝色节点表示输入变量,黄色节点表示输出变量。
由于G1, ...,Gb之间没有连接,它们之间没有顺序依赖关系,因此它们可以并行计算。
实际上作者并没有具体化白色节点中间的Gs和Ws来计算对偶形式的输出变量。
上面的代码按照PyTorch的风格,用线性模型和在线GD实现了TTT层。TTT_Layer可以像其它序列建模层一样被放入更大的网络中。
训练网络将优化TTT_Layer中Task的参数,因为两者都是nn的子类。单元由于学习者不是nn的子类。模块state.model在内部循环中为state.train的每次调用手动更新。为了简单起见,作者有时会将模型重载为model.parameters。
from transformers import AutoTokenizer
from modeling_ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['ttt-1b']) is equivalent to the following
configuration = TTTConfig()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM(configuration)
model.eval()
# Accessing the model configuration
configuration = model.config
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
# Prefill
input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids
logits = model(input_ids=input_ids)
print(logits)
# Decoding
out_ids = model.generate(input_ids=input_ids, max_length=50)
out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
print(out_str)
07-TTT性能评估
如上图所示,左图展示了在书籍的缩放趋势,在350M~1.3B参数之间放大。在760M和1.3B时,
TTT Linear在使用较少FLOP的困惑方面优于Mamba,在线性插值下优于Transformer。
右图展示了Transformer和TTT Linear可以在更多Tokens的条件下不断减少复杂度,而Mamba在16k上下文后则不能。所有方法都匹配训练FLOP为Mamba 1.4B。
通过观察与分析,我们可以发现:
与Mamba相比,TTT Linear具有更好的困惑性和更少的FLOP,以及更好地使用长上下文。
上图展示了该算法
在Pie上的
多个变种架构与Transformer、Mamba在上下文长度为2k和8k情况下的性能评估。通过观察与分析,我们可以得出以下的初步结论 :
为了评估该架构的长上下文处理
能力,作者使用一个名为Books3的流行测试集,按照2×增量对1k到32k的上下文长度进行实验。这里的训练配方与Pile的训练配方相同,TTT层的所有实验都在一次训练中进行。通过观察与分析上图,我们可以得出以下的初步结论: