Linear Attention的论文如下: Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention:
https://arxiv.org/pdf/2006.16236.pdf
。官方给出实现代码地址:
https://github.com/idiap/fast-transformers
。虽然这个仓库是Linear Attention的原始实现,但基于这个codebase也引出了后续的一系列线性Attention的工作比如:Efficient Attention: Attention with Linear Complexities (
https://arxiv.org/abs/1812.01243
), Linformer: SelfAttention with Linear Complexity(
https://arxiv.org/abs/2006.04768
), Reformer: The Efficient Transformer (
https://arxiv.org/abs/2001.04451
) 等等。
这篇文章是对Linear Attention的forward cuda kernel进行解析, 在此之前我先基于论文的3.2节对Linear Attention做一个复述, 明确这里要计算的是什么。
# 检查并确保 attn_mask 是全部为一的,这表明这种注意力不支持任意的注意力掩码。 if not attn_mask.all_ones: raise RuntimeError(("LinearAttention does not support arbitrary " "attention masks")) K = K * key_lengths.float_matrix[:, :, None, None]
# 这个私有方法用于确保 Q 和 K 张量的大小兼容,通过对 K 进行切片或填充以匹配 Q 的大小。 def _make_sizes_compatible(self, Q, K): """Either slice or pad K in case that the sizes do not match between Q and K.""" N, L, H, E = Q.shape _, S, _, _ = K.shape if L == S: return Q, K
if L < S: return Q, K[:, :L, :, :]
if L > S: return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1)
# 检查 attn_mask 是否为下三角因果掩码,并应用键长度掩码。 if not attn_mask.lower_triangular: raise RuntimeError(("CausalLinearAttention only supports full " "lower triangular masks")) K = K * key_lengths.float_matrix[:, :, None, None]
# 确保query和key的大小(长度)兼容。 Q, K = self._make_sizes_compatible(Q, K)
# TODO: Shall we divide the Q and K with a relatively large number to # avoid numerical instabilities in computing the denominator? # We used to divide each with the max norm of all q and k but # that seems relatively costly for a simple normalization.
# Compute the normalizers Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
# Compute the unnormalized result V = causal_linear( Q, K, values )
// 这个函数计算两个向量 a 和 b 的外积(a 和 b 的转置的点积)并将结果保存在 out 中。 // a 是一个长度为 A 的向量,b 是一个长度为 B 的向量。 // 外积的结果是一个 AxB 的矩阵。 inline void vvt_dot(float *a, float *b, float *out, int A, int B) { for (int i=0; ifloat * bi = b; for (int j=0; j *out += (*a) * (*bi); out++; bi++; } a++; } }
// 这个函数实现了向量 v 和矩阵 m 的乘积,并将结果保存在 out 中。 // v 是一个长度为 A 的向量,m 是一个 AxB 的矩阵。 // 结果是一个长度为 B 的向量。 inline void vm_dot(float *v, float *m, float *out, int A, int B) { // TODO: Consider removing the zeroing part and assuming out already // contains 0s for (int i=0; i out[i] = 0; }
// 这个函数计算向量 v 和矩阵 m 转置的乘积,并将结果保存在 out 中。 // v 是一个长度为 B 的向量,m 是一个 AxB 的矩阵。 // 结果是一个长度为 A 的向量。 inline void vmt_dot(float *v, float *m, float *out, int A, int B) { for (int i=0; ifloat *vi = v; float s = 0; for (int j=0; j s += (*vi) * (*m); vi++; m++; } // TODO: Should we be aggregating? See the comment on vm_dot. *out = s; out++; } }
// 这个函数计算查询(queries)、键(keys)和值(values)的因果掩码点积。 // N、H、L 和 E 分别代表 batch 大小、头数、序列长度和特征维度。M 是value的特征维度。 // 计算公式为:V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} void causal_dot_product( const torch::Tensor queries, const torch::Tensor keys, const torch::Tensor values, torch::Tensor product ) { // Extract some shapes int N = queries.size(0); int H = queries.size(1); int L = queries.size(2); int E = queries.size(3); int M = values.size(3);
// Create accessors for all the arguments auto qa = queries.accessor(); auto ka = keys.accessor(); auto va = values.accessor(); auto pa = product.accessor();
// 使用 OpenMP 实现并行计算,增加计算效率。 #pragma omp parallel for collapse(2) for (int n=0; n for (int h=0; h auto kv = torch::zeros({E, M}, queries.options()); float *kvp = kv.data_ptr(); for (int l=0; l // 该函数首先计算 K 和 V 的外积(vvt_dot),然后计算 Q 和这个外积的结果(vm_dot)。 vvt_dot( &ka[n][h][l][0], &va[n][h][l][0], kvp, E, M ); vm_dot( &qa[n][h][l][0], kvp, &pa[n][h][l][0], E, M ); } } } }
可以清晰的看到这里的计算过程就是
,也就是先计算 K 和 V 的外积(vvt_dot),然后计算 Q 和这个外积的结果(vm_dot)。为了更高效,还使用了openmp做并行计算。
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "causal_dot_product", &causal_dot_product, "Compute the weighted sum of values but attending only to previous " "values." ); }
__global__ void causal_dot_product_kernel( const float_accessor queries, const float_accessor keys, const float_accessor values, float_accessor result, const int N, const int H, const int L, const int E, const int M ) { int n = blockIdx.y; // 确定 batch 所在的id int h = blockIdx.z; // 确定 attention 所在的头的id
int e_start = blockIdx.x * E_BLOCK_SIZE; // 确定query的特征维度的开始位置 int m = threadIdx.x % M; // 确定 value 的特征维度 id
// N、H、L 和 E 分别代表 batch 大小、头数、序列长度和特征维度。M 是value的特征维度。 int N = queries.size(0); int H = queries.size(1); int L = queries.size(2); int E = queries.size(3); int M = values.size(3);
// 每个Block处理E_BLOCK_SIZE(=8)个隐藏层的元素 // 一共需要blocks_per_sequence这么多个Block来进行处理 // 注意:这里的blocks_per_sequence还要乘以N和H才是真正的Block个数 const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
// 每个Block固定有M个线程 dim3 blockDim(M, 1, 1); dim3 gridDim(blocks_per_sequence, N, H); // 每个Block固定使用 E_BLOCK_SIZE(=8)* M个float这么大的shm const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
causal_dot_product_kernel<<>>( queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), N, H, L, E, M ); }
总的来说,这个kernel使用了
(E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE * N * H
个Block,并且每个Block里面有M个线程,并且每个Block里面开了一个长度为E_BLOCK_SIZE * M的共享内存来存储当前Block计算出来的KV乘积。
// 从query张量 q 中提取出维度信息(如批量大小 N、头数 H、序列长度 L 和特征维度 E)并设置到 params 结构体中。 int N = q.size(0); int H = q.size(1); int L = q.size(2); int E = q.size(3); int M = v.size(3);
// Make sure that we are using the correct GPU device torch::DeviceGuard _guard(queries.device());
// Make sure the inner-most dimension of the tensors is packed. // 使用 assert 语句检查张量的最内层维度(即特征维度)是否是packed的 assert(queries.stride(3) == 1); assert(keys .stride(3) == 1); assert(values .stride(3) == 1); assert(product.stride(3) == 1);
// 提取张量的维度信息,如批量大小、头数、序列长度等。 int N = queries.size(0); int H = queries.size(1); int L = queries.size(2); int E = queries.size(3); int M = values.size (3);
// The structure of params. Lmha_params<float> params; // 调用 set_params 函数来初始化 Lmha_params<float> 结构体。 set_params(params, queries, keys, values, product);
// 确定lmha kernel需要的共享内存大小 template< int E, typename Params > static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) { int M = round_up(params.M, 4); return 2*E + 2*M; }
// E: 代表特征维度的大小。 // THREADS_PER_HEAD: 每个 attention 头分配的线程数。 // GO_BACKWARD: 布尔类型的模板参数,指示是进行前向计算还是反向传播。 template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > int lmha_(const Lmha_params<float> ¶ms) { // 调整 M 维度: M 是 params.M 的调整值,向上取整到最接近的 4 的倍数。这种调整可能是出于内存对齐或性能优化的考虑。 int M = round_up(params.M, 4);
可以看到这个 kernel 将会启动
H*B
个Block,每个Block里面的线程数由 E 以及
M*THREADS_PER_HEAD
共同决定。注意,E 是 query 的隐藏层大小,而 M 是 value 的隐藏层大小。此外,还通过
smem_buffer_elts_
函数确定了这个kernel需要的共享内存大小,这个函数里面的
2*
表示的是double buffering。
接下来就是最关键的 lmha_kernel 的实现了,阅读之前还是要先想着我们要计算的东西是
,也就是先计算 K 和 V 的外积(vvt_dot),然后计算 Q 和这个外积的结果(vm_dot)。
// 确定lmha kernel需要的共享内存大小 template< int E, typename Params > static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) { int M = round_up(params.M, 4); return 2*E + 2*M; }
// E: 特征维度的大小。 // THREADS_PER_HEAD: 每个 attention 头分配的线程数。 // GO_BACKWARD: 布尔类型的模板参数,指示是进行前向计算还是反向传播。 // params: Lmha_params<float> 类型的结构体,包含多头自注意力所需的各种参数。 template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > __global__ void lmha_kernel(Lmha_params<float> params) {
// Make sure E is a multiple of 4. static_assert(E % 4 == 0, "");
// The amount of shared memory per buffer (2 buffers for double-buffering). const int smem_buffer_elts = smem_buffer_elts_(params); // The M dimension for shared memory. const int M = round_up(params.M, 4);
// Shared memory to store Q, K and V. Size is 2*smem_buffer_elts. // 分配共享内存用于存储 Q、K、V(query、key、value)。 // 注意上面的smem_buffer_elts是 (2E + 2M) extern __shared__ float smem_[];
// The index of the shared memory buffer (for double-buffering). // 使用 smem_curr 管理双缓冲区策略,以平滑地在不同迭代间交换共享内存。 int smem_curr = 0;
// 确定处理的序列(bi)和头(hi)。 const int bi = blockIdx.y; const int hi = blockIdx.x;
// 线程的id const int tidx = threadIdx.x;
// 根据线程索引(tidx)和 params 中的 stride 计算 Q、K、的偏移量 // The offset to the position loaded by the thread in Q. int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx; // The offset to the position loaded by the thread in K. int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
// Determine the base pointers for Q and K. const float *ptr_q = ¶ms.q[offset_q]; const float *ptr_k = ¶ms.k[offset_k];