专栏名称: 学姐带你玩AI
这里有人工智能前沿信息、算法技术交流、机器学习/深度学习经验分享、AI大赛解析、大厂大咖算法面试分享、人工智能论文技巧、AI环境工具库教程等……学姐带你玩转AI!
目录
相关文章推荐
河北交通广播  ·  【992 | ... ·  昨天  
河北交通广播  ·  【992 | ... ·  昨天  
河北交通广播  ·  骇人听闻!“每20个80后就有1人去世”?— ... ·  3 天前  
河北交通广播  ·  【992 | ... ·  3 天前  
51好读  ›  专栏  ›  学姐带你玩AI

Llama改进之——均方根层归一化RMSNorm

学姐带你玩AI  · 公众号  ·  · 2024-09-25 18:21

正文

来源:投稿  作者:175
编辑:学姐

unset unset 引言 unset unset

在学习完GPT2之后,从本文开始进入Llama模型系列。

本文介绍Llama模型的改进之RMSNorm(均方根层归一化)。它是由Root Mean Square Layer Normalization论文提出来的,可以参阅其论文笔记1。

unset unset LayerNorm unset unset

层归一化(LayerNorm)对Transformer等模型来说非常重要,它可以帮助稳定训练并提升模型收敛性。LayerNorm针对一个样本所有特征计算均值和方差,然后使用这些来对样本进行归一化:

这里 表示某个时间步LN层的输入向量表示,向量维度为H;h实LN层的输出;g,b实两个可学习的参数。

为什么层归一化有用?一些解释如下:

  1. 减少内部协变量偏移(Internal Covariate Shift):内部协变量偏移是指在深度神经网络的训练过程中,每一层输入的分布会发生变化,导致网络的训练变得困难。层归一化通过对每一层的输入进行归一化处理,可以减少内部协变量偏移,使得每一层的输入分布更加稳定。
  2. 稳定化梯度:层归一化有助于保持每一层输出的均值和方差稳定,从而使得梯度的传播更加稳定。这有助于减少梯度消失或梯度爆炸的问题,提高梯度在网络中的流动性,加快训练速度。
  3. 更好的参数初始化和学习率调整:通过层归一化,每一层的输入分布被归一化到均值为0、方差为1的标准正态分布,这有助于更好地初始化网络参数和调整学习率。参数初始化与学习率调整的稳定性对模型的训练效果至关重要。
  4. 增强模型的泛化能力:层归一化可以减少网络对训练数据分布的依赖,降低了过拟合的风险,从而提高模型的泛化能力。稳定的输入分布有助于模型更好地适应不同数据集和任务。

unset unset RMSNorm unset unset

虽然LayerNorm很好,但是它每次需要计算均值和方差。RMSNorm的思想就是移除(1)式中μ 的计算部分:

同时在实现也可以移除平移偏置b。

单看(2)式的话,相当于仅使用x \pmb xx的均方根来对输入进行归一化,它简化了层归一化的计算,变得更加高效,同时还有可能带来性能上的提升。

unset unset 实现 unset unset

RMSNorm的实现很简单:

import torch
import torch.nn as nn
from torch import Tensor

class RMSNorm(nn.Module):
  def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(hidden_size))
  
  def _norm(self, hidden_states: Tensor) -> Tensor:
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    return hidden_states * torch.rsqrt(variance + self.eps)
  
  def forward(self, hidden_states: Tensor) -> Tensor:
    return self.weight * self._norm(hidden_states.float()).type_as(hidden_states)

torch.rsqrt torch.sqrt 的倒数; eps 是一个很小的数,防止除零; hidden_states.float() 确保了标准差计算的精确度和稳定性,然后在 forward 方法中,通过 .type_as(hidden_states) 将结果转换回原来的数据类型,以保持与输入张量相同的数据类型,使得归一化处理后的结果与输入数据类型一致。

下面通过一个简单的网络来测试一下:

import torch






请到「今天看啥」查看全文