defkd_ce_loss(logits_S, logits_T, temperature=1): ''' Calculate the cross entropy between logits_S and logits_T :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels) :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels) :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,) ''' if isinstance(temperature, torch.Tensor) and temperature.dim() > 0: temperature = temperature.unsqueeze(-1) beta_logits_T = logits_T / temperature beta_logits_S = logits_S / temperature p_T = F.softmax(beta_logits_T, dim=-1) loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean() return loss
defhid_mse_loss(state_S, state_T, mask=None): ''' * Calculates the mse loss between `state_S` and `state_T`, which are the hidden state of the models. * If the `inputs_mask` is given, masks the positions where ``input_mask==0``. * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions. :param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*) :param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*) :param torch.Tensor mask: tensor of shape (*batch_size*, *length*) ''' if mask isNone: loss = F.mse_loss(state_S, state_T) else: mask = mask.to(state_S) valid_count = mask.sum() * state_S.size(-1) loss = (F.mse_loss(state_S, state_T, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count return loss
蒸馏attention矩阵则也要考虑mask,但注意这里要处理的维度是N*N:
defatt_mse_loss(attention_S, attention_T, mask=None): ''' * Calculates the mse loss between `attention_S` and `attention_T`. * If the `inputs_mask` is given, masks the positions where ``input_mask==0``. :param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) :param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) :param torch.Tensor mask: tensor of shape (*batch_size*, *length*) ''' if mask isNone: attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S) attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T) loss = F.mse_loss(attention_S_select, attention_T_select) else: mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len) valid_count = torch.pow(mask.sum(dim=2),2).sum() loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(2)).sum() / valid_count return loss
defcos_loss(state_S, state_T, mask=None): ''' * Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see `DistilBERT `_ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``. * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions. :param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*) :param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*) :param torch.Tensor mask: tensor of shape (*batch_size*, *length*) ''' if mask isNone: state_S = state_S.view(-1,state_S.size(-1)) state_T = state_T.view(-1,state_T.size(-1)) else: mask = mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(mask_dtype) #(bs,len,dim) state_S = torch.masked_select(state_S, mask).view(-1, mask.size(-1)) #(bs * select, dim) state_T = torch.masked_select(state_T, mask).view(-1, mask.size(-1)) # (bs * select, dim)
target = state_S.new(state_S.size(0)).fill_(1) loss = F.cosine_embedding_loss(state_S, state_T, target, reduction='mean') return loss