0%

大模型归一化方法

LayerNorm

LayerNorm(层归一化)是一种在深度学习中常用的归一化技术。

计算公式:

y=xE[x]Var[x]+ϵγ+βy=\frac{x-\mathrm{E} [ x ]} {\sqrt{\mathrm{V a r} [ x ]+\epsilon}} * \gamma+\beta

其中:

  • x:输入向量。
  • E[x]:输入向量均值。
  • Var[x]:输入向量方差。
  • γ:可学习的缩放参数。
  • β:可学习的偏移参数。

优点

  • 完全归一化(均值 0,方差 1),对分布偏移的矫正更彻底;
  • 偏移参数beta可灵活调整分布中心,适配更多任务(如 CV、NLP 基础模型)。

缺点

  • 均值计算增加计算量,大模型场景下效率偏低;
  • 中心化可能破坏特征的 “绝对偏移信息”(如语言模型中词向量的固有偏移);
  • 小批量 / 低维特征下,均值估计易受噪声影响,导致稳定性下降。

RMSNorm

RMSNorm( Root Mean Square Layer Normalization )核心思想是通过对输入向量进行缩放归一化,以提升训练稳定性和效率。

计算公式:

RMSNorm(x)=γxmean(x2)+ϵ\mathrm{R M S N o r m} ( x )=\gamma\odot\frac{x} {\sqrt{\mathrm{m e a n} ( x^{2} )+\epsilon}}

其中:

  • x:输入向量。
  • mean(x^2):向量元素的平方均值。
  • ϵ:极小常数( 如 10e−8 ),防止分母为零。
  • γ:可学习的缩放参数。

代码实现:

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim)) # 可学习参数γ

def _norm(self, x: torch.Tensor):
# 计算平方均值的根 (RMS)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor):
return self.gamma * self._norm(x.float()).type_as(x)

优点

  • 计算高效(少均值步骤),显存 / 算力开销更低,适配大模型(如 LLaMA、GPT-3);
  • 保留特征的原始均值偏移,更贴合语言模型的语义特性(词向量的偏移包含语义信息);
  • 无中心化带来的噪声放大,训练更稳定,收敛速度更快。

缺点

  • 仅归一化方差,对分布中心的偏移无矫正作用;
  • 无偏移参数,适配性稍弱(但实践中可通过其他层补偿)。

关键区别

维度 LayerNorm RMSNorm
中心化操作 有(减均值) 无(直接用原始值计算均方根)
可学习参数 gamma(缩放)+ beta(偏移) 仅gamma(缩放)
计算复杂度 稍高(多一步均值计算) 更低(少均值计算,浮点运算量减少~20%)
数值稳定性 均值可能放大噪声(尤其小批量) 避免均值偏移,稳定性更好(尤其大模型)
保留的信息 消除均值 + 方差偏移,改变分布中心 仅消除方差偏移,保留原始分布中心
梯度传播 均值计算引入额外依赖,梯度稍复杂 梯度路径更简洁,训练更高效

💡LayerNorm 是 “中心化 + 标准化”,RMSNorm 是 “仅标准化”。

浮点运算量(FLOPs)对比:

操作类型 LayerNorm(d 维向量) RMSNorm(d 维向量) 差异(RMSNorm 减少)
加法 d(求 μ) + d(求方差)= 2d d(求平方和)= d d 次加法
减法 d(中心化 xi - mu) 0(无中心化) d 次减法
乘法(含平方) d(方差中 xi²) d(RMS 中 xi²) 0 次
除法 2 次(μ = sum/d;σ² = sum/d) 1 次(sum/d) 1 次除法
开方 1 次(sqrt (σ²)) 1 次(sqrt (RMS)) 0 次
仿射变换 d(gamma 乘) + d(beta 加)= 2d d(gamma 乘)= d d 次加法

总运算量对比:

  • LayerNorm 总 FLOPs ≈ 4d + 3(简化后,忽略常数项)
  • RMSNorm 总 FLOPs ≈ 2d + 2(简化后,忽略常数项)
  • RMSNorm 比 LayerNorm 少 ≈ 2d 次运算(当 d 很大时,比如 d=1024,少 2048 次运算;d=4096 时,少 8192 次运算)

直观理解(d=1024 时的实际差异):

  • LayerNorm:21024(加法) + 1024(减法) + 21024(仿射)= 5120 次核心运算
  • RMSNorm:1024(加法) + 0(减法) + 1024(仿射)= 2048 次核心运算
  • 计算量减少约 60%(实际工程中因内存访问优化,实测减少~20%~30%,但仍显著优于 LayerNorm)

欢迎关注我的其它发布渠道