专栏名称: AINLP
关注AI、NLP相关技术,关注算法研发职位和课程;回复"文章"获取历史信息;双语聊天机器人"无名";中英翻译请输入:翻译 翻译内容;自动对联,请输入:上联 上联内容;调戏夸夸聊天机器人,请求夸、求赞;查询相似词,请输入: 相似词 词条
目录
相关文章推荐
新浪科技  ·  【#DeepSeek怎么看待概念股大涨#】“ ... ·  11 小时前  
i黑马  ·  第一批回老家的人,已经后悔县城有房了 ·  16 小时前  
药明康德  ·  加入新型“鸡尾酒”,CAR-T疗法再升级! ·  3 天前  
药明康德  ·  全球首款!脐血来源通用型CAR-T疗法IND ... ·  4 天前  
51好读  ›  专栏  ›  AINLP

手写大模型组件之Group Query Attention,从 MHA,MQA 到 GQA

AINLP  · 公众号  ·  · 2025-01-19 20:01

正文

 


  • 备注::本文首发于 https://bruceyuan.com/hands-on-code/hands-on-lora.html 后续有修改会更新于 博客中(实在是博客没人看啊)


  •  

  • • GQA(Group Query Attention)的优点:效果损失小,推理的时候可以加速(来自于kvcache小,内存取数少)。
  • • 仔细阅读 MHA, MQA 和 GQA的区别,就会发现 MHA 和 MQA 都是 GQA 的特殊表达形式
    • • 三者可以用同一套代码,只需要修改【GQA】代码里面的 nums_key_value_head 参数就可
    • • nums_key_value_head 设置等于 1 就是 MQA
    • • nums_key_value_head 设置等于 nums_head 就是 MHA

> 不喜欢看视频的可以直接看视频 里面有详细介绍 MHA MQA GQA 的区别: 

Group Query Attention

备注:以下代码省略了 attention_dropout attention_mask等情况的处理,真实实现过程中需要考虑。

import torch
import torch.nn as nn
import math

# 忽略了 attention_mask, attention_dropout; 
classGroupQueryAttention(nn.Module):
    def__init__(self, hidden_dim, nums_head, nums_key_value_head):
        super().__init__()
        assert hidden_dim % nums_head == 0# 可以整除
        assert nums_head % nums_key_value_head == 0# N 个 query head 为一组

        self.hidden_dim = hidden_dim
        self.nums_head = nums_head
        self.nums_key_value_head = nums_key_value_head
        self.head_dim = hidden_dim // nums_head

        # 初始化 qkv o
        self.q_proj = nn.Linear(hidden_dim, nums_head * self.head_dim)  # out feature_size (nums_head * head_dim)
        # k v out shape (nums_key_value_head * head_dim)
        self.k_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)

        self.o_proj = nn.Linear(hidden_dim, hidden_dim) # input_size nums_head * head_dim

    defforward(self, X, attention_mask=None):
        # X shape (batch, seq, hidden_dim)
        batch_size, seq, _ = X.size()

        # qkv projection
        q = self.q_proj(X)  # (batch, seq, hidden_dim)
        k = self.k_proj(X)
        v = self.v_proj(X) 

        # attention_weight 目标shape 是 (batch, nums_head, seq, seq)
        q = q.view(batch_size, seq, self.nums_head, self.head_dim)
        k = k.view(batch_size, seq, self.nums_key_value_head, self.head_dim)
        v = v.view(batch_size, seq, self.nums_key_value_head, self.head_dim)

        # 关注: nums_head 和 nums_key_value_head 的关系
        q = q.transpose(12# (b, nums_head, seq, head_dim)
        k = k.transpose(12# (b, nums_key_value_head, seq, head_dim)
        v = v.transpose(12)  # (b, nums_key_value_head, seq, head_dim)

        # k v repeat;
        k = k.repeat_interleave(self.nums_head // self.nums_key_value_head, dim=1)
        v = v.repeat_interleave(self.nums_head // self.nums_key_value_head, dim=1)

        attention_score = (q @ k.transpose(23)) / math.sqrt(self.head_dim)

        attention_weight = torch.softmax(attention_score, dim=-1)
        # (attention_mask 忽略) # 可以看前面的视频

        output = attention_weight @ v  # (b, nums_head, seq, head_dim)

        # output projection 变成 (b, seq, hidden_dim)
        output = output.transpose(12).contiguous()
        final_output = self





请到「今天看啥」查看全文