旋转位置编码
论文:RoFormer: Enhanced Transformer with Rotary Position Embedding
代码:https://github.com/ZhuiyiTechnology/roformer
RoPE旋转位置编码的出发点是“通过绝对位置编码的方式实现相对位置编码”,Transformer位置编码是加性的正弦位置编码(位置向量和词向量是直接相加的),然而RoPE采用的是乘性的正弦位置编码,因为自注意力机制计算过程中需要通过query词向量qm和key词向量kn点乘计算注意力得分,因此想到可以先用绝对位置编码表示qm和kn的位置信息,需要找到函数g(xm,xn,m−n)使得qm和kn点乘之后可以计算得到qm和kn的相对位置信息m-n:
<fq(xm,m),fk(xn,n)>=g(xm,xn,m−n)
为了实现这个目标,需要引入复数和旋转矩阵(复数在空间上可以表示旋转),假定现在词嵌入向量的维度是两维d = 2 ,然后RoPE利用2维度平面上的向量的几何性质,再结合复数的性质,神奇般的找到了满足上述等式的 f 和 g ,其形式如下:
fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θ]
为了实现这个目标,需要引入复数和旋转矩阵(复数在空间上可以表示旋转),假定现在词嵌入向量的维度是两维d = 2 ,然后RoPE利用2维度平面上的向量的几何性质,再结合复数的性质,神奇般的找到了满足上述等式的 f 和 g ,其形式如下:
fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θ]
其中 eimθ是复数的指数形式(可以通过旋转矩阵表示),Wq和Wk表示q和k输入注意力线性层权重,Re 表示复数的实部。
RoPE将θ设置为:
θ={θi=10000d−2i,i∈[1,2,…,2d]}
fq(xm,m)可以表示成下面的式子:
fq(xm,m)=(cosmθsinmθ−sinmθcosmθ)(Wq(1,1)Wq(2,1)Wq(1,2)Wq(2,2))(xm(2)xm(1))=(cosmθsinmθ−sinmθcosmθ)(qm(2)qm(1))
而前面说的能让乘性正弦位置编码绝对位置编码在q和k向量点乘之后具有相对位置信息m-n的就是旋转矩阵:
Rθ=[cosθsinθ−sinθcosθ]
复数旋转矩阵证明:https://chatgpt.com/s/t_6948207f1b7c8191935745beadbc563b
fk(xn,n)可以表示成下面的式子:
fk(xm,m)=(cosmθsinmθ−sinmθcosmθ)(Wk(1,1)Wk(2,1)Wk(1,2)Wk(2,2))(xm(2)xm(1))=(cosmθsinmθ−sinmθcosmθ)(km(2)km(1))
g(xm,xn,m−n)可以表示如下:
g(xm,xn,m−n)=(qm(1)qm(2))(cos((m−n)θ)sin((m−n)θ)−sin((m−n)θ)cos((m−n)θ))(kn(2)kn(1))
g(xm,xn,m−n)公式证明:
将fk(xn,n)和fk(xn,n)写成向量形式:
fq(xm,m)fk(xn,n)=[qm(1)cos(mθ)−qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]=[kn(1)cos(nθ)−kn(2)sin(nθ),kn(2)cos(nθ)+kn(1)sin(nθ)]
g(xm,xn,m−n)可以表示如下向量形式:
<fq(xm,m),fk(xn,n)>=(qm(1)cos(mθ)−qm(2)sin(mθ))(kn(1)cos(nθ)−kn(2)sin(nθ))+(qm(2)cos(mθ)+qm(1)sin(mθ))(kn(2)cos(nθ)+kn(1)sin(nθ))=qm(1)cos(mθ)kn(1)cos(nθ)−qm(1)cos(mθ)kn(2)sin(nθ)−qm(2)sin(mθ)kn(1)cos(nθ)+qm(2)sin(mθ)kn(2)sin(nθ)+qm(2)cos(mθ)kn(2)cos(nθ)+qm(2)cos(mθ)kn(1)sin(nθ)+qm(1)sin(mθ)kn(2)cos(nθ)+qm(1)sin(mθ)kn(1)sin(nθ)
背景知识-三角函数恒等变换公式:
sin(a+b)sin(a−b)cos(a+b)cos(a−b)=sinacosb+cosasinb,=sinacosb−cosasinb,=cosacosb−sinasinb,=cosacosb+sinasinb.
首先,把上面第二点的式子整理一下,总计8项,为了把qk相关的项提取出来,第1项和8项合并处理、第2项和7项合并处理、第3项和6项合并处理、第4项和5项合并处理:
<fq(xm,m),fk(xn,n)>=qm(1)kn(1)(cos(mθ)cos(nθ)+sin(mθ)sin(nθ))+qm(1)kn(2)(−cos(mθ)sin(nθ)+sin(mθ)cos(nθ))+qm(2)kn(1)(−sin(mθ)cos(nθ)+cos(mθ)sin(nθ))+qm(2)kn(2)(sin(mθ)sin(nθ)+cos(mθ)cos(nθ))=qm(1)kn(1)cos((m−n)θ)+qm(1)kn(2)sin((m−n)θ)−qm(2)kn(1)sin((m−n)θ)+qm(2)kn(2)cos((m−n)θ)=(qm(1)kn(1)+qm(2)kn(2))cos((m−n)θ)+(qm(1)kn(2)−qm(2)kn(1))sin((m−n)θ)=(qm(1)kn(1)+qm(2)kn(2))cos((m−n)θ)−(qm(2)kn(1)−qm(1)kn(2))sin((m−n)θ)=g(xm,xn,m−n)
至此证明了位置m的query向量和位置n的key向量的内积就是函数g(xm,xn,m−n)。
RoPE位置编码只作用于q和k,不直接作用于v,但注意力权重(由 RoPE 的 q/k 计算而来)包含位置信息,最终 v 的加权输出会间接包含位置信息。
由于内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接
由于Rm的稀疏性,所以直接用矩阵乘法来实现会很浪费算力,所以在计算时采用逐位相乘再相加的方式进行:
⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛q0q1q2q3⋮qd−2qd−1⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞⊗⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞+⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛−q1q0−q3q2⋮−qd−1qd−2⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞⊗⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛sinmθ0sinmθ1sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞
其中⊗是逐位对应相乘。
参考文章: