searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

RoPE与长度外推

2025-11-28 09:36:00
0
0

一. RoPE

1.为什么需要位置编码

接下来计算注意力后的特征向量

接下来思考一个问题:改变"求关注"的顺序再求自注意力

思考:"求关注"和"求注关"存在语义上的差别,其token的特征编码应该要体现差异

2.Embedding 如何体现文本的位置关系

最简单的方法: 构造f(x,i), i表示token的位置,引入位置编码

特性​​

​​绝对位置编码​​

​​旋转位置编码(RoPE)​​

​​位置信息类型​​

绝对位置

绝对+相对位置

​​长序列支持​​

有限(正弦外推弱,可学习固定长度)

极强(连续旋转支持任意长度)

​​计算效率​​

高(简单加法)

中(需旋转操作,但支持KV Cache优化)

​​适用场景​​

短序列、简单任务

长文本生成、大模型推理

3.RoPE的数学理论

参考此篇文章

https://zhuanlan.zhihu.com/p/30190764463

苏神:https://spaces.ac.cn/archives/8265

4.RoPE的性质

4.1 远距离衰减

远距离衰减指的是随着q和k的相对距离的增大,加入位置编码之后的内积应该随着距离增大而减小,这样相当于离得远的token分配到的attention会比较小,而离得近的token会得到更多的注意力。这样的特性确实直觉上比较符合人类的注意力机制。

把各个参数(base、window size、head size)下的内积值画出来看看是怎么衰减的。实现参考下面的代码。这里偷懒没有实现得很高效,勉强能用就行。

参数​​

​​典型值​​

​​影响范围​​

​​示例模型​​

Base (θ)

10000或动态调整

位置编码的频率分布

LLaMA(固定10000)

Window Size (w)

512或2048

注意力计算的实际范围

Longformer(滑动窗口)

Head Size (dh​)

64、128或256

旋转操作的维度和多头注意力设计

 

Base:(通常记为 θ或 base)是控制旋转角度频率的超参数,决定了位置编码的波长分布。其计算公式通常为 θi=10000−2i/d,其中 d是向量的维度,i是维度索引。

作用:

控制旋转角度的变化速率:较大的 base值会使高频维度(靠近向量末尾的维度)旋转更慢,低频维度(靠近向量开头的维度)旋转更快,从而捕捉不同尺度的位置关系。

影响外推能力:合适的 base值能平衡短距离和长距离依赖的建模,例如在插值方法(如NTK-aware)中调整 base可以优化长序列处理能力

维度 i

θi​

0

10000−0=1

1

10000−0.25≈0.56

2

10000−0.5≈0.01

3

10000−0.75≈0.0002

(1)q = k = 1

假设q和k都是1向量,如果q在位置0,画出k在0~4096位置下和q在位置编码后的内积如下。

这里使用了base=10000,d=512,可以看到整体趋势是震荡下降的。

不过如果把窗口从4096增大到65536,图像会变成这样。

可以看到图像不再是单纯的衰减,在距离超过大约15000的时候,出现了上升。

实际上这个包含多个周期函数的内积也具有一定的周期性,并不是在整个域上保持衰减的特性。只要相对距离够大,超过这个周期的1/4,内积就会再次上升。

而这个内积的周期受base调控,base越大,周期越长,因此现在的长窗口模型起步就是base=5M或者10M。

(2)q、k随机

前面是把q和k固定为1向量,现在试着把q和k初始化为随机向量,图像如下

相比1向量出现了更多的震荡,但是大体上还是能保持一定的远距离衰减特性。

  • RoPE的远距离衰减是震荡的,并且整个内积本身也具有一定的周期性,只有把base设得足够大,才能让内积结果在模型窗口大小内保持远距离衰减的特性。
  • 在q和k的相对距离小的时候,内积差距较大,也就是衰减较快;到了远距离之后,衰减变慢,也就是从内积角度来看,分辨率会变小。

5.RoPE与上下文拓展

6.1.推理外推

RoPE通过旋转矩阵将位置信息融入查询(Q)和键(K)向量中,其核心思想是:

(1)旋转操作:对向量的每一对维度施加旋转,旋转角度与位置线性相关(例如,位置 m的向量旋转 角度)。这种设计使得两个向量的点积仅依赖于它们的相对距离 ∣mn∣,而与绝对位置无关。因此,模型在训练时学习的是相对位置关系,而非固定长度的绝对位置编码

(2)连续性保证:旋转角度的连续变化(如 θi=10000−2i/d)确保了位置外推时角度的平滑过渡,避免了离散跳跃导致的信息断裂。即使遇到超出训练长度的位置,旋转后的向量仍能保持合理的几何关系

关键点:RoPE的旋转机制天然支持相对位置建模,使得模型在推理时能泛化到任意长度的序列,只要相对距离在训练范围内出现过

6.2.训练外推

比如YaRN外推训练:https://www.cnblogs.com/laozhanghahaha/p/18345815

7.RoPE的数学特性与KV Cache的兼容性

RoPE通过旋转矩阵将位置信息融入查询(Q)和键(K)向量中,其核心思想是对向量的每一对维度进行旋转变换。旋转角度与token的位置相关,但旋转矩阵本身是正交的,这意味着:

(1)位置信息解耦:RoPE将位置编码与向量内容分离,使得位置变换后的Q和K向量仍能通过点积保留相对位置关系。这种特性允许KV Cache存储未旋转的原始K向量,仅在计算注意力时动态应用旋转,避免了重复存储不同位置的K向量

(2)缓存原始K向量:由于RoPE的旋转操作是线性的,推理时只需缓存未旋转的K向量(或低秩压缩后的潜在向量),在需要时根据当前位置动态生成旋转后的K向量。这显著减少了KV Cache的存储需求,因为无需为每个位置单独缓存旋转后的K向量

(3)传统绝对位置编码(如正弦编码或可学习嵌入)需要将位置信息直接加到K向量中,导致KV Cache必须存储已编码的K向量,无法动态调整位置。而RoPE的旋转操作是位置相关的函数,允许K向量在缓存后仍能灵活适应不同位置,从而支持长序列推理

8.代码解读

```plaintext
import torch
def precompute_freqs_cis(dim, end, theta=10000):
    # e^i*m*(10000)^-2i/d ->(10000)^(-2i/d) i->(0,2/d -1)
    # 1/(10000**[0.0]/2)     shape [1]
    freqs = 1.0/(theta**(torch.arange(0, dim, 2)[:(dim//2)].float()/dim))
    # [3]
    m = torch.arange(end)
    # [3],[1] ->[3,1]
    freqs = torch.outer(m, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    # [3,1] ->[1,3,1,1]
    shape = [d if i==1 or i==ndim-1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq, xk, freqs_cis):
    # 复数 实部+虚部   xq shape [1,3,2] [1,3,1,2]   *xq.shape[:-1]  [1,3,1,2] -> view_as_complex ->[1,3,1,1]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    # freqs_cis [3,1] ->[1,3,1,1]
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    # q*e^I*m*theta
    #[1,3,2]  xq_*ffreqs_cis->[1,3,1,1]->view_as_real->[1,3,1,2]->[1,3,2]
    xq_out = torch.view_as_real(xq_*freqs_cis).flatten(3)
    # xq [1,3,1,2]->[1,3,1,1,2]->VIEW_CIMPLEX-》[1,3,1,1,1]->view_as_real->[1,3,1,1,2]->[1,3,1,2]
    xk_out = torch.view_as_real(xk_*freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xq)

dim = 2
end = 3
fres_cis = precompute_freqs_cis(dim, end)
xq = torch.randn(1, end, dim)
xk = torch.randn(1, end, dim)
res = apply_rotary_emb(xq, xk, fres_cis)
print(res[0].shape)
```
0条评论
0 / 1000
f****n
6文章数
0粉丝数
f****n
6 文章 | 0 粉丝
原创

RoPE与长度外推

2025-11-28 09:36:00
0
0

一. RoPE

1.为什么需要位置编码

接下来计算注意力后的特征向量

接下来思考一个问题:改变"求关注"的顺序再求自注意力

思考:"求关注"和"求注关"存在语义上的差别,其token的特征编码应该要体现差异

2.Embedding 如何体现文本的位置关系

最简单的方法: 构造f(x,i), i表示token的位置,引入位置编码

特性​​

​​绝对位置编码​​

​​旋转位置编码(RoPE)​​

​​位置信息类型​​

绝对位置

绝对+相对位置

​​长序列支持​​

有限(正弦外推弱,可学习固定长度)

极强(连续旋转支持任意长度)

​​计算效率​​

高(简单加法)

中(需旋转操作,但支持KV Cache优化)

​​适用场景​​

短序列、简单任务

长文本生成、大模型推理

3.RoPE的数学理论

参考此篇文章

https://zhuanlan.zhihu.com/p/30190764463

苏神:https://spaces.ac.cn/archives/8265

4.RoPE的性质

4.1 远距离衰减

远距离衰减指的是随着q和k的相对距离的增大,加入位置编码之后的内积应该随着距离增大而减小,这样相当于离得远的token分配到的attention会比较小,而离得近的token会得到更多的注意力。这样的特性确实直觉上比较符合人类的注意力机制。

把各个参数(base、window size、head size)下的内积值画出来看看是怎么衰减的。实现参考下面的代码。这里偷懒没有实现得很高效,勉强能用就行。

参数​​

​​典型值​​

​​影响范围​​

​​示例模型​​

Base (θ)

10000或动态调整

位置编码的频率分布

LLaMA(固定10000)

Window Size (w)

512或2048

注意力计算的实际范围

Longformer(滑动窗口)

Head Size (dh​)

64、128或256

旋转操作的维度和多头注意力设计

 

Base:(通常记为 θ或 base)是控制旋转角度频率的超参数,决定了位置编码的波长分布。其计算公式通常为 θi=10000−2i/d,其中 d是向量的维度,i是维度索引。

作用:

控制旋转角度的变化速率:较大的 base值会使高频维度(靠近向量末尾的维度)旋转更慢,低频维度(靠近向量开头的维度)旋转更快,从而捕捉不同尺度的位置关系。

影响外推能力:合适的 base值能平衡短距离和长距离依赖的建模,例如在插值方法(如NTK-aware)中调整 base可以优化长序列处理能力

维度 i

θi​

0

10000−0=1

1

10000−0.25≈0.56

2

10000−0.5≈0.01

3

10000−0.75≈0.0002

(1)q = k = 1

假设q和k都是1向量,如果q在位置0,画出k在0~4096位置下和q在位置编码后的内积如下。

这里使用了base=10000,d=512,可以看到整体趋势是震荡下降的。

不过如果把窗口从4096增大到65536,图像会变成这样。

可以看到图像不再是单纯的衰减,在距离超过大约15000的时候,出现了上升。

实际上这个包含多个周期函数的内积也具有一定的周期性,并不是在整个域上保持衰减的特性。只要相对距离够大,超过这个周期的1/4,内积就会再次上升。

而这个内积的周期受base调控,base越大,周期越长,因此现在的长窗口模型起步就是base=5M或者10M。

(2)q、k随机

前面是把q和k固定为1向量,现在试着把q和k初始化为随机向量,图像如下

相比1向量出现了更多的震荡,但是大体上还是能保持一定的远距离衰减特性。

  • RoPE的远距离衰减是震荡的,并且整个内积本身也具有一定的周期性,只有把base设得足够大,才能让内积结果在模型窗口大小内保持远距离衰减的特性。
  • 在q和k的相对距离小的时候,内积差距较大,也就是衰减较快;到了远距离之后,衰减变慢,也就是从内积角度来看,分辨率会变小。

5.RoPE与上下文拓展

6.1.推理外推

RoPE通过旋转矩阵将位置信息融入查询(Q)和键(K)向量中,其核心思想是:

(1)旋转操作:对向量的每一对维度施加旋转,旋转角度与位置线性相关(例如,位置 m的向量旋转 角度)。这种设计使得两个向量的点积仅依赖于它们的相对距离 ∣mn∣,而与绝对位置无关。因此,模型在训练时学习的是相对位置关系,而非固定长度的绝对位置编码

(2)连续性保证:旋转角度的连续变化(如 θi=10000−2i/d)确保了位置外推时角度的平滑过渡,避免了离散跳跃导致的信息断裂。即使遇到超出训练长度的位置,旋转后的向量仍能保持合理的几何关系

关键点:RoPE的旋转机制天然支持相对位置建模,使得模型在推理时能泛化到任意长度的序列,只要相对距离在训练范围内出现过

6.2.训练外推

比如YaRN外推训练:https://www.cnblogs.com/laozhanghahaha/p/18345815

7.RoPE的数学特性与KV Cache的兼容性

RoPE通过旋转矩阵将位置信息融入查询(Q)和键(K)向量中,其核心思想是对向量的每一对维度进行旋转变换。旋转角度与token的位置相关,但旋转矩阵本身是正交的,这意味着:

(1)位置信息解耦:RoPE将位置编码与向量内容分离,使得位置变换后的Q和K向量仍能通过点积保留相对位置关系。这种特性允许KV Cache存储未旋转的原始K向量,仅在计算注意力时动态应用旋转,避免了重复存储不同位置的K向量

(2)缓存原始K向量:由于RoPE的旋转操作是线性的,推理时只需缓存未旋转的K向量(或低秩压缩后的潜在向量),在需要时根据当前位置动态生成旋转后的K向量。这显著减少了KV Cache的存储需求,因为无需为每个位置单独缓存旋转后的K向量

(3)传统绝对位置编码(如正弦编码或可学习嵌入)需要将位置信息直接加到K向量中,导致KV Cache必须存储已编码的K向量,无法动态调整位置。而RoPE的旋转操作是位置相关的函数,允许K向量在缓存后仍能灵活适应不同位置,从而支持长序列推理

8.代码解读

```plaintext
import torch
def precompute_freqs_cis(dim, end, theta=10000):
    # e^i*m*(10000)^-2i/d ->(10000)^(-2i/d) i->(0,2/d -1)
    # 1/(10000**[0.0]/2)     shape [1]
    freqs = 1.0/(theta**(torch.arange(0, dim, 2)[:(dim//2)].float()/dim))
    # [3]
    m = torch.arange(end)
    # [3],[1] ->[3,1]
    freqs = torch.outer(m, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    # [3,1] ->[1,3,1,1]
    shape = [d if i==1 or i==ndim-1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq, xk, freqs_cis):
    # 复数 实部+虚部   xq shape [1,3,2] [1,3,1,2]   *xq.shape[:-1]  [1,3,1,2] -> view_as_complex ->[1,3,1,1]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    # freqs_cis [3,1] ->[1,3,1,1]
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    # q*e^I*m*theta
    #[1,3,2]  xq_*ffreqs_cis->[1,3,1,1]->view_as_real->[1,3,1,2]->[1,3,2]
    xq_out = torch.view_as_real(xq_*freqs_cis).flatten(3)
    # xq [1,3,1,2]->[1,3,1,1,2]->VIEW_CIMPLEX-》[1,3,1,1,1]->view_as_real->[1,3,1,1,2]->[1,3,1,2]
    xk_out = torch.view_as_real(xk_*freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xq)

dim = 2
end = 3
fres_cis = precompute_freqs_cis(dim, end)
xq = torch.randn(1, end, dim)
xk = torch.randn(1, end, dim)
res = apply_rotary_emb(xq, xk, fres_cis)
print(res[0].shape)
```
文章来自个人专栏
文章 | 订阅
0条评论
0 / 1000
请输入你的评论
0
0