专栏名称: 深度学习自然语言处理
一个从大三就接触NLP的小小NLPer,本公众号每天记录自己的一点一滴,每篇文章最后也有托福单词等新知识,学技术同时,也一点一滴积累额外的知识。期待与你在知识的殿堂与你相遇!
目录
相关文章推荐
新北方  ·  幸福之旅:开往春天的专列① 港澳 + ... ·  1小时前  
新北方  ·  再冷一天,辽宁气温即将大反转! ·  6 小时前  
新北方  ·  来东北,不到大连“血”后悔!福利来了 ·  2 天前  
51好读  ›  专栏  ›  深度学习自然语言处理

无痛理解旋转位置编码RoPE

深度学习自然语言处理  · 公众号  ·  · 2024-12-14 18:14

正文

作者:车中草同学
原文:https://zhuanlan.zhihu.com/p/8306958113

排版:青稞AI

本篇博客目的:

Rope比较难理解主要是因为它涉及数学比较多:复变函数,欧拉公式,两角和公式等。以及不同地方(论文,博客)引用复数的不同表达形式(原论文使用复变函数的指数形式),导致忘记复变函数的数学基础的人很难将不同形式转换。

LLM所有 细分方向 群+ ACL25/ICML25/NAACL25 投稿群-> LLM所有细分领域群、投稿群从这里进入!

本篇博客希望用通俗易懂的语言把RoPE(旋转位置编码)的论文和原始博客串起来,让大家无痛理解,另外,在掌握了RoPE之后,也用来分析下长文本外推的两种技术。(位置内插以及NTK-Aware)

本篇博客包含以下内容:

  1. 1. 为了通俗易读,本篇博客在附录介绍了RoPE涉及到的数学基础(建议可以先扫一眼,或者看到看不懂的公式的时候,去看下),另外,在介绍时,除了解释原论文和博客中的:复数函数的指数形式和矩阵形式,还会从 几何意义(旋转)的角度来介绍 ,更加通俗易懂点。
  2. 2. 代码实现:我们还介绍了RoPE的代码实现(以及为啥常见的大模型代码RoPE实现和论文中表达有所差别, GPT-J style,GPT-NeoX )。
  3. 3. 超参数分析:会用代码实际地进行超参数的分析。(base对远程衰减的影响, 的变化);
  4. 4.也介绍两个扩展长文本的技术:位置内插(PI),NTK-Aware(拿钟表举例说明为什么:高频外推,低频内插)。

问题:

其实写这篇文章的过程中,我一直有很多问题,逐步探索,有些有了答案,有的还没有答案。希望你看完能获得自己的答案。

  • 1.transformer的正余弦位置编码(sinusoidal)不是也有外推功能(比起BERT的PE),不是也可以表示相对距离?(也有远程衰减),为什么大模型都使用RoPE?
  • 2.RoPE的base,到底有什么作用,到底在控制什么?(很多长文本的工作都在调整它)
  • 3.RoPE的2维很好理解(从几何上很漂亮),为什么能从2维扩展到n维?
  • 4.RoPE的qwen代码里的实现为什么和理论不一样?(最后发现是两种实现:GPT-J 和 GPT-NeoX )他们俩等价嘛?

进一步问题:

  • 1.长度外推中,位置编码的OOD问题和RoPE的OOD问题,分别是什么?
  • 2.RoPE是绝对位置编码,那么我们在训练过程中到底在训练什么?
  • 3.如何免训练对RoPE进行外推?或者训练少量长文本让其外推?
  • 4.2维rope 从几何意义很完美,就是对向量不改变模长的逆时针旋转一个角度。但要从几何角度理解 n 维在做什么(感觉要从傅里叶角度来理解,d/2 个向量在分别旋转不同的角度,整体代表什么?)
  • 5.把训练过程和转圈是如何联系起来?(高频转的圈数多,低频转的圈数少)

感悟:

这篇是受到猛猿大佬的激励,认真看了很久,写的一篇文章,希望能带给大家有所帮助。

写完最大的领悟就是:纸上得来终觉浅,绝知此事要躬行。那些模棱两可,草草看过就了事,以为自己理解了,但完全不理解的知识,其实有很多。

感觉深入的做一件事的好处就是:你后面会更有耐心的做事情,你知道即使在难,只要慢慢的深入理解,和别人交流,你也可以掌握。

另外:本人工作中并没有长文本的实践,这篇文章如果有什么不对的地方,欢迎大家在评论区指教。

1.位置编码的作用:

自然语言中位置是很重要的信息。(不管是绝对位置还是相对位置)因此,我们希望加入句子的位置信息到模型中来。

对于nlp任务而言,因为Transformer 的编码器结构本身无法识别序列中元素的顺序,因此,我们希望使用 合适的位置编码(Position Encoding, PE)引入一定的位置先验信息 。(不管是绝对位置和相对位置)

比如 远程衰减的先验 :即位置相近的Token平均来说获得更多的注意力,而距离比较远的Token平均获得更好的注意力。

2.初步定义:

对于Transformer架构,我们希望在词向量计算self-attention的时候,引入位置编码。因此有如下的定义:

假设有句子 是一个句子每个token的词向量(多维向量)。(i从1到N),那么,带有位置信息的self-attention机制是:

公式1:将位置信息m融入词向量

解释:

  • 1. 是第m个位置的词向量,没有位置信息。
  • 2. , 是一个函数,首先给词向量加入位置信息,然后转换他们为 , , 表示。(通过线性层)

计算 的重要性(attention weights),然后对 归一化

公式2:计算m对q的重要性(attention weights),然后进行归一化。

3.绝对位置编码:

绝对位置编码:不同词的位置编码仅由其位置唯一决定。

1 transformer的正余弦位置编码(sinusoidal)

对于公式1中的qkv,在transformer里面,定义为下面的公式。

公式3:sinusoidal的f_q

解释:

只在最底层把 进行相加,来将位置信息融入到词向量 中。

其中 定义如下:

公式4:sinusoidal的p_i

解释:

  • 1.其中 k 是句子中token的位置,t是位置向量的维度index。偶数的时候用sin,奇数的时候用cos。
  • 2.d是词向量的维度。

sinusoidal的优点:

  • 1.理论上可以长度外推。(因为位置编码不像bert一样,是需要训练的固定维度的psotion encoder([512,768]))
  • 2.不同位置的 位置编码内积 具有远程衰减。

缺点:

1.不同位置的位置编码内积确实会有远程衰减。但是它和词向量相加后,经过attention层,还会有q和k的线性变化,经过这个线性变化后,导致q和k的点积没有了远程衰减。

举例说明:

应用了sinusoidal位置编码的q和k点积如下:

公式5:应用了sinusoidal位置编码的q和k点积。(其中△t是相对t变化量)

我们只观察 的变化,其中可以发现,经过attention层之后,位置编码真正起作用的,不是两者的两个位置编码的点积,还要引入两个线性变化。

图1:d=128,是两个sinusoidal位置编码的点积,下面两个是引入两个随机初始化的线性变化的矩阵的点积。

可以看出,两个不同位置sinusoidal位置编码的点积,确实有很好的远程衰减。但是引入线性变化后,远程衰减的性质受到了极大的破坏。

因此,即使sinusoidal位置编码本身拥有很好的形式,但是其只是在底层和词向量相加,这导致经过Attention层后,真正起作用的不是两个位置编码的乘积,而还要带来一个线性变化,而引入这个线性变化后,其丢失了远程衰减这个性质。

此外:

1. transformer的作者表示 :他们还尝试使用学习的位置嵌入 ,发现这两个版本产生了几乎相同的结果。我们选择正弦版本,因为它可能允许模型预测到时候处理比训练期间遇到的序列长度更长的句子。

2 bert的position encoder

bert的psotion encoder:但它是随机初始化一个embedding[512,768],然后作为参数来学习的,也就是说它没有作出任何假设,但允许模型学到位置信息。

应用方式:

1.在transformer底层和词向量相加,输入到transformer中去。

存在的问题:

1.没有外推能力。(训练长度如果是512,那么预测不能大于512。)

4.相对位置编码:

除了绝对位置编码外,在NLP领域,还有另外一种相对位置编码。由于自然语言一般更依赖于相对位置,所以相对位置编码通常也有着优秀的表现。

相对位置编码 :在计算Attention的时候考虑当前位置与被Attention的位置的相对距离。(想办法微调一下Attention结构,使得它有能力分辨不同位置的Token。)

这里为了介绍RoPE的特点,简单介绍下其思想,不展开说具体的方法。

5.RoPE的原理:

一言以概之:而RoPE(旋转位置编码)则是通过绝对位置编码的方式实现相对位置编码 。(结合了相对位置编码和绝对位置编码)

苏神在论文中说:

Specifically, the proposed RoPE encodes the absolute position with a rotation matrix and meanwhile incorporates the explicit relative position dependency in self-attention formulation

翻译为中文是:

RoPE使用旋转矩阵编码绝对位置信息,同时可以将显示的相对位置信息融入自注意力计算中

其优点

1.包括序列长度灵活性。(方便外推性)

2.非显式的长期衰减性。(随着相对距离的增加而衰减attention得分)

3.可以应用于线性自注意力模型。

1 RoPE的定义:

为了包含相对位置信息,rope希望把 的内积操作,可以编码成一个 以及相对位置m-n的函数g。
公式6:RoPE希望位置为m的q和位置为n的k内积找到一个函数g,其中g的自变量有m-n,x_m和x_n。

解释:

  • 1.<>表示 进行内积操作。
  • 2.如果找到这样一个函数g,那么 的内积操作,也会蕴含相对位置m-n。
公式6:RoPE希望位置为m的q和位置为n的k内积找到一个函数g,其中g的自变量有m-n,x_m和x_n。

解释:

  • 1.<>表示 进行内积操作。
  • 2.如果找到这样一个函数g,那么 的内积操作,也会蕴含相对位置m-n。

为什么要找到这样一个函数呢?

因为我们希望 进行内积操作,受到他们相对位置的影响。(符合自然语言的习惯)

1.两个词相对位置近的时候(m-n小),内积可以大一点。

2.两个词相对位置远的时候(m-n大),内积可以小一点。(长度衰减)

2 RoPE在词向量是二维的情况下:

维度为2时,可以找到如下的函数f和g:(这里是复数的指数形式:(苏神博客以及论文中的版本))

注意:

1.如果感觉这里不理解的话,也没关系,往下看矩阵形式以及几何形式就很容易理解了,也可以参考附录2:复习下复数的不同形式之间的转换。

公式7:找到的满足公式6的f和g

解释:

  • 1.其中Re是复数的实部, 的共轭复数。
  • 2.m是词在句子中的位置, 是一个非零的常量。

我们首先将 其展开成矩阵乘法的形式

公式8:f_q(x_m,m)的矩阵形式

解释:

  • 1. 是一个非零的常数。 是词向量的第一维度,m是位置。
  • 2.可以看出其实就是 乘了一个旋转矩阵。(旋转位置编码的来源)
  • 3.几何意义:给 旋转其索引的 倍数。(逆时针旋转其索引的 倍数)

的几何意义解释

我们首先回顾下:一个列向量(1,0),乘一个旋转矩阵度数为45度。公式以及图片:

图片2: 旋转矩阵的图示包括一个蓝色箭头,表示原始向量 [1,0];旋转角度 $\theta$ 为 45 度,红色箭头表示应用二维旋转矩阵得到的旋转向量 $[\cos(\theta), \sin(\theta)]$。

图片来源: RoPE: Addressing the Position Encoding Flaw in Transformer Models [1]

解释:

1.可以看成对向量逆时针旋转45度。

2.向量的模长不变,只有角度发生变化。

将几何意义应用到上面公式8 中,我们发现,其实公式8就是将二维的向量q,逆时针旋转 度,并且只改变方向,不会改变q的模长

接下来,我们需要再看下,在二维情况下,应用了旋转位置编码的 的点积会发生什么变化呢?

首先定义 是:

表示绝对位置,从0,1,2…


公式9:应用了旋转位置编码的 和q_m和k_n 的点积

这里的推导涉及:附录3的两角和公式。

可以看出:

1.RoPE确实是用绝对位置编码的形式(红色部分),实现了相对位置编码(绿色部分),这是旋转位置编码设计的巧妙之处

3 RoPE在词向量是d维的情况下:

维度为d时,可以找到如下的函数f

公式10:扩展到d维的f函数

其中的 是:

公式11:扩展到d维的R定义

解释:

  • 1.首先解释下从2维扩展到n维的做法:对于d维度的q向量,我们分为d/2组,每两个相邻维度为一组,共同旋转一个角度
  • 2.其中 ,和transforfmer的位置编码一样。( 从0到d/2-1是渐渐变小的),它可以带来一定的远程衰减性。
  • 3.m是第m个位置的位置编码。(对于不同位置,在 的基础上要乘m倍, 也是绝对位置信息必不可少的

为什么从2维可以扩展到n维呢?

原论文中说: 因为内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接

这里怎么理解呢?

上一小节,我们证明了应用了RoPE旋转矩阵的2维q和k点积,他们的内积可以满足g函数(带有相对关系m-n)

对于n维而言,n维度q和k最终要通过内积来计算其他词对当前词的重要性,因为内积是element-wise相乘,然后相加

因此,我们可以分解d维度,分为d/2组,每一组都应用了RoPE旋转矩阵,都满足一个函数g(带有相对关系m-n),最后他们相加,也一定会满足g函数。

应用我们找到的函数f到f_q和f_k的内积操作,公式如下。(因此,其也满足我们4.1小节要找的函数g)

也是复数的指数形式

公式12:扩展到d维的f(q,m)和f(k,m)的点积

解释:

1.值得指出的是, 是一个正交矩阵,它不会改变向量的模长,因此通常来说它不会改变原模型的稳定性。

2.由于 比较稀疏,因此,我们推荐使用下面的方式来实现。( 复数的矩阵形式

3. 可以看出RoPE形式上和Sinusoidal位置编码有点相似,只不过Sinusoidal位置编码是加性的,而RoPE可以视为乘性的。

公式13:旋转位置编码的高效实现方法

如何从几何角度来思考,对于d维度的q,旋转位置编码做了什么呢?

图3:多维RoPE旋转示例

上图很直观的展示了:

1.对于位置为m的d维q向量,我们分为d/2组,每两个相邻维度为一组,共同旋转一个角度

2. 是一个这是一个从1渐变到接近于0的函数,因此,前面维度的 旋转的更快,后面的旋转的更慢。

下面讲一个n维度RoPE从转圈角度,在训练过程中的特点。

参考 苏剑林的:Transformer升级之路:16、“复盘”长度外推技术 [2] ,我们形式定义应用了RoPE q和k内积如下:

公式14:应用旋转位置编码的q和k点积展开

实际就是单位圆上的点,这个点逆时针旋转 度,其实在训练位置编码的过程中,我们可以看成是在训练d/2个单位圆。(如果圆上的点都被训练过了,那么就认为训练充分了)

在训练过程中,如果训练长度是 (假设预训练过程中,长度都是一样的)。那么在 m-n -1] 。

前面的维度 较大,所以在训练过程中,已经转了很多圈了,圆上的每个点都被训练过。(如下图左边)

而对于 后面的维度 较小,所以在训练过程中,转圈不充分,只有部分弧长。(如下图右边)

对于这种情况,如果测试的时候遇到更大的L_test,那么就超出了训练过的弧范围,从而有无法预估的表现。

这个时候就要想办法将它压缩到已经被充分训练过的那段弧上(位置内插)。

图4:训练过程中,不同\theta_i的训练图示

6.代码实现: transformer里面qwen2的rope实现方式:

https://github.com/huggingface/... [3]

def _compute_default_rope_parameters(
    configOptional[PretrainedConfig] = None,
    deviceOptional["torch.device"] = None,
    seq_lenOptional[int] = None,
    **rope_kwargs,
) -> Tuple["torch.Tensor", float]:
    """
    Computes the inverse frequencies according to the original RoPE implementation
    Args:
        config ([`~transformers.PretrainedConfig`]):
            The model configuration.
        device (`torch.device`):
            The device to use for initialization of the inverse frequencies.
        seq_len (`int`, *optional*):
            The current sequence length. Unused for this type of RoPE.
        rope_kwargs (`Dict`, *optional*):
            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
    Returns:
        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
    "
""
    if config is not None and len(rope_kwargs) > 0:
        raise ValueError(
            "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
            f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
        )
    iflen (rope_kwargs) > 0:
        base = rope_kwargs["base"]
        dim = rope_kwargs["dim"]
    elif config is not None:
        base = config.rope_theta
        partial_rotary_factor = config.partial_rotary_factorifhasattr(config, "partial_rotary_factor"else1.0
        head_dim = getattr(config, "head_dim", config.hidden_size// config.num_attention_heads)
        dim = int(head_dim * partial_rotary_factor)

    attention_factor = 1.0  # Unusedinthis type ofRoPE

    # Compute the inverse frequencies
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
    return inv_freq, attention_factor

这部分代码其实就计算了 ,可以看出对于d维向量,每两个维度共享一个

inv_freq的维度:[1,dim/2]

Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbeddingwithLlama->Qwen2
classQwen2RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,
        max_position_embeddings=2048,
        base=10000,
        device=None,
        scaling_factor=1.0,
        rope_type="default",
        configOptional[Qwen2Config] = None,
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used forBC
        self.rope_kwargs = {}
        if config is None:
            logger.warning_once(
                "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.46"
            )
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings
            self.original_max_seq_len = max_position_embeddings
        else:
            # BC"rope_type" was originally "type"
            if config.rope_scaling is not None:
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
            else:
                self.rope_type = "default"
            self.max_seq_len_cached = config.max_position_embeddings
            self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        #这里初始化的\theta_i
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        "
""
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODOjoao: may breakwith compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic"in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # CoreRoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -11)
        position_ids_expanded = position_ids[:, None, :].float()
        # Forcefloat32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.typ
        device_type = device_type ifisinstance(device_type, str) and device_type != "mps"else"cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(12)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # AdvancedRoPEtypes (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

这里是在计算 的值。

Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1// 2]
    x2 = x[..., x.shape[-1// 2 :]
    return torch.cat((-x2, x1), dim=-1)


Copiedfrom transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    "
""
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

1.对分别经过q_proj和k_proj的q和v应用旋转位置编码。 https://github.com/huggingface/... [4]

2.通过rotate_half,是将q分为d/2组。(注意,这里实现的是GPT-NeoX的旋转位置编码,而非上面理论部分介绍的。)

GPT-NeoX style :(为了高效实现)

不是相邻两个元素为一组,而是 为一组

图5:GPT-NeoX style rope推导

图来自: https://discuss.huggingface.co/... [5]

GPT-J style

是和原始论文和博客一样,使用的相邻两个为一组

7.超参数的分析:

1. 旋转角的随着i的变化。

我们首先了解下 在做什么变化?(i从0到d/2-1)。比如d=512维度的 会怎么变化

可以看出: 从1到0递减变化。(i从0到d/2-1)

图3:当维度为512时,base为10000,\theta_i随机i的变化而变化。

画图的代码如下:

import numpy as np
import matplotlib.pyplotas plt
import torch

# 设置参数
theta_base = 10_000
head_dim = 512

# 计算 inv_freq
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2).float() / head_dim))

# 创建横坐标
x = torch.arange(0256)  # 0 到 256

# 开始绘图
plt.figure(figsize=(106))
plt.plot(x.numpy(), inv_freq.numpy(), label='inv_freq', color='blue')
plt.xlim(0, head_dim)
plt.ylim(0, np.max(inv_freq.numpy()) * 1.1)  # 从 0 到最大值稍微增加一些
plt.xlabel('Index (0 to 256)')
plt.ylabel('inv_freq')
plt.title('Inverse Frequency Plot')
plt.grid()
plt.legend()
plt.show()

2.远程衰减:f(q,m)和f(k,m)的点积随着theta_base的变化。

我们假设q和k是全1的向量在应用了RoPE后,他们的点积如下图所示:

theta_base=10000

theta_base=1000000


我们可以看出,基数影响衰减的范围,基数越大,衰减的越慢。

因此,更长的文本,需要更大的base。

import torch

import matplotlib.pyplotas plt

import numpy as np
import torch

def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0"Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2).float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0"Head dimension must be even"

    # Split x into first half and second half
    #[2,4,5,8]
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (11, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    #[2,4,5,16]
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)

def rope_test():

    # Settings
    batch_size = 1
    context_len = 64000
    num_heads = 4
    head_dim = 4096

    # InstantiateRoPE parameters
    cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)
    torch.manual_seed(123)
    queries = torch.ones(1,1,context_len, head_dim)
    keys = torch.ones(1,1,context_len, head_dim)

    # Dummy query and key tensors
    #[2,4,5,16]
    # queries = torch.randn(batch_size, num_heads, context_len, head_dim)
    # keys = torch.randn(batch_size, num_heads, context_len, head_dim)

    # # Apply rotary position embeddings
    queries_rot = compute_rope(queries, cos, sin)
    keys_rot = compute_rope(keys, cos, sin)

    print(queries_rot)

    print(queries_rot.shape)
    print(queries_rot[0,0,[0]].shape)
    print(keys_rot[0,0,:,:].transpose(1,0).shape)

    #[q对k的点积得分]
    q_k_dot = queries_rot[0,0,[0]] @ keys_rot[0,0,:,:].transpose(1,0)
    distances = np.arange(064000)  # Distances to test
    dot_products = q_k_dot[0].numpy()

    import matplotlib.pyplotas plt

    # Plot the results
    plt.figure(figsize=(106))
    plt.plot(distances, dot_products, label='q·k after RoPE', color='blue')
    plt.xlabel('Relative Distance')
    plt.ylabel('Dot Product')
    plt.title






请到「今天看啥」查看全文