Args: query (Tensor): (batch_size, q_len, d_model) key_value (Tensor, optional): (batch_size, k_len/v_len, d_model) key and value are same. mask (Tensor, optional): mask for padding or decoder. Defaults to None. past_key_value (Tuple[Tensor], optional): cached past key and value states. Defaults to None. use_cache (bool, optional): whether to use kv cache during inference. Defaults to False. keep_attentions (bool): whether to keep attention weigths or not. Defaults to False.
Returns: output (Tensor): (batch_size, q_len, d_model) attention output present_key_value (Tuple[Tensor], optional): Cached present key and value states """
if past_key_value is not None: assert self.is_decoder is True, "Encoder cannot cache past key value states"
is_self_attention = key_value is None
_query = query
query = self._transform_and_split(self.q, query)
if is_self_attention: # the 'self' attention key = self._transform_and_split(self.k, _query, is_key=True) # 即先进行Q/K/V转换,再拆分成多头 value = self._transform_and_split(self.v, _query) key, value = self._concat_key_value(key, value, past_key_value) # 分情况拼接最新的key和value elif past_key_value is None: # the cross attention, key_value is memory key = self._transform_and_split(self.k, key_value, is_key=True) value = self._transform_and_split(self.v, key_value) else: # if is_self_attention == False and past_key_value is not None # key_value is memory and use cache(past_key_value not None) we do not need to calculate the key and value again because it was cached. # since memory will not change during inference. key, value = past_key_value
if self.is_decoder and use_cache: # cache newest key and value present_key_value = (key, value) else: present_key_value = None
Args: tgt (Tensor): (batch_size, tgt_seq_len, d_model) the (target) sequence to the decoder block. memory (Tensor): (batch_size, src_seq_len, d_model) the sequence from the last layer of the encoder. tgt_mask (Tensor, optional): (batch_size, 1, tgt_seq_len, tgt_seq_len) the mask for the tgt sequence. memory_mask (Tensor, optional): (batch_size, 1, 1, src_seq_len) the mask for the memory sequence. past_key_values (Tuple[Tensor], optional): the cached key and value states. Defaults to None. use_cache (bool, optional): whether use kv cache during inference or not. Defaults to False. keep_attentions (bool): whether keep attention weigths or not. Defaults to False.
Returns: tgt (Tensor): (batch_size, tgt_seq_len, d_model) output of decoder block """ if past_key_value is not None: # first two elements in the past_key_value tuple are self-attention # past_key_value是一个元组,其中前2个元素为自注意力层的key和value # 后2个元素为交叉注意力层的key和value self_attn_past_key_value = past_key_value[:2] cross_attn_past_key_value = past_key_value[2:] else: self_attn_past_key_value = None cross_attn_past_key_value = None
x = cross_attn_outputs[0] if present_key_value_state is not None: # append the cross-attention key and value states to present key value states
# 拼接注意力和交叉注意力中的key和value,得到元组的4个元素 present_key_value_state = present_key_value_state + cross_attn_outputs[1]
x = self._ff_sub_layer(x) # 别忘了返回 return x, present_key_value_state
Args: tgt (Tensor): (batch_size, tgt_seq_len) the sequence to the decoder. memory (Tensor): (batch_size, src_seq_len, d_model) the sequence from the last layer of the encoder. tgt_mask (Tensor, optional): (batch_size, 1, 1, tgt_seq_len) the mask for the target sequence. Defaults to None. memory_mask (Tensor, optional): (batch_size, 1, 1, src_seq_len) the mask for the memory sequence. Defaults to None. past_key_values (Tuple[Tensor], optional): the cached key and value states. Defaults to None. use_cache (bool, optional): whether use kv cache during inference or not. Defaults to False. keep_attentions (bool, optional): whether keep attention weigths or not. Defaults to False.
Returns: Tensor: output (batch_size, tgt_seq_len, tgt_vocab_size) """ if past_key_values is None: past_key_values = [None] * len(self.decoder.layers) # 未使用缓存则传None position_ids = None else: # when use_cache we only care about the current position # 否则传入当前位置对应的ID position_ids = past_key_values[0][1].size(2)
# finished sentences should have their next token be a pad token next_tokens = next_tokens * unfinished_sequences + self.pad_idx * ( 1 - unfinished_sequences )
# set sentence to finished if eos_idx was found unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_idx_tensor.shape[0], 1) .ne(eos_idx_tensor.unsqueeze(1)) .prod(dim=0) )
if use_cache: # only need the last tokens input_ids = next_tokens[:, None] else: input_ids = decoder_inputs
# all sentences have eos_idx if unfinished_sequences.max() == 0: finished = True
if decoder_inputs.shape[-1] >= max_gen_len: finished = True