备注::本文首发于 https://bruceyuan.com/hands-on-code/hands-on-lora.html 后续有修改会更新于 博客中(实在是博客没人看啊)
- • GQA(Group Query Attention)的优点:效果损失小,推理的时候可以加速(来自于kvcache小,内存取数少)。
- • 仔细阅读 MHA, MQA 和 GQA的区别,就会发现 MHA 和 MQA 都是 GQA 的特殊表达形式
- • 三者可以用同一套代码,只需要修改【GQA】代码里面的
nums_key_value_head
参数就可 - •
nums_key_value_head
设置等于 1 就是 MQA - •
nums_key_value_head
设置等于 nums_head
就是 MHA
> 不喜欢看视频的可以直接看视频 里面有详细介绍 MHA MQA GQA 的区别: Group Query Attention
备注:以下代码省略了 attention_dropout attention_mask等情况的处理,真实实现过程中需要考虑。
import torch
import torch.nn as nn
import math
# 忽略了 attention_mask, attention_dropout;
classGroupQueryAttention(nn.Module):
def__init__(self, hidden_dim, nums_head, nums_key_value_head):
super().__init__()
assert hidden_dim % nums_head == 0# 可以整除
assert nums_head % nums_key_value_head == 0# N 个 query head 为一组
self.hidden_dim = hidden_dim
self.nums_head = nums_head
self.nums_key_value_head = nums_key_value_head
self.head_dim = hidden_dim // nums_head
# 初始化 qkv o
self.q_proj = nn.Linear(hidden_dim, nums_head * self.head_dim) # out feature_size (nums_head * head_dim)
# k v out shape (nums_key_value_head * head_dim)
self.k_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)
self.v_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)
self.o_proj = nn.Linear(hidden_dim, hidden_dim) # input_size nums_head * head_dim
defforward(self, X, attention_mask=None):
# X shape (batch, seq, hidden_dim)
batch_size, seq, _ = X.size()
# qkv projection
q = self.q_proj(X) # (batch, seq, hidden_dim)
k = self.k_proj(X)
v = self.v_proj(X)
# attention_weight 目标shape 是 (batch, nums_head, seq, seq)
q = q.view(batch_size, seq, self.nums_head, self.head_dim)
k = k.view(batch_size, seq, self.nums_key_value_head, self.head_dim)
v = v.view(batch_size, seq, self.nums_key_value_head, self.head_dim)
# 关注: nums_head 和 nums_key_value_head 的关系
q = q.transpose(1, 2) # (b, nums_head, seq, head_dim)
k = k.transpose(1, 2) # (b, nums_key_value_head, seq, head_dim)
v = v.transpose(1, 2) # (b, nums_key_value_head, seq, head_dim)
# k v repeat;
k = k.repeat_interleave(self.nums_head // self.nums_key_value_head, dim=1)
v = v.repeat_interleave(self.nums_head // self.nums_key_value_head, dim=1)
attention_score = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim)
attention_weight = torch.softmax(attention_score, dim=-1)
# (attention_mask 忽略) # 可以看前面的视频
output = attention_weight @ v # (b, nums_head, seq, head_dim)
# output projection 变成 (b, seq, hidden_dim)
output = output.transpose(1, 2).contiguous()
final_output = self