Transformer decoder block 在计算上可以看做六个操作的总和:pre-proj,attn,post-proj,ffn_ln1,ffn_ln2,others(比如说 layer normalization,activation functions,residual connection..)
Transformer 的输出可以视为一个 tensor X of shape [B, L, H]。其中 B 是 batch size,L 是 input tokens length,H 是模型的 embedding size。
Prefill 的第一步做 pre-proj。从数学上,就是简单的线性运算,分别用三个大小为 [H, H] 的矩阵 W^Q,W^K,W^V 和 X 做乘积。从计算上,就是输入 X 和一个大小为 [H, 3H] 的矩阵相乘。
attn 操作在计算上的输入是 Q, K, V,而输出 Y 仍旧是大小为 [B, L, H] 的 tensor。post-proj 采用大小为 [H, H] 的 W_0 矩阵和 Y 相乘,输出结果 Z 的大小仍旧为 [B, L, H]。
ffn_ln1 和 ffn_ln2 在计算上的输入是 Z。ffn_ln1 中,Z 和大小为 [H, H’] 的矩阵相乘,得到大小为 [B, L, H’] 的 tensor,接着和大小为 [H’, H] 的矩阵相乘,再投影回去得到一个大小仍旧为 [B, L, H] 的 tensor。上式中,H’ 为模型的 second hidden dimension。
Decode 阶段和 Prefill 阶段的操作完全一致,不过每次只会生成一个 token 并且输入给下个阶段。采用 KV Cache 后,实质上的输入是上次生成的那一个 token,输入的 tensor 大小是 [B, 1, H](input tokens 的长度为 1)。
每个 token 的 KV Cache 大小均为 [1, H]。