【新智元导读】
AlphaFold 3的论文太晦涩?没关系,斯坦福大学的两位博士生「图解」AlphaFold 3 ,将模型架构可视化,同时不遗漏任何一个细节。
谷歌DeepMind的人工智能模型AlphaFold 3两个月前横空出世,颠覆了生物学。
这个「值得获得诺贝尔奖的发明」不仅在学术圈引起了巨震,还轰动了制药界——它可能带来数千亿美元的商业价值,并对药物研发产生深远影响。
论文地址:https://www.nature.com/articles/s41586-024-07487-w
如此重要的AlphaFold3,其具体工作原理是什么?
因为AlphaFold3的结构非常复杂,论文有相当高的阅读门槛,让人望而却步。两位斯坦福大学的两位博士生制作了一个论文的「图解版」,比论文阅读起来友好多了,而且还很详尽!
每一位机器学习工程师都不应该错过这篇图文并茂的文章——
博客地址:https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/
此前已经有很多关于蛋白质结构预测的研究动机、CASP竞赛、模型失效模式、关于评估的争论、对生物技术的影响等主题的文章,因此以上内容都不是这篇博文关注的重点。
这篇博文关注的重点在于AlphaFold3(以下简称AF3)是如何在技术上实现的:分子在模型中被如何表示,它们又是如何被转换成预测结构的?
AlphaFold3和前代模型最大的不同点在于——预测目标不同。
AF3不仅预测单个蛋白质序列(AF2)或蛋白质复合物(AF-multimeter)的结构,还能预测蛋白质与其他蛋白质、核酸、小分子中的一种或多种物质的复合结构,而且仅根据序列信息。
因此,前代的AF模型只需表示标准的氨基酸序列,但AF3需要引入更复杂的输入类型,因此设计了更复杂的特征表示和tokenization机制。
tokenization过程会在稍后单独描述,目前我们只需要知道,token可能代表单个氨基酸(蛋白质)、核苷酸(DNA/RNA),或者单个原子(其他物质)。
- 输入准备:给定输入的分子序列,模型需要检索一系列的结构相似的分子。这一步骤会识别出这些分子,并将其编码为数值张量。
- 表征学习:给定上一步中创建的张量,使用注意力机制的多种变体来更新这些表征。
- 结构预测:基于第一部分创建的原始输入以及第二部分改进后的表征,使用条件扩散进行结构预测。
在整个模型中,蛋白质复合物有两种表示形式:单一表征(single representation)和配对表征(pair representation),这两种表示都可以应用于token级别或原子级别。
前者仅仅表示复合物中的所有token或原子,后者则表征了物质中所有token/原子之间的关系(如距离、潜在相互作用等)。
为了简单起见,下述的结构中忽略了大多数LayerNorm层,但其实它们无处不在。
用户向AF3提供的实际输入是一个蛋白质序列和可选的其他分子。
本节的目标是将这些序列转换成一系列6个张量,这些张量将作为模型主干的输入.
本节包含5个步骤,分别是tokenization、检索、创建原子级表征、更新原子级表征、原子级到token级集成。
tokenization
在AF2中,由于模型只表示具有固定氨基酸集的蛋白质,因此每个氨基酸都拥有自己的token。
AF3保留了这一点,但也引入了额外的token,以便处理其他分子类型:
-非标准氨基酸或核苷酸(甲基化核苷酸、翻译后修饰的氨基酸等):每个原子1个token
因此,我们可以认为某些token(如标准氨基酸/核苷酸)对应多个原子,而其他token (如配体中的原子)只对应单个原子。
比如,35个标准氨基酸的序列(可能大于600个原子)将由35个token来表示;同时,一个由35个原子组成的配体也同样由35个token表示。
检索
AF3的早期关键步骤之一类似于语言模型中的检索增强生成(Retrieval Augmented Generation,RAG)。
模型会检索到与与输入序列相似的序列(收集到多序列比对中,即MSA),以及与这些序列相关的任何结构(称为「模板」),将它们作为模型的附加输入,分别写作m和t。
与AF-multimer相比,这些检索步骤中唯一的新内容是,除了蛋白质序列外,我们现在还对RNA序列进行检索。
请注意,这在传统上并不被称为「检索」,因为早在RAG这个术语出现之前,使用结构模板指导蛋白质结构建模就已经是同源建模领域的常见做法了。
不过,尽管AlphaFold没有明确将这一过程称为检索,但它确实与现在流行的RAG非常相似。
通过模板搜索,我们获得了每个模板的三维结构,以及有关哪些token位于哪些链中的信息。
首先,计算给定模板中所有token对之间的欧氏距离。如果是对应多个原子token,则使用一个有代表性的「中心原子」来计算距离。
比如,对于氨基酸来说,中心原子是C
ɑ
原子,而标准核苷酸的中心原子是C1'原子。
这会为每个模板生成一个N
token
x N
token
大小的矩阵。不过,距离并不是用数值来表示,而是将距离离散化为「距离直方图」。
然后,我们向每个直方图添加元数据,关于每个token属于哪个链、该token是否在晶体结构中得到解析、以及每个氨基酸内部的局部距离信息。
然后,我们对这个矩阵进行掩码,只查看每条链内部的距离(忽略链之间的距离)。根据论文的解释,这样做的原因是「并不尝试选择模板......以获取链间交互的信息」。
创建原子级表征
为了创建q(原子级单一表征),我们需要提取所有原子级特征。
第一步是计算每个氨基酸、核苷酸和配体的「参考构象」(reference conformer)。虽然我们还不知道整个复合物的结构,但我们对每个单独组件的局部结构有很多的先验知识。
构象(构象异构体的简称)是分子中原子的三维排列,通过对单键的旋转角度进行采样而生成。
每种氨基酸都有一个「标准」构象,但这只是该氨基酸存在的低能量构象之一,可以通过查表找到。
不过,每个小分子都需要生成自己的构象,利用RDKit中的ETKDGv3算法,同时结合了实验数据和扭转角偏好来生成三维构象。
然后,我们将该构象信息(相对位置)与每个原子的电荷、原子序数和其他标识符连接起来。矩阵c存储了序列中所有原子的这些信息。
然后,我们用c来初始化原子级别的配对表征p,以存储原子间的相对距离。
由于我们只知道每个token内的参考距离,因此先使用掩码机制(v)来确保这个初始距离矩阵只代表我们在构象生成过程中计算出的距离。
最后,我们将原子级别的单一表征复制一份,并将这个副本称为q。这个矩阵q是我们接下来要更新的,但c确实会被保存并稍后使用。
原子Transformer
在生成了q(单个原子的表征)和 p(原子配对表征)之后,我们现在要根据附近的其他原子更新这些表征。
AF3使用一个名为原子Transformer的模块,在原子级别应用注意力机制时
原子Transformer主要遵循标准的Transformer结构。不过,其具体步骤都经过了调整,以处理来自c和p的额外输入。
原子级->token级
到目前为止,所有数据都是以原子级别存储的,而AF3的表征学习部分则从这里开始以token级别运行。
为了创建token级表征,我们首先将原子级表征投影到一个更大的维度(c
atom
=128,c
token
=384)。然后对分配给同一token的所有原子取平均值。
请注意,这只适用于与标准氨基酸和核苷酸相关的原子,其余原子保持不变。
现在我们就从「原子空间」进入了「token空间」。
之后将token级特征和MSA中的统计信息连接起来,形成矩阵s
inputs
并被向下投影到c
token
,作为序列的起始表征s
init
。
s
init
将在表征学习部分中被更新,但s
inputs
保持不变,用于结构预测部分。
经过一系列输入准备后,我们就来到了模型的主干部分,也是完成大部分计算量的部分。
这部分模型的学习目标是改进上述token级别的单一(s)或成对(z)张量的初始化表示,因此被称为「表征学习」。
- 模板模块(template module):使用模板t更新张量z
- MSA模块:更新MSA表征m,再将其引入token级别的张量z
- Pairformer:使用三角注意力更新张量s、z
以上步骤会重复运行多次,每次输出结果后再将其反馈到自身继续作为输入,继续进行计算(如上图中蓝色需先所示),这种做法被称为「回收」(recycling)。
模板模块
该模块的计算流程如下图所示(模板个数N
template
=2)。
每个模板t和张量z经过线性投影后相加得到矩阵v,再经过一系列被称作Pairformer Stack的操作(后文详述)。
之后,N个模板被平均到一起,再通过另一个线性层和一次ReLU,得到最终结果。
有趣的是,这是AF3模型中唯二使用ReLU的地方之一,但论文中并没有解释为什么选择ReLU而非其他非线性函数。
MSA模块
AF3中的MSA与AF2中的Evoformer非常类似,都是在同时改进MSA表征和配对表征,对两者分别独立执行一系列操作后进行交互。
处理MSA表征的第一步是下采样,而非使用之前生成的MSA的所有行(最多可达 16k)。下采样后,还要加入经过投影映射的单一表征s。
之后,MSA表征m通过外积均值方法(outer product mean)被合并到配对表征中。
如下图所示,比较MSA中的两列揭示了有关序列中两个位点之间的关系信息(比如进化过程中的相关性)。
对于每对标记索引i,j,我们迭代所有进化序列s,获取 m
s,i
和 m
s,j
的外积,在所有进化序列中进行平均。
然后,我们压平这个外部积并将其投影回去,最后将其添加至配对表征z
i,j
。
虽然每个外积仅对给定序列m
s
内的值进行操作,但取平均值时会混合序列之间的信息。这是模型中唯一能在进化序列之间共享信息的机制。
这个方法是相对于AF2的重大改变,旨在降低Evoformer的计算复杂度。
根据MSA更新配对表征后,模型接下来根据后者更新MSA,这种特定的更新模式是原子Transformer中所述的「具有配对偏差的自注意力」的简化版本,被称为「仅使用配对偏差的行内门控自注意力」(row-wise gated self-attention using only pair bias)。
这种方法受到注意力机制的启发,但并不使用查询和键计算,而是直接使用存储在配对表征z中的token间关系。
如下图所示,在张量z中,每个c
z
维度的向量z
i,j
都表示第i个和第j个token间的关系。将z线性投影到矩阵b后,每个z
i,j
向量变为标量,就可以相当于「注意力分数」(attention score),用于加权平均。
最后,MSA通过一系列「三角更新」(triangle updates)和注意力机制来更新配对表征,其中「三角更新」与下面Pairformer的描述相同。
Pairformer
经过前两个模块后,模板t和MSA表征m的作用就结束了,只有单一表征z和经过更新的配对表征s进入Pairformer并用于相互更新。
Pairformer中值得注意的是「三角更新」和「三角自注意力」方法,它们首次在AF2模型中出现,并被保留在AF3中,而且正在被应用到越来越多的架构中。
这里的指导原则是三角形不等式的思想:「三角形任意两边之和大于或等于第三条边。」
回想一下,张量z中的值z
i,j
编码序列中位置i和j之间的关系。虽然并没有显式地对token间的物理距离进行编码,但的确包含了这层含义。
如果我们想象每个z
i,j
代表两个氨基酸之间的距离,并且有z
i,j
=1和z
j,k
=1。那么根据三角形不等式,z
i,k
不能大于2。
「三角更新」和「三角形自注意力」的目标就是尝试将这些几何约束编码到模型中,但并不会强制执行,而是鼓励模型在每次更新z
i,j
的值时考虑所有可能的三元组(i,j,k)。
此外,z不仅代表物理距离,还编码了token之间复杂的物理关系,因此向量z
i,j
是有方向的。
所以,如上图所示,在对节点k进行「三角更新」和「三角自注意力」操作时,需要分别查看两种有向路径,出边(outgoing edge)和入边(incoming edges)。
从图论角度理解「三角」操作后,我们就能明白以下的张量更新和注意力机制是如何通过张量运算实现的。
使用出边进行更新时,使用到了z的三个线性投影a、b和g。
为了更新z
i,j
,需要对z
i,k
和z
j,k
进行操作,即对a中的第i行和b中的第j行进行逐元素(element-wise)乘法,之后对所有行(不同k值)求和,再用g进行门控。
为了更新z
i,j
,需要对z
k,i
和z
k,j
进行操作,即对a中的第i列和b中的第j列进行逐元素乘法,再对所有行求和。
可以发现,出边和入边的「三角更新」操作与上面标出有向路径的两个紫色三角形完全对应。
接下来,分别使用出边和入边的「三角自注意力」更新每个z
i,j
值。AF3论文将这两个过程分别称为「围绕起始节点」(around starting node)和「围绕结束节点」(around ending node)。
回忆一下,典型的一维序列自注意力中,查询、键和值都是原始一维序列的转换。自注意力的二维变体——轴向注意力中,在二维矩阵的不同轴上(行,然后列)上独立应用一维自注意力。
以此类推,「三角自注意力」在轴向注意力的基础上添加了之前讨论的「三角形不等式」,通过合并所有k值的z
i,k
和z
j,k
来更新z
i,j
。
比如,在围绕起始节点的情况中(下图),为了计算注意力分数z
i,j
,需要将q
i,j
与k矩阵中第i行每个值相乘(以确定第j列受到其他列的影响),然后加上z
j,k
的注意力偏置。
围绕结束节点的情况同样是行列对称,为了计算z
i,j
,需要将q
i,j
与k矩阵中第i列每个值相乘,注意力偏置则来自第j列。