本文约7000字,建议阅读14分钟
本文介绍了一种贝叶斯风格的注意力机制,用于序列预测。我们将详细阐述如何使用马尔可夫链蒙特卡罗法(MCMC)训练该模型。
当前的大型语言模型在处理长序列文本时面临挑战。主要的瓶颈在于注意力机制,它将文本处理为单词(或 tokens)序列。注意力计算的复杂度随序列长度 T 呈平方增长,导致处理长文本的成本显著增加。为了降低计算成本,研究人员积极探索注意力的替代方案,包括递归模型(如 Mamba [1] 和 xLSTM [2])、卷积模型(如 Hyena [3])以及基于稀疏性的模型(如 Longformer [4] 和 BigBird [5])。其中状态空间模型作为一种有效的注意力替代方案,受到了越来越多的关注。例如基于状态空间模型的大型语言模型(LLM),如 Mamba 2 [6],在多项任务上表现出与 transformers 相当甚至更优的性能 [7]。状态空间模型还为理解注意力与半可分离矩阵和状态空间模型之间的关系提供了新的视角 [6]。
然而当前基于注意力和状态空间的模型通常需要大量的训练数据。当训练数据有限,或者需要将领域知识融入模型时,贝叶斯方法是一种有效的选择。与标准神经网络训练不同,贝叶斯模型不易受到过度自信的影响 [8],并且支持利用未标记数据进行训练。此外贝叶斯模型能够提供不确定性估计,这在金融和医疗等高风险领域具有重要价值。
本文介绍了一种贝叶斯风格的注意力机制,用于序列预测。我们将详细阐述如何使用马尔可夫链蒙特卡罗法(MCMC)训练该模型。
贝叶斯注意力适用性评估
以下是一些建议,帮助您快速评估贝叶斯注意力是否适合您的应用场景:
如果满足以下一个或多个条件,可以考虑使用贝叶斯注意力:
注意力机制回顾
在开始之前,我们回顾一下注意力机制,它可以被视为一种字典查找过程,其中查询
q ᵢ
与键
k
进行比较,并检索相应的值
v ᵢ[9]
。具体而言,对于
d
维潜在空间:
Attn(
Q ,K ,V
) = S(
QK
⊤/√d)
V
,
其中 softmax 函数定义为 S(
x
) = [exp[x ₁]/Z ,exp[x ₂]/Z,…],Z 是配分函数(归一化常数)。在自注意力机制中,权重矩阵
W
用于生成键
K
=
W
(k)
x
,查询
Q
=
W
(q)x 和值
V
=
W
(v)x ,它们都从同一个输入序列 x 导出(但使用不同的矩阵)。关键在于,
QK
⊤ 项的形状为
T
×
T
,导致其计算复杂度为二次方。通过使用核技巧 [10],我们可以消除 softmax 激活函数并实现线性化。该技巧的核心思想是利用泰勒展开将 exp[xy] 重写为内积形式:
exp[
xy
] =
𝛙
(x)·**
𝛙
**(y),
其中
𝛙
(x) = [1, x, x ²/√2,x ³/√6,…]。因此,注意力机制可以视为一个线性变换:
Attn(Q’ ,K’ ,V’) =
Q’K’
⊤** V’** ,
在由核函数
𝛙
定义的线性基
𝛙
(Q) ->
Q
’ 中进行。
状态空间模型简介
状态空间模型通过一阶微分方程描述系统
y
(t) 对给定输入
V
(t) 的响应。状态空间模型常用于描述电路动态、化学反应器中的浓度分布或机械系统动态。具体来说,状态空间模型描述了 状态
h
(t) 的演变:
d/d t
h
(t) =
A
(t)
h
(t) +
K
(t)
V
(t),
y
(t) =
Q
(t)
h
(t) +
D
(t)
V
(t).
这里,
A
(t) 称为状态矩阵 ,
K
(t) 称为输入矩阵 ,它将输入
V
(t) 投影到状态上,
Q
(t) 称为输出矩阵 ,它将状态
h
(t) 转换回输出,
D
(t) 称为 反馈
矩阵 。矩阵
Q
、
K
和
V
的命名是为了体现它们作为查询、键和值的角色。
在机器学习中,我们处理的是这些方程的离散化版本,序列元素 t = 0,1,2… 而不是连续时间。将 Q ₜ, K ₜ, 和 V ₜ 称为 Q(t)、 K(t) 和 V(t) 的离散化对应物。我们选择将 A(t) 的离散化版本设置为标量 aₜ 乘以单位矩阵,并删除 D ₜ 而不失一般性。
离散化后的模型可以表示为:
hₜ = aₜ h ₜ ₋₁ +
K
ₜ** V** ₜ ,
yₜ =
Q
ₜ ⊤ hₜ,
这是一个具有隐藏状态 hₜ 的线性递归神经网络(图 1)。
接下来,定义 h ₀ =
K
₀ _**V** ₀ 并展开递归,得到 _yₜ 的表达式:
yₜ =
Q
ₜ ⊤ aₜ … a ₁** K** ₀ V ₀ +
Q
ₜ ⊤ aₜ … a ₂**
K
** ₁ V ₁ +
Q
ₜ ⊤ aₜ … a ₃**
K
** ₂ _**V** ₂ + … + **
Q
** _ₜ ⊤** K** ₜ** V** ₜ
≡ ∑ ₛ
M
ₜₛ
V
ₛ.
因此状态空间模型等价于具有特定因果掩码 aₜ … a ₁ 的线性注意力机制。系数 aₜ 可以被认为是可训练的位置编码 [6]。计算 yₜ 的复杂度为线性时间,而非二次时间。
贝叶斯方法融合
在回顾了基础知识后,我们考虑一个具有输入
x
和高斯输出
y
的自注意力模块。键
K
=
W
(k)
x
,查询
Q
=
W
(q)
x
和值
V
=
W
(v)
x
依赖于输入
x
,即自注意力机制。
我们不再将权重 {
W
(k),
W
(q),
W
(v)} 视为固定数值,而是赋予它们高斯分布。同样我们也对位置编码 aₜ 采用这种处理方式。这样做有以下两个关键优势:(i) 我们可以通过指定权重在观察数据之前的先验分布,将领域知识融入模型。(ii) 允许模型对权重和编码的不确定性进行建模。
生成模型
现在我们定义一个生成模型来描述数据集的生成过程。给定一个包含 i = 1...m 个输入 x 的数据集,每个输入包含 t = 0,…,T −1 个序列元素,以及 R 个特征。设 N 为隐藏单元 h 的维度。
模型的参数和输出根据以下统计过程生成:
Eq. (1): 贝叶斯注意力的生成模型。
符号
x
~ N(
μ ,Σ
) 表示
x
服从均值为
μ
,协方差矩阵为
Σ
的高斯分布。Eq. (1) 表示每个输出
y
都是从隐藏状态
h
递归生成的,并叠加了大小为 σ 的高斯噪声。此外,键、查询和值权重 {
W
(k),
W
(q),
W
(v)} 和位置编码 a 也是从先验分布中抽取的有噪声的样本,其均值分别为
p
(
W
) 和
p
(
a
),并具有预定义的协方差。
通过训练模型,我们可以根据观测数据
y
更新对权重
W
和位置编码
a
的合理值的信念。
贝叶斯模型训练
简要介绍贝叶斯模型的训练方法。标准的机器学习模型训练方法是找到一组参数 Θ ={W(k),W(q),W(v), a},以优化损失函数,通常是(负)对数似然函数。贝叶斯方法与之不同,它不是寻找参数 Θ 的最优值,而是推断参数在给定观测值 y 下的分布 p(Θ |y),即
后验分布。
由于后验分布 p(Θ |y) 通常难以直接求解,因此需要对其进行近似。估计 p(Θ |y) 的主要方法有两种:变分推断和马尔可夫链蒙特卡罗 (MCMC) 模拟。变分推断试图优化一个代理分布 q(Θ),使其逼近 p(Θ |y)。而 MCMC 方法则通过从 p(Θ |y) 中抽取样本 Θ 来进行估计。虽然我们无法直接从 p(Θ |y) 中抽取样本,但可以通过马尔可夫链蒙特卡罗过程逐步逼近。MCMC 描述了一个随机过程 p(Θ ₛ |Θ ₛ ₋₁),当步数 s 足够大时,该过程将收敛到 p(Θ |y)。
我们仍然需要指定 p(Θ ₛ |Θ ₛ ₋₁),即如何进行随机抽样。有多种方法可以保证 p(Θ ₛ |Θ ₛ ₋₁) 收敛到后验分布。
本文将重点介绍一种方法:
吉布斯采样
。吉布斯采样是一种高效的随机抽样方法。其一般步骤如下:给定一个包含 D 个参数的样本 Θ ₛ ₋₁ = [Θ ₛ ₋₁(1), Θ ₛ ₋₁(2),…,Θ ₛ ₋₁(D)],通过交替抽取每个参数 Θ ₛ(i),并以其他参数的先前值为条件,来采样 p(Θ ₛ |Θ ₛ ₋₁)。具体来说:
-
Θ ₛ(1)~ p[Θ ₛ(1)|Θ ₛ ₋₁(2),Θ ₛ ₋₁(3),..,Θ ₛ ₋₁(D), y],
-
Θ ₛ(2)~ p[Θ ₛ(2)|Θ ₛ(1), Θ ₛ ₋₁(3),..,Θ ₛ ₋₁(D), y],
-
…
-
Θ ₛ(D)~ p[Θ ₛ(D)|Θ ₛ(1), Θ ₛ(3),..,Θ ₛ(D -1), y].
重复此过程,直到 Θ ₛ 收敛到后验分布 p(Θ |y)。高效吉布斯采样的关键在于条件分布 p[Θ ₛ(i)|Θ ₛ(1), Θ ₛ(2),..,Θ ₛ(i -1),Θ ₛ ₋₁(i +1),…** Θ** ₛ ₋₁(D), y] 易于采样。幸运的是,参数的先验分布和似然函数均为高斯分布 [Eq. (1)],这意味着条件分布也必须是高斯分布。
基于吉布斯采样的推理
回到我们的注意力模块。由于参数和输出 y 都是高斯分布,我们可以解析地计算条件分布,从而进行吉布斯采样以获得后验分布。算法如下:
Eq. (2): 贝叶斯注意力的吉布斯采样算法。
请注意,表示先验均值和协方差的 breve 符号已被 hat符号替换,表示条件均值和协方差。从直观上看,该算法可以解释为一系列交替回归,其中回归系数为模型参数,其余参数保持固定。
在 JAX 代码中,Eq. (2) 的核心部分如下所示:
def kernel(key, y, X, state):
"""Take one Gibbs sampling step, updating the model parameters
Args:
key: Pseudo random number generator key.
y: Observed output of the module.
X: Input to the module.
state: A tuple with the current model parameters.
Returns: A new configuration of the (unobserved) model parameters."""
key_seq = KeySeq(key)
(ln_a, W_V, W_K, W_Q) = state
L = jnp.exp(segsum(ln_a))[jnp.newaxis,...] # Broadcast across batch.
L = jnp.real(L) # Discard zero imaginary part.
# 1) W_K[α] ~ p(W_K[α]|--).
Q = jnp.einsum('αβ,bsβ->bsα', W_Q, X)
V = jnp.einsum('pi,bsi->bsp', W_V, X)
W_K = sample_weights_keys(next(key_seq), W_K, Q, V, L, X, y)
# 2) W_Q[α] ~ p(W_Q[α]|--).
K = jnp.einsum('nk,bsk->bsn', W_K, X)
W_Q = sample_weights_queries(next(key_seq), W_Q, K, V, L, X, y)
# 3) W_V[β] ~ p(W_V[β]|--).
Q = jnp.einsum('αβ,bsβ->bsα', W_Q, X)
W_V = sample_weights_values(next(key_seq), W_V, K, Q, L, X, y)
# 4) a[t] ~ p(a[t]|--).
V = jnp.einsum('pi,bsi->bsp', W_V, X)
ln_a = sample_positional_encodings(next(key_seq), ln_a, K, V, Q, X, y)
state = (ln_a, W_V, W_K, W_Q)
return state
其中,segsum 用于构建因果掩码 L [6](如下所述),KeySeq 是一个辅助类,用于管理伪随机数生成器。
接下来,我们将逐一介绍各个采样函数。
位置编码 aₜ 的采样
为了对位置编码系数 aₜ 进行吉布斯采样,我们将预测分解为与 aₜ 成比例的项和常数项:
Eq. (3a): 预测分解为与位置编码 _aₜ 成比例的项和常数项。
我们暂时忽略训练样本索引 i 和特征索引 p。只有 𝜏 ≥ t 的项包含 y -lesser 预测中的 aₜ。这是生成模型 [Eq. (1)] 的直接结果,其中 aₜ 出现在 h ₜ = aₜ h ₜ ₋₁ + K ₜ** V** ₜ 中,因此仅影响 y ₜ 及其后续输出。通过乘以并除以 aₜ,我们将其转化为一个回归问题,截距为 y -greater。经过整理,Eq. (2) 中的均值和方差变为:
Eq. (3b): 用于吉布斯采样位置编码 aₜ 的条件均值和协方差矩阵。
其中,我们定义了 c 和 d:
Eq. (3c): 位置编码 aₜ 的精度和相关性。
这可以解释为精度和相关性的度量。
在代码中将状态 aₜ 的对数存储为复数,以跟踪负号。此外还构建
因果掩码
矩阵
L
,其中下三角元素设置为 Lₜₛ = aₜ … aₛ ₊₁,对角线上为 1。
def sample_positional_encodings(key, ln_a, K, V, Q, X, y):
key_seq = KeySeq(key)
G = jnp.einsum('btn,bsn->bts', Q, K)
for t in range(1, T):
L = jnp.exp(segsum(ln_a))[None,...] # (B,T,S)
L = jnp.real(L) # Discard zero imaginary part.
M = jnp.einsum('bts,bts->bts', G, L)
# Compute y<.>
y_pred_cumul = jnp.cumsum(
M[..., jnp.newaxis] * V[:,jnp.newaxis,...], axis=-2,
) # (B,T,S,P)
y_pred = y_pred_cumul.diagonal(axis1=-3, axis2=-2) # (B,P,T)
y_pred = rearrange(y_pred, 'b p t -> b t p')
mask = jnp.triu(jnp.ones([T, T], dtype=bool)) # (T,S)
mask = mask[jnp.newaxis, ..., jnp.newaxis] # (B,T,S,P)
y_pred_lt = jnp.where(mask, 0, y_pred_cumul)
# Insert column with zeros at index 0.
y_pred_lt = jnp.pad(y_pred_lt, ((0, 0), (0, 0), (1, 0), (0, 0)))
y_pred_lt = y_pred_lt[..., :-1,:] # (B,T,S,P)
# Interpret as a regression with regression coefficient a:
# y-y≥ = a * (y< / a),
# where we move all terms not containing `a` into the output.
y_pred_lt_t = y_pred_lt[...,t,:]
y_pred_ge_t = y_pred - y_pred_lt_t # y≥.
is_geg_t = (jnp.arange(T) >= t).astype(int)[None, :, None] # (B,T,P)
outputs = (y - y_pred_ge_t) * is_geg_t # tau >= t
outputs = rearrange(outputs, 'b t p->(b t p)')
a_t = jnp.real(jnp.exp(ln_a[t]))
inputs = y_pred_lt_t * is_geg_t / a_t # (B,T,P)
inputs = rearrange(inputs, 'b t p->(b t p)')
μ_a_posterior, σ_a_posterior = scalar_gaussian_posterior(
outputs, inputs, μ_a_prior[t], Λ_a_prior[t],
)
a_t = random.normal(next(key_seq)) * σ_a_posterior + μ_a_posterior
ln_a = ln_a.at[t].set(jnp.log(a_t.astype(complex)))
return ln_a
其中,scalar_gaussian_posterior 计算单变量高斯回归的均值和方差。
键权重 W(k) 的采样
为了采样键权重
W
(k),我们将预测的输出写成
W
(k) 和矩阵
Γ
之间的点积。利用高斯分布的共轭性质,条件均值和协方差矩阵为:
Eq. (4a): 用于吉布斯采样键权重的条件均值和协方差矩阵。
其中定义了:
Eq. (4b): 回归数据和偏移量。
将这些方程解释为线性回归,Γ 可以被认为是协变量,没有 W(k) 的行 α = 1…N 的预测作为偏移量,y 作为输出。如 Eq. (2) 所示,每行单独采样。在代码中,采样步骤如下:
def sample_weights_keys(key, W_K, Q, V, L, X, y):
key_seq = KeySeq(key)
Γ = jnp.einsum('btα,bts,bsβ,bsp->btpαβ', Q, L, X, V)
for α in range(n_components):
y_pred = jnp.einsum('αβ,btpαβ->btp', W_K, Γ)
y_pred_α = jnp.einsum('β,btpβ->btp', W_K[α], Γ[:,:,:,α])
y_residual = y - (y_pred - y_pred_α)
y_residual = rearrange(y_residual, 'b t p->(b t p)')
λ = rearrange(Γ[:,:,:,α], 'b t p β->(b t p) β')
μ, Σ = gaussian_posterior(
y_residual, μ_prior=μ_k_prior[α], Λ_prior=Λ_k_prior[α], X=λ,
)
w_α = random.multivariate_normal(next(key_seq), μ, Σ)
W_K = W_K.at[α].set(w_α)
return W_K
其中,函数 gaussian_posterior 计算多元贝叶斯回归的均值和协方差矩阵。
查询权重 W(q) 的采样
接下来,对与查询对应的权重进行采样几乎完全类似。将问题视为具有协变量的回归:
Eq. (5a): 用于采样查询权重的回归协变量。
可以使用条件分布对查询权重
W
(q) 进行采样:
Eq. (5b): 用于吉布斯采样查询权重的条件均值和协方差矩阵。
代码类似于 sample_weights_keys,只是对协变量和截距进行了细微更改:
def sample_weights_queries(key, W_Q, K, V, L, X, y):
key_seq = KeySeq(key)
Λ = jnp.einsum('btβ,bts,bsα,bsp->btpαβ', X, L, K, V)
for α in range(n_components):
y_pred = jnp.einsum('αβ,btpαβ->btp', W_Q, Λ)
y_pred_α = jnp.einsum('β,btpβ->btp', W_Q[α], Λ[:,:,:,α])
y_residual = y - (y_pred - y_pred_α)
y_residual = rearrange(y_residual, 'b t p->(b t p)')
λ = rearrange(Λ[:,:,:,α], 'b t p β->(b t p) β')
μ, Σ = gaussian_posterior(
y_residual, μ_prior=μ_k_prior[α], Λ_prior=Λ_k_prior[α], X=λ,
)
w_α = random.multivariate_normal(next(key_seq), μ, Σ)
W_Q = W_Q.at[α].set(w_α)
return W_Q
值权重 W(v) 的采样
与
W
(v) 回归对应的协变量由下式给出:
Eq. (6a): 值的回归协变量。
与键和查询的条件分布不同,回归不再包含截距项。条件分布如下:
Eq. (6b): 用于吉布斯采样值权重的条件均值和协方差矩阵。
每行 β = 1… P, 单独采样,对应于单个输出维度。
def sample_weights_values(key, W_V, K, Q, L, X, y):
key_seq = KeySeq(key)
Ω = jnp.einsum('btn,bts,bsn,bsβ->btβ', Q, L, K, X)
Ω = rearrange(Ω, 'b t β->(b t) β')
y_pred = jnp.einsum('βγ,iγ->iβ', W_V, Ω)
y_flat = rearrange(y, 'b t α->(b t) α')
for β in range(p_features):
y_β = y_flat[:,β]
μ, Σ = gaussian_posterior(
y_β, μ_prior=μ_v_prior[β], Λ_prior=Λ_v_prior[β], X=Ω,
)
w_β = random.multivariate_normal(next(key_seq), μ, Σ)
W_V = W_V.at[β].set(w_β)
return W_V
整合
现在我们有了所有单独的采样步骤,可以训练模型。如 Eq. (2) 所示,在初始化参数后,重复调用之前定义的函数 kernel,直到收敛到后验分布。吉布斯采样器容易陷入局部最优,因此,需要检查似然函数,并比较不同的起点,以确保模型正确识别了后验分布的模式。最后,在收敛到后验分布后,收集 state 的样本来估计后验分布。
如果训练成功,模型训练过程可能如下所示:
贝叶斯注意力的训练过程。每条线显示一个不同的马尔可夫链蒙特卡罗模拟。红点是真实值。
至此,我们已经成功地从头开始实现了贝叶斯注意力机制!
未来工作
本文为构建一个简单的贝叶斯自注意力模块。一些计算方面的问题尚未解决,并有进一步优化的空间。首先,我们孤立地考虑了一个自注意力模块。复杂的建模任务(如语言建模)通常需要深度模型才能获得良好的性能。如何以完全贝叶斯的方式连接这些模块是一个有趣的开放性问题。其次,计算条件分布的汇总统计量需要进行矩阵求逆,这对于大型矩阵而言计算成本很高。第三,马尔可夫链蒙特卡罗 (MCMC) 模拟需要为每个 MCMC 步骤处理整个训练集,因此难以扩展到大型数据集。第四,本文的方法没有充分利用预测可以在线性时间 O(T) 内完成的特性,而是使用了朴素但简单的二次注意力方法。最后,我们假设输出 y 的方差是固定的,并且事先已知(即,一个超参数),这可以使用正态-逆伽马共轭关系进一步分析。