LayerNorm
LayerNorm(层归一化)是一种在深度学习中常用的归一化技术。
计算公式:
其中:
- x:输入向量。
- E[x]:输入向量均值。
- Var[x]:输入向量方差。
- γ:可学习的缩放参数。
- β:可学习的偏移参数。
优点:
- 完全归一化(均值 0,方差 1),对分布偏移的矫正更彻底;
- 偏移参数beta可灵活调整分布中心,适配更多任务(如 CV、NLP 基础模型)。
缺点:
- 均值计算增加计算量,大模型场景下效率偏低;
- 中心化可能破坏特征的 “绝对偏移信息”(如语言模型中词向量的固有偏移);
- 小批量 / 低维特征下,均值估计易受噪声影响,导致稳定性下降。
RMSNorm
RMSNorm( Root Mean Square Layer Normalization )核心思想是通过对输入向量进行缩放归一化,以提升训练稳定性和效率。
计算公式:
其中:
- x:输入向量。
- mean(x^2):向量元素的平方均值。
- ϵ:极小常数( 如 10e−8 ),防止分母为零。
- γ:可学习的缩放参数。
代码实现:
class RMSNorm(nn.Module): |
优点:
- 计算高效(少均值步骤),显存 / 算力开销更低,适配大模型(如 LLaMA、GPT-3);
- 保留特征的原始均值偏移,更贴合语言模型的语义特性(词向量的偏移包含语义信息);
- 无中心化带来的噪声放大,训练更稳定,收敛速度更快。
缺点:
- 仅归一化方差,对分布中心的偏移无矫正作用;
- 无偏移参数,适配性稍弱(但实践中可通过其他层补偿)。
关键区别
| 维度 | LayerNorm | RMSNorm |
|---|---|---|
| 中心化操作 | 有(减均值) | 无(直接用原始值计算均方根) |
| 可学习参数 | gamma(缩放)+ beta(偏移) | 仅gamma(缩放) |
| 计算复杂度 | 稍高(多一步均值计算) | 更低(少均值计算,浮点运算量减少~20%) |
| 数值稳定性 | 均值可能放大噪声(尤其小批量) | 避免均值偏移,稳定性更好(尤其大模型) |
| 保留的信息 | 消除均值 + 方差偏移,改变分布中心 | 仅消除方差偏移,保留原始分布中心 |
| 梯度传播 | 均值计算引入额外依赖,梯度稍复杂 | 梯度路径更简洁,训练更高效 |
💡LayerNorm 是 “中心化 + 标准化”,RMSNorm 是 “仅标准化”。
浮点运算量(FLOPs)对比:
| 操作类型 | LayerNorm(d 维向量) | RMSNorm(d 维向量) | 差异(RMSNorm 减少) |
|---|---|---|---|
| 加法 | d(求 μ) + d(求方差)= 2d | d(求平方和)= d | d 次加法 |
| 减法 | d(中心化 xi - mu) | 0(无中心化) | d 次减法 |
| 乘法(含平方) | d(方差中 xi²) | d(RMS 中 xi²) | 0 次 |
| 除法 | 2 次(μ = sum/d;σ² = sum/d) | 1 次(sum/d) | 1 次除法 |
| 开方 | 1 次(sqrt (σ²)) | 1 次(sqrt (RMS)) | 0 次 |
| 仿射变换 | d(gamma 乘) + d(beta 加)= 2d | d(gamma 乘)= d | d 次加法 |
总运算量对比:
- LayerNorm 总 FLOPs ≈ 4d + 3(简化后,忽略常数项)
- RMSNorm 总 FLOPs ≈ 2d + 2(简化后,忽略常数项)
- RMSNorm 比 LayerNorm 少 ≈ 2d 次运算(当 d 很大时,比如 d=1024,少 2048 次运算;d=4096 时,少 8192 次运算)
直观理解(d=1024 时的实际差异):
- LayerNorm:21024(加法) + 1024(减法) + 21024(仿射)= 5120 次核心运算
- RMSNorm:1024(加法) + 0(减法) + 1024(仿射)= 2048 次核心运算
- 计算量减少约 60%(实际工程中因内存访问优化,实测减少~20%~30%,但仍显著优于 LayerNorm)