专栏名称: 学姐带你玩AI
这里有人工智能前沿信息、算法技术交流、机器学习/深度学习经验分享、AI大赛解析、大厂大咖算法面试分享、人工智能论文技巧、AI环境工具库教程等……学姐带你玩转AI!
目录
相关文章推荐
跟宇宙结婚  ·  “跟宇宙结婚”音频节目总目录 ·  16 小时前  
跟宇宙结婚  ·  日常絮叨:上饿了么搜【跟宇宙结婚】领红包哟 ·  16 小时前  
跟宇宙结婚  ·  节目更新:跟宇宙结婚悄悄话 vol.245 ... ·  3 天前  
51好读  ›  专栏  ›  学姐带你玩AI

Transformer从菜鸟到新手(六)

学姐带你玩AI  · 公众号  ·  · 2024-06-05 11:17

正文

来源:投稿  作者:175
编辑:学姐

引言

上篇文章 介绍了如何在多GPU上分布式训练,本文介绍大模型常用的一种推理加速技术——KV缓存。

KV Cache

KV缓存(KV Cache)是在大模型推理中常用的一种技巧。我们知道在推理阶段,Transformer也只能像RNN一样逐个进行预测,也称为自回归。KV cahce是用在注意力阶段缓存key和value状态,具体的我们可以看图示:

上图(灰色区域表示掩码)是在没有KV缓存的情况下,在每一步生成时,我们都在重新计算相同的之前的Token注意力,而实际上我们只想计算新Token的注意力。

比如在最后一步,即第4步时,我们再次计算了之前步骤已经算好的Token注意力Attention1到Attention3,实际上这是没有必要的。

如果我们可以缓存之前计算好的Key和Value,那么就可以不需要这么多重复计算,每次只关注最新Token的注意力:

上图(蓝色表示缓存起来的Key或Value)在有KV缓存的情况下,每次只需要传入新的Query,然后计算新的Key和Value,并且与之前的Key和Value缓存矩阵拼接在一起,最后计算出最新Token的注意力。这就是KV缓存的主要思想。可以看到这里不再需要掩码。

这里描述的是自注意力中的KV缓存,如果是交叉注意力那么更简单,因为编码器生成的memory不会改变,因此可以直接缓存memory计算出来的Key和Value矩阵,而不需要拼接。

为了让我们的Transformer能支持KV缓存技术,我们需要进行一些改造。首先对 MultiHeadAttention 模块动刀,主要修改它的 forward 方法:

 def forward(
        self,
        query: Tensor,
        key_value: Tensor = None,
        mask: Tensor = None,
        past_key_value: Tuple[Tensor] = None,
        use_cache: bool = False,
        keep_attentions: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """

        Args:
            query (Tensor): (batch_size, q_len, d_model)
            key_value (Tensor, optional): (batch_size, k_len/v_len, d_model) key and value are same.
            mask (Tensor, optional): mask for padding or decoder. Defaults to None.
            past_key_value (Tuple[Tensor], optional): cached past key and value states. Defaults to None.
            use_cache (bool, optional): whether to use kv cache during inference. Defaults to False.
            keep_attentions (bool): whether to keep attention weigths or not. Defaults to False.

        Returns:
            output (Tensor): (batch_size, q_len, d_model) attention output
            present_key_value (Tuple[Tensor], optional): Cached present key and value states
        "
""

        if past_key_value is not None:
            assert self.is_decoder is True, "Encoder cannot cache past key value states"

        is_self_attention = key_value is None

        _query = query

        query = self._transform_and_split(self.q, query)

        if is_self_attention:
            # the 'self' attention
            key = self._transform_and_split(self.k, _query, is_key=True) # 即先进行Q/K/V转换,再拆分成多头
            value = self._transform_and_split(self.v, _query)
            key, value = self._concat_key_value(key, value, past_key_value) # 分情况拼接最新的key和value
        elif past_key_value is None:
            # the cross attention, key_value is memory
            key = self._transform_and_split(self.k, key_value, is_key=True)
            value = self._transform_and_split(self.v, key_value)
        else:
            # if is_self_attention == False and past_key_value is not None
            # key_value is memory and use cache(past_key_value not None) we do not need to calculate the key and value again because it was cached.
            # since memory will not change during inference.
            key, value = past_key_value

        if self.is_decoder and use_cache:
            # cache newest key and value
            present_key_value = (key, value)
        else:
            present_key_value = None

        attn_output = self.attenion(query, key, value, mask, keep_attentions)

        # Concat
        concat_output = self.merge_heads(attn_output)
        # the final liear
        # output (batch_size, q_len, d_model)
        output = self.concat(concat_output)

        return output, present_key_value

其参数发生了一些变换,由原来的 query,key,value 变成了 query,key_value

首先,这里将 key value 合并了起来,因为如果是自注意力 query=key=value ,而如果是交叉注意力 key=value=memory ,然后我们可以通过判断 key_value 是否为空来分辨本次计算的是自注意力还是交叉注意力;

其次,增加了两个参数 past_key_value use_cache use_cache 表示是否使用kv缓存,而 past_key_value 代表缓存的kv,注意缓存的k和v是不同的,因为它们经过了Key和Value矩阵映射。

然后我们深入方法内部,注意只有在推理阶段的Decoder中才能使用kv cache。

这里要分两种情况:自注意力和交叉注意力。

如果是自注意力直接使用传入的 query 就可以计算映射后的query,key,value,见代码行32到37。当使用缓存时,传入的 query 的长度一定是1,因为我们只需要为最新的 query 去计算注意力分数,算出一个预测的token。但还是需要当前 query 对应K和V矩阵映射后的key和value,将它们与历史(缓存)的拼接起来去计算新的token。

如果是交叉注意力,即 Decoder 中第二个注意力模块,其query来自decoder,而key和value(即memory)来自encoder。显然这个memory在整个推理阶段都是一样的,因此只需要计算一次,然后存入 past_key_value 缓存,后续就不再需要重复计算,对应上面的代码行47。

只有在使用缓存且为Decoder的时候才会缓存最新的key和value。

最后和之前一样计算注意力得分即可。

接下来修改DecoderBlock中的forward代码:

 def forward(
        self,
        tgt: Tensor,
        memory: Tensor,
        tgt_mask: Tensor = None,
        memory_mask: Tensor = None,
        past_key_value: Tuple[Tensor] = None,
        use_cache: bool = True,
        keep_attentions: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """

        Args:
            tgt (Tensor):   (batch_size, tgt_seq_len, d_model) the (target) sequence to the decoder block.
            memory (Tensor):  (batch_size, src_seq_len, d_model) the sequence from the last layer of the encoder.
            tgt_mask (Tensor, optional):  (batch_size, 1, tgt_seq_len, tgt_seq_len) the mask for the tgt sequence.
            memory_mask (Tensor, optional): (batch_size, 1, 1, src_seq_len) the mask for the memory sequence.
            past_key_values (Tuple[Tensor], optional): the cached key and value states. Defaults to None.
            use_cache (bool, optional): whether use kv cache during inference or not. Defaults to False.
            keep_attentions (bool): whether keep attention weigths or not. Defaults to False.


        Returns:
            tgt (Tensor): (batch_size, tgt_seq_len, d_model) output of decoder block
        "
""
        if past_key_value is not None:
            # first two elements in the past_key_value tuple are self-attention
            # past_key_value是一个元组,其中前2个元素为自注意力层的key和value
            # 后2个元素为交叉注意力层的key和value
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
        else:
            self_attn_past_key_value = None
            cross_attn_past_key_value = None

        x = tgt
        # 自注意力
        self_attn_outputs = self._sa_sub_layer(
            x,
            tgt_mask,
            self_attn_past_key_value,
            use_cache,
            keep_attentions,
        )
        # self attention output and present key value state
        # x和之前的输出一样,多了一个保存key和value的present_key_value_state
        x, present_key_value_state = self_attn_outputs
    # 交叉注意力
        cross_attn_outputs = self._ca_sub_layer(
            x,
            memory,
            memory_mask,
            cross_attn_past_key_value,
            use_cache,
            keep_attentions,
        )

        x = cross_attn_outputs[0]
        if present_key_value_state is not None:
            # append the cross-attention key and value states to present key value states   
            # 拼接注意力和交叉注意力中的key和value,得到元组的4个元素
            present_key_value_state = present_key_value_state + cross_attn_outputs[1]

        x = self._ff_sub_layer(x)
    # 别忘了返回
        return x, present_key_value_state

其中调用了两个子层对应的方法如下:

def _sa_sub_layer(
    self,
    x: Tensor,
    attn_mask: Tensor,
    past_key_value: Tensor,
    use_cache: bool,
    keep_attentions: bool,
) -> Tensor:
    residual = x
    x, present_key_value = self.masked_attention(
        query=self.norm1(x),
        past_key_value=past_key_value,
        use_cache=use_cache,
        mask=attn_mask,
        keep_attentions=keep_attentions,
    )
    x = self.dropout1(x) + residual
    return x, present_key_value

# cross attention sub layer
def _ca_sub_layer(
    self,
    x: Tensor,
    mem: Tensor,
    attn_mask: Tensor,
    past_key_value: Tensor,
    use_cache: bool,
    keep_attentions: bool,
) -> Tensor:
    residual = x
    x, present_key_value = self.cross_attention(
        query=self.norm2(x),
        key_value=mem,
        mask=attn_mask,
        past_key_value=past_key_value,
        use_cache=use_cache,
        keep_attentions=keep_attentions,
    )
    x = self.dropout2(x) + residual
    return x, present_key_value

这里改成了默认Pre-LN的形式,即先计算层归一化,最后再进行残差连接。

还有一个非常重要的修改是 PositionalEncoding

def forward(self, x: Tensor, position_ids: Union[int, list[int]] = None) -> Tensor:
    """

    Args:
        x (Tensor): (batch_size, seq_len, d_model) embeddings
        position_ids (Union[int, list[int]]): singe position id or list

    Returns:
        Tensor: (batch_size, seq_len, d_model)
    "
""
    if position_ids is None:
        position_ids = range(x.size(1))
    return self.dropout(x + self.pe[:, position_ids, :])

增加了一个参数表示位置id,我们知道如果使用缓存传入的 seq_len 恒等于1,但实际上它对应的位置ID是不停增加的,若不修改此处,默认通过 range(x.size(1)) 永远只能获取索引等于0时的位置编码,导致表现大幅下降。因此我们要传入当前的位置。

由于缓存只对Decoder生效,因此我们可以直接修改Transformer模块的decode方法:

def decode(
    self,
    tgt: Tensor,
    memory: Tensor,
    tgt_mask: Tensor = None,
    memory_mask: Tensor = None,
    past_key_values: Tuple[Tensor] = None,
    use_cache: bool = False,
    keep_attentions: bool = False,
) -> Tensor:
    """

    Args:
        tgt (Tensor):  (batch_size, tgt_seq_len) the sequence to the decoder.
        memory (Tensor): (batch_size, src_seq_len, d_model) the  sequence from the last layer of the encoder.
        tgt_mask (Tensor, optional): (batch_size, 1, 1, tgt_seq_len) the mask for the target sequence. Defaults to None.
        memory_mask (Tensor, optional): (batch_size, 1, 1, src_seq_len) the mask for the memory sequence. Defaults to None.
        past_key_values (Tuple[Tensor], optional): the cached key and value states. Defaults to None.
        use_cache (bool, optional): whether use kv cache during inference or not. Defaults to False.
        keep_attentions (bool, optional): whether keep attention weigths or not. Defaults to False.

    Returns:
        Tensor: output (batch_size, tgt_seq_len, tgt_vocab_size)
    "
""
    if past_key_values is None:
        past_key_values = [None] * len(self.decoder.layers)
        # 未使用缓存则传None
        position_ids = None
    else:
        # when use_cache we only care about the current position
        # 否则传入当前位置对应的ID
        position_ids = past_key_values[0][1].size(2)

    tgt_embed = self.dec_pos(self.tgt_embedding(tgt), position_ids)
    # logits (batch_size, tgt_seq_len, d_model)
    logits, past_key_values = self.decoder(
        tgt_embed,
        memory,
        tgt_mask,
        memory_mask,
        past_key_values,
        use_cache,
        keep_attentions,
    )

    return logits, past_key_values

代码增加了注释,大概意思是如果使用缓存,那么我们需要知道缓存的key或value对应的长度。而刚好seq_len恒等于1,因此不需要增加这个 seq_len past_key_values[0][1].size(2) 的值刚好就是我们想要的位置ID。

最后对贪心解码的实现进行一些小修改:

def _greedy_search(
    self,
    src: Tensor,
    src_mask: Tensor,
    max_gen_len: int,
    use_cache: bool,
    keep_attentions: bool,
):
    memory = self.transformer.encode(src, src_mask)

    batch_size = src.shape[0]

    device = src.device

    # keep track of which sequences are already finished
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)

    decoder_inputs = torch.LongTensor(batch_size, 1).fill_(self.bos_idx).to(device)

    input_ids = decoder_inputs

    eos_idx_tensor = torch.tensor([self.eos_idx]).to(device)

    finished = False

    past_key_values = None

    tgt_mask = None # 使用缓存的情况下可以传None,因为此时query可以看到所有的key。

    while True:
        if not use_cache:
            tgt_mask = self.generate_subsequent_mask(decoder_inputs.size(1), device)

        outputs = self.transformer.decode(
            input_ids,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=src_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            keep_attentions=keep_attentions,
        )

        logits = self.lm_head(outputs[0])

        past_key_values = outputs[1]

        next_tokens = torch.argmax(logits[:, -1, :], dim=-1)

        # finished sentences should have their next token be a pad token
        next_tokens = next_tokens * unfinished_sequences + self.pad_idx * (
            1 - unfinished_sequences
        )

        decoder_inputs = torch.cat([decoder_inputs, next_tokens[:, None]], dim=-1)

        # set sentence to finished if eos_idx was found
        unfinished_sequences = unfinished_sequences.mul(
            next_tokens.tile(eos_idx_tensor.shape[0], 1)
            .ne(eos_idx_tensor.unsqueeze(1))
            .prod(dim=0)
        )

        if use_cache:
            # only need the last tokens
            input_ids = next_tokens[:, None]
        else:
            input_ids = decoder_inputs

        # all sentences have eos_idx
        if unfinished_sequences.max() == 0:
            finished = True

        if decoder_inputs.shape[-1] >= max_gen_len:
            finished = True

        if finished:
            break

    return decoder_inputs

在使用缓存的时候 input_ids = next_tokens[:, None] ,这样保证每次只传入最新预测的Token。

最后在测试集上进行推理来验证下加了kv cache速度提升了多少:

$ python train.py 
source tokenizer size: 32000
target tokenizer size: 32000
Loads cached dataframes.
The model has 93255680 trainable parameters
begin train with arguments: {'d_model': 512, 'n_heads': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6, 'd_ff': 2048, 'dropout': 0.1, 'max_positions': 5000, 'source_vocab_size': 32000, 'target_vocab_size': 32000, 'attention_bias': False, 'pad_idx': 0, 'dataset_path''nlp-in-action/transformers/transformer/data/wmt''src_tokenizer_file''nlp-in-action/transformers/transformer/model_storage/source.model''tgt_tokenizer_path''nlp-in-action/transformers/transformer/model_storage/target.model''model_save_path''nlp-in-action/transformers/transformer/model_storage/best_transformer.pt''dataframe_file''dataframe.{}.pkl''use_dataframe_cache': True, 'cuda': True, 'num_epochs': 40, 'batch_size': 32, 'gradient_accumulation_steps': 1, 'grad_clipping': 0, 'betas': (0.9, 0.98), 'eps': 1e-09, 'label_smoothing': 0, 'warmup_steps': 4000, 'warmup_factor': 0.5, 'only_test': True, 'max_gen_len': 60, 'use_wandb': False, 'patient': 5, 'calc_bleu_during_train': True, 'use_kv_cache': False}
total train steps: 221200
  0%|                                                                                                                                                                        | 0/1580 [00:00, ?it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1580/1580 [17:25<00:00,  1.51it/s]
TEST loss=0.0021 bleu score: 26.74


$ python train.py
source tokenizer size: 32000
target tokenizer size: 32000
Loads cached dataframes.
The model has 93255680 trainable parameters
begin train with arguments: {'d_model': 512, 'n_heads': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6, 'd_ff': 2048, 'dropout': 0.1, 'max_positions': 5000, 'source_vocab_size': 32000, 'target_vocab_size': 32000, 'attention_bias': False, 'pad_idx': 0, 'dataset_path''transformers/transformer/data/wmt''src_tokenizer_file''transformers/transformer/model_storage/source.model''tgt_tokenizer_path''transformers/transformer/model_storage/target.model''model_save_path''transformers/transformer/model_storage/best_transformer.pt''dataframe_file''dataframe.{}.pkl''use_dataframe_cache': True, 'cuda': True, 'num_epochs': 40, 'batch_size': 32, 'gradient_accumulation_steps': 1, 'grad_clipping': 0, 'betas': (0.9, 0.98), 'eps': 1e-09, 'label_smoothing': 0, 'warmup_steps': 4000, 'warmup_factor': 0.5, 'only_test': True, 'max_gen_len': 60, 'use_wandb': False, 'patient': 5, 'calc_bleu_during_train': True, 'use_kv_cache': True}
total train steps: 221200
  0%|                                                                                                                                                                        | 0/1580 [00:00, ?it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1580/1580 [13:37<00:00,  1.93it/s]
TEST loss=0.0021 bleu score: 26.74

这里加载之前训练效果最好的模型,可以看到计算出来的BLEU 分数都为26.74,使用kv cache耗时(单GPU推理)由17:25降到了13:37,快了接近4分钟。

kv cache实际上是一种空间换时间的技术,那么它会占多大的空间呢?

从上面代码可以看到,我们为每个Token都保存了4个向量,2个k和2个v,那么保存的字节数为:

第一个4表示有4个向量;第二个4表示假设在float-32下需要4个字节;为每层都保存kv cahce;每个向量的大小为

在base设定下(层数=6,d_model=512)批大小等于1,一个Token需要48kb的显存,假设最终生成512个长度的序列时,那么需要24M的显存。看起来不大,但对于大模型的参数量来说,显存占用就显著上升了。

我们这次结合多GPU和KV缓存进行训练:







请到「今天看啥」查看全文