0%

大模型位置编码笔记

旋转位置编码

论文:RoFormer: Enhanced Transformer with Rotary Position Embedding

代码:https://github.com/ZhuiyiTechnology/roformer

RoPE旋转位置编码的出发点是“通过绝对位置编码的方式实现相对位置编码”,Transformer位置编码是加性的正弦位置编码(位置向量和词向量是直接相加的),然而RoPE采用的是乘性的正弦位置编码,因为自注意力机制计算过程中需要通过query词向量qmq_m和key词向量knk_n点乘计算注意力得分,因此想到可以先用绝对位置编码表示qmq_mknk_n的位置信息,需要找到函数g(xm,xn,mn)g(x_m, x_n, m-n)使得qmq_mknk_n点乘之后可以计算得到qmq_mknk_n的相对位置信息m-n:

<fq(xm,m),fk(xn,n)>=g(xm,xn,mn)< f_{q} \left( x_{m} , m \right) , f_{k} \left( x_{n} , n \right) >=g \left( x_{m} , x_{n} , m-n \right)

为了实现这个目标,需要引入复数和旋转矩阵(复数在空间上可以表示旋转),假定现在词嵌入向量的维度是两维d = 2 ,然后RoPE利用2维度平面上的向量的几何性质,再结合复数的性质,神奇般的找到了满足上述等式的 f 和 g ,其形式如下:

fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]\begin{aligned} & f_q\left(\boldsymbol{x}_m, m\right)=\left(\boldsymbol{W}_q \boldsymbol{x}_m\right) e^{i m \theta} \\ & f_k\left(\boldsymbol{x}_n, n\right)=\left(\boldsymbol{W}_k \boldsymbol{x}_n\right) e^{i n \theta} \\ & g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right)=\operatorname{Re}\left[\left(\boldsymbol{W}_q \boldsymbol{x}_m\right)\left(\boldsymbol{W}_k \boldsymbol{x}_n\right)^* e^{i(m-n) \theta}\right] \end{aligned}

为了实现这个目标,需要引入复数和旋转矩阵(复数在空间上可以表示旋转),假定现在词嵌入向量的维度是两维d = 2 ,然后RoPE利用2维度平面上的向量的几何性质,再结合复数的性质,神奇般的找到了满足上述等式的 f 和 g ,其形式如下:

fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]\begin{aligned} & f_q\left(\boldsymbol{x}_m, m\right)=\left(\boldsymbol{W}_q \boldsymbol{x}_m\right) e^{i m \theta} \\ & f_k\left(\boldsymbol{x}_n, n\right)=\left(\boldsymbol{W}_k \boldsymbol{x}_n\right) e^{i n \theta} \\ & g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right)=\operatorname{Re}\left[\left(\boldsymbol{W}_q \boldsymbol{x}_m\right)\left(\boldsymbol{W}_k \boldsymbol{x}_n\right)^* e^{i(m-n) \theta}\right] \end{aligned}

其中 eimθe^{imθ}是复数的指数形式(可以通过旋转矩阵表示),WqW_qWkW_k表示q和k输入注意力线性层权重,Re 表示复数的实部。

RoPE将θ设置为:

θ={θi=100002id,i[1,2,,d2]}\theta=\left\{\theta_{i}=1 0 0 0 0^{\frac{-2 i} {d}} , i \in[ 1 , 2 , \ldots, \frac{d} {2} ] \right\}

fq(xm,m)f_q(x_m, m)可以表示成下面的式子:

fq(xm,m)=(cosmθsinmθsinmθcosmθ)(Wq(1,1)Wq(1,2)Wq(2,1)Wq(2,2))(xm(1)xm(2))=(cosmθsinmθsinmθcosmθ)(qm(1)qm(2))\begin{aligned} f_q\left(\boldsymbol{x}_m, m\right) & =\left(\begin{array}{cc} \cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{array}\right)\left(\begin{array}{ll} W_q^{(1,1)} & W_q^{(1,2)} \\ W_q^{(2,1)} & W_q^{(2,2)} \end{array}\right)\binom{x_m^{(1)}}{x_m^{(2)}} \\ & =\left(\begin{array}{cc} \cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{array}\right)\binom{q_m^{(1)}}{q_m^{(2)}} \end{aligned}

而前面说的能让乘性正弦位置编码绝对位置编码在q和k向量点乘之后具有相对位置信息m-n的就是旋转矩阵

Rθ=[cosθsinθsinθcosθ]\mathbf{R}_\theta=\left[\begin{array}{cc} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{array}\right]

复数旋转矩阵证明:https://chatgpt.com/s/t_6948207f1b7c8191935745beadbc563b

fk(xn,n)f_k(x_n, n)可以表示成下面的式子:

fk(xm,m)=(cosmθsinmθsinmθcosmθ)(Wk(1,1)Wk(1,2)Wk(2,1)Wk(2,2))(xm(1)xm(2))=(cosmθsinmθsinmθcosmθ)(km(1)km(2))\begin{gathered} f_k\left(\boldsymbol{x}_m, m\right)=\left(\begin{array}{cc} \cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{array}\right)\left(\begin{array}{ll} W_k^{(1,1)} & W_k^{(1,2)} \\ W_k^{(2,1)} & W_k^{(2,2)} \end{array}\right)\binom{x_m^{(1)}}{x_m^{(2)}} \\ =\left(\begin{array}{cc} \cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{array}\right)\binom{k_m^{(1)}}{k_m^{(2)}} \end{gathered}

g(xm,xn,mn)g(x_m, x_n, m-n)可以表示如下:

g(xm,xn,mn)=(qm(1)qm(2))(cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ))(kn(1)kn(2))g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right)=\left(\begin{array}{cc} q_m^{(1)} & q_m^{(2)} \end{array}\right)\left(\begin{array}{cc} \cos ((m-n) \theta) & -\sin ((m-n) \theta) \\ \sin ((m-n) \theta) & \cos ((m-n) \theta) \end{array}\right)\binom{k_n^{(1)}}{k_n^{(2)}}

g(xm,xn,mn)g(x_m, x_n, m-n)公式证明:
fk(xn,n)f_k(x_n, n)fk(xn,n)f_k(x_n, n)写成向量形式:

fq(xm,m)=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]fk(xn,n)=[kn(1)cos(nθ)kn(2)sin(nθ),kn(2)cos(nθ)+kn(1)sin(nθ)]\begin{aligned} f_q\left(x_m, m\right) & =\left[q_m^{(1)} \cos (m \theta)-q_m^{(2)} \sin (m \theta), q_m^{(2)} \cos (m \theta)+q_m^{(1)} \sin (m \theta)\right] \\ f_k\left(x_n, n\right) & =\left[k_n^{(1)} \cos (n \theta)-k_n^{(2)} \sin (n \theta), k_n^{(2)} \cos (n \theta)+k_n^{(1)} \sin (n \theta)\right] \end{aligned}

g(xm,xn,mn)g(x_m, x_n, m-n)可以表示如下向量形式:

<fq(xm,m),fk(xn,n)>=(qm(1)cos(mθ)qm(2)sin(mθ))(kn(1)cos(nθ)kn(2)sin(nθ))+(qm(2)cos(mθ)+qm(1)sin(mθ))(kn(2)cos(nθ)+kn(1)sin(nθ))=qm(1)cos(mθ)kn(1)cos(nθ)qm(1)cos(mθ)kn(2)sin(nθ)qm(2)sin(mθ)kn(1)cos(nθ)+qm(2)sin(mθ)kn(2)sin(nθ)+qm(2)cos(mθ)kn(2)cos(nθ)+qm(2)cos(mθ)kn(1)sin(nθ)+qm(1)sin(mθ)kn(2)cos(nθ)+qm(1)sin(mθ)kn(1)sin(nθ)\begin{gathered} <f_q\left(x_m, m\right), f_k\left(x_n, n\right)> \\ =\left(q_m^{(1)} \cos (m \theta)-q_m^{(2)} \sin (m \theta)\right)\left(k_n^{(1)} \cos (n \theta)-k_n^{(2)} \sin (n \theta)\right) \\ +\left(q_m^{(2)} \cos (m \theta)+q_m^{(1)} \sin (m \theta)\right)\left(k_n^{(2)} \cos (n \theta)+k_n^{(1)} \sin (n \theta)\right) \\ =q_m^{(1)} \cos (m \theta) k_n^{(1)} \cos (n \theta)-q_m^{(1)} \cos (m \theta) k_n^{(2)} \sin (n \theta) \\ -q_m^{(2)} \sin (m \theta) k_n^{(1)} \cos (n \theta)+q_m^{(2)} \sin (m \theta) k_n^{(2)} \sin (n \theta) \\ +q_m^{(2)} \cos (m \theta) k_n^{(2)} \cos (n \theta)+q_m^{(2)} \cos (m \theta) k_n^{(1)} \sin (n \theta) \\ +q_m^{(1)} \sin (m \theta) k_n^{(2)} \cos (n \theta)+q_m^{(1)} \sin (m \theta) k_n^{(1)} \sin (n \theta) \end{gathered}

背景知识-三角函数恒等变换公式:

sin(a+b)=sinacosb+cosasinb,sin(ab)=sinacosbcosasinb,cos(a+b)=cosacosbsinasinb,cos(ab)=cosacosb+sinasinb.\begin{aligned} \sin(a+b) &= \sin a \cos b + \cos a \sin b, \\ \sin(a-b) &= \sin a \cos b - \cos a \sin b, \\ \cos(a+b) &= \cos a \cos b - \sin a \sin b, \\ \cos(a-b) &= \cos a \cos b + \sin a \sin b. \end{aligned}

首先,把上面第二点的式子整理一下,总计8项,为了把qk相关的项提取出来,第1项和8项合并处理、第2项和7项合并处理、第3项和6项合并处理、第4项和5项合并处理:

<fq(xm,m),fk(xn,n)>=qm(1)kn(1)(cos(mθ)cos(nθ)+sin(mθ)sin(nθ))+qm(1)kn(2)(cos(mθ)sin(nθ)+sin(mθ)cos(nθ))+qm(2)kn(1)(sin(mθ)cos(nθ)+cos(mθ)sin(nθ))+qm(2)kn(2)(sin(mθ)sin(nθ)+cos(mθ)cos(nθ))=qm(1)kn(1)cos((mn)θ)+qm(1)kn(2)sin((mn)θ)qm(2)kn(1)sin((mn)θ)+qm(2)kn(2)cos((mn)θ)=(qm(1)kn(1)+qm(2)kn(2))cos((mn)θ)+(qm(1)kn(2)qm(2)kn(1))sin((mn)θ)=(qm(1)kn(1)+qm(2)kn(2))cos((mn)θ)(qm(2)kn(1)qm(1)kn(2))sin((mn)θ)=g(xm,xn,mn)\begin{gathered} <f_q\left(x_m, m\right), f_k\left(x_n, n\right)> \\ =q_m^{(1)} k_n^{(1)}(\cos (m \theta) \cos (n \theta)+\sin (m \theta) \sin (n \theta)) \\ +q_m^{(1)} k_n^{(2)}(-\cos (m \theta) \sin (n \theta)+\sin (m \theta) \cos (n \theta)) \\ +q_m^{(2)} k_n^{(1)}(-\sin (m \theta) \cos (n \theta)+\cos (m \theta) \sin (n \theta)) \\ +q_m^{(2)} k_n^{(2)}(\sin (m \theta) \sin (n \theta)+\cos (m \theta) \cos (n \theta)) \\ =q_m^{(1)} k_n^{(1)} \cos ((m-n) \theta) \\ +q_m^{(1)} k_n^{(2)} \sin ((m-n) \theta) \\ -q_m^{(2)} k_n^{(1)} \sin ((m-n) \theta) \\ +q_m^{(2)} k_n^{(2)} \cos ((m-n) \theta) \\ =\left(q_m^{(1)} k_n^{(1)}+q_m^{(2)} k_n^{(2)}\right) \cos ((m-n) \theta)+\left(q_m^{(1)} k_n^{(2)}-q_m^{(2)} k_n^{(1)}\right) \sin ((m-n) \theta) \\ =\left(q_m^{(1)} k_n^{(1)}+q_m^{(2)} k_n^{(2)}\right) \begin{array}{c} \cos ((m-n) \theta)-\left(q_m^{(2)} k_n^{(1)}-q_m^{(1)} k_n^{(2)}\right) \sin ((m-n) \theta) \\ =g\left(x_m, x_n, m-n\right) \end{array} \end{gathered}

至此证明了位置m的query向量和位置n的key向量的内积就是函数g(xm,xn,mn)g(x_m, x_n, m-n)

RoPE位置编码只作用于q和k,不直接作用于v,但注意力权重(由 RoPE 的 q/k 计算而来)包含位置信息,最终 v 的加权输出会间接包含位置信息。

由于内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接

由于Rm的稀疏性,所以直接用矩阵乘法来实现会很浪费算力,所以在计算时采用逐位相乘再相加的方式进行:

(q0q1q2q3qd2qd1)(cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21)+(q1q0q3q2qd1qd2)(sinmθ0sinmθ1sinmθ1sinmθ1sinmθd/21sinmθd/21)\left(\begin{array}{c} q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right) \otimes\left(\begin{array}{c} \cos m \theta_0 \\ \cos m \theta_0 \\ \cos m \theta_1 \\ \cos m \theta_1 \\ \vdots \\ \cos m \theta_{d / 2-1} \\ \cos m \theta_{d / 2-1} \end{array}\right)+\left(\begin{array}{c} -q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{array}\right) \otimes\left(\begin{array}{c} \sin m \theta_0 \\ \sin m \theta_1 \\ \sin m \theta_1 \\ \sin m \theta_1 \\ \vdots \\ \sin m \theta_{d / 2-1} \\ \sin m \theta_{d / 2-1} \end{array}\right)

其中⊗是逐位对应相乘。

参考文章:

欢迎关注我的其它发布渠道