RMSNorm
Prerequisite knowledge 先備知識
- 了解 Transformer 的殘差連接(residual connection)與「很深的堆疊」會讓訓練變得不穩定
- 知道常見的正規化(normalization)概念:把向量「縮放到比較穩定的尺度」
- 會看懂基本符號:平均值 \(\mu\)、方差 \(\sigma^2\)、維度 \(d\)、小常數 \(\epsilon\)
1. 故事背景
1-1 深度模型的老問題:訊號尺度在路上越走越失控
想像你在玩傳話遊戲:每一層都會把訊息「加工」一次,再用殘差把原訊息加回來。層數一多,訊號的整體大小(向量長度、能量)很容易忽大忽小,導致:
- 梯度不穩,訓練忽然爆掉或收斂很慢
- 不同 batch、不同 token 的激活尺度差異很大,優化器很難用同一套步伐學習
於是大家常在每個子模組前後加 LayerNorm,讓每一步的尺度更可控,Transformer 訓練才穩得住。
1-2 LayerNorm 很好用,但「每一步都要算平均與方差」很貴
LayerNorm 的核心是:先把向量減掉平均值(re-centering),再用標準差做縮放(re-scaling)。
在 2019 年,Zhang 與 Sennrich 提出一個關鍵觀察:LayerNorm 的「減平均(re-centering)不一定是必要的」,而且這一步會帶來額外計算與實作成本;他們因此提出 RMSNorm(Root Mean Square Layer Normalization)。 (arXiv)
1-3 RMSNorm 的想法:只做「縮放」,不做「搬到零均值」
RMSNorm 用 RMS(均方根)衡量向量大小,做純粹的縮放:
作者主張:很多網路結構只需要「對尺度不敏感」(re-scaling invariance)就夠了,不一定需要「對平移不敏感」(re-centering invariance);RMSNorm 因此更簡單也更快,並在多種模型上達到和 LayerNorm 相近的效果,還報告了不同模型上可觀的加速幅度。 (arXiv)
2. 解決的痛點
2-1 痛點一:LayerNorm 的額外成本,會拖慢大型模型的吞吐
把「減平均、算方差、再開根號」想成每一層都要多做一套統計工作;模型越大、層越深、序列越長,這個成本越顯眼。
RMSNorm 省掉「減平均」與與其相關的一些計算與反傳複雜度,因此更容易做出高效 kernel;也因此在現代 LLM 堆疊裡很常見(例如很多模型採用 pre-norm + RMSNorm 的配置)。 (arXiv)
2-2 痛點二:你不一定想把「平均值」洗掉,尤其在殘差訊號裡
用一個直覺例子來看差異。假設某層輸出向量 \(x=[3,4]\),\(d=2\):
- RMSNorm 只看能量大小
- LayerNorm 會先減平均變成零均值
你可以把「向量的平均值」想成殘差訊號中的某種共同偏移量:LayerNorm 會把它直接扣掉;RMSNorm 則保留方向,只把整體大小調到穩定範圍。近年的幾何觀點也強調:LayerNorm 的定義和「全 1 方向」有內在關聯,而 RMSNorm 的行為更像純粹控制向量長度。 (arXiv)
2-3 痛點三:訓練時尺度亂飄,RMSNorm 像「自動音量控制」
把每層輸出想成麥克風音量:
- 音量太大會爆音(梯度爆炸、數值不穩)
- 音量太小會聽不到(梯度消失、學不動)
RMSNorm 透過 \(\mathrm{RMS}(x)\) 把輸出自動縮放回「差不多的音量」,再用可學參數 \(g\) 決定每個維度應該放大或縮小多少;原論文也提到這種機制能帶來類似「隱式學習率調節」的效果,幫助穩定優化。 (arXiv)
2-4 小結:RMSNorm 到底解了什麼
- 更便宜的正規化:少一步 re-centering,實作與計算更簡單 (arXiv)
- 保留殘差訊號的某些特性:不強制零均值,主要控制「尺度」 (arXiv)
- 符合現代 LLM 的工程趨勢:大量模型採用 RMSNorm,甚至出現針對 RMSNorm 的更快實作研究 (arXiv)
3(實際案例運算)
3-1 真實場景設定:聊天機器人在 Transformer block 進入子模組前先做 RMSNorm
3-1-1 輸入 hidden states(2 個 token,hidden size=4)
我們把一句話切成 2 個 token(例如 token1=「今天」、token2=「下雨」),每個 token 都有 4 維的 hidden state。以 row 表示 token、以 column 表示特徵維度。
痛點:把多個 token 的表示整理成一個可批次處理的矩陣,讓後續層能用相同流程穩定處理序列資料。
3-1-2 設定 RMSNorm 的可學縮放參數 \(g\)(elementwise affine 的 scale)
RMSNorm 會用一個可學向量 \(g\)(每個 column 一個縮放係數)來調整每個維度的重要性;這裡先給一組「已訓練到一半」的數值。
痛點:保留「可學的幅度調整」,避免正規化把所有維度都強迫變得一樣重要,維持表徵能力。 (docs.pytorch.org)
3-2 計算每個 token 的 RMS(只對最後一個維度做)
3-2-1 RMSNorm 的核心公式(前向)
對每一個 row(每個 token 向量)計算:
並做(先除 RMS,再乘上 \(g\)):
這裡 \(d=4\),取 \(\epsilon=10^{-5}\)。 (docs.pytorch.org)
痛點:只用 RMS 控制「尺度」就能讓每層輸入幅度更穩定,同時比 LayerNorm 少做「減平均」等操作,降低正規化的計算負擔。 (arXiv)
3-2-2 把每個 token 的 RMS 算成一個 column 向量
token1(row1)平方和平均後再開根號:
- row1:\([2.0,-1.0,0.0,3.0]\)
- \(\frac{1}{4}(2.0^2+(-1.0)^2+0.0^2+3.0^2)=\frac{1}{4}(4+1+0+9)=3.5\)
- \(\mathrm{RMS}(row1)=\sqrt{3.5+10^{-5}}\approx 1.870831\)
token2(row2):
- row2:\([0.5,1.5,-2.0,1.0]\)
- \(\frac{1}{4}(0.5^2+1.5^2+(-2.0)^2+1.0^2)=\frac{1}{4}(0.25+2.25+4+1)=1.875\)
- \(\mathrm{RMS}(row2)=\sqrt{1.875+10^{-5}}\approx 1.369310\)
痛點:把每個 token 的向量長度量化成單一尺度因子,後面就能用同一套規則把不同 token 的幅度拉回可控範圍,避免激活值忽大忽小造成訓練不穩。
3-3 做 RMS 正規化(逐 row 除以 RMS)
3-3-1 逐 row 做 \(\frac{X}{\mathrm{RMS}(X)}\)(elementwise 除法)
計算後(四捨五入到小數點後 6 位):
痛點:讓每個 token 的向量整體幅度被「自動音量控制」,降低梯度爆炸或消失的風險,使深層 Transformer 更容易穩定訓練。 (arXiv)
3-3-2 乘上可學縮放 \(g\)(elementwise 乘法)
計算後:
痛點:在「穩定尺度」的同時,允許模型把某些 column 放大或縮小,避免過度正規化導致資訊被壓扁、表徵能力下降。 (docs.pytorch.org)
3-4 接到真實子模組:線性投影(例如注意力的 Q/K/V 或 FFN 的第一層)
3-4-1 設定線性層權重矩陣 \(W\)
這裡用一個把 hidden size=4 投影到 3 維的線性層(真實 Transformer 裡可能是 attention 的投影或 FFN 的投影)。
痛點:把已經尺度穩定的 token 表徵映射到任務需要的子空間,讓後續模組能用更合適的維度組合做運算(例如注意力打分或非線性變換)。
3-4-2 做矩陣乘法:\(Y\cdot W=Z\)
痛點:因為前面 RMSNorm 已經把每個 token 的幅度穩住,線性層比較不會遇到「輸入尺度飄動」導致輸出突然過大或過小,讓整個 block 的前向訊號更可控、更穩定。 (arXiv)