Skip to content

Better-Relative Positional Encoding

1. 故事背景:圖書館的藏書迷航

想像你是一座大型圖書館的管理員,館內藏書數以萬計。為了方便管理,每本書都有一個固定的「絕對位置」:例如「3 號書架,第 5 層,第 2 本」。這個編號就像 Transformer 最初使用的「絕對位置編碼」,每個詞(書本)都有一個獨一無二的座標。

有一天,一位讀者想找與某本書相關的參考文獻。他手裡拿著一本書,想找到書架上與它內容相近的其他書。如果只靠絕對位置,他必須知道每一本書的書架號,然後手動比對哪些書在附近。更糟的是,當圖書館擴建,增加了全新的 100 號書架時,管理員從未見過這個編號,完全不知道該如何引導讀者去找新書架上的書。這就是絕對位置編碼的困境:模型必須記住每個位置的具體數字,卻無法理解位置之間的「相對關係」,也無法應對從未見過的長度。

1-1 絕對位置的孤立感

  • 每個位置的向量都是獨立的,模型無法直接從「第 3 本」的位置推導出它與「第 5 本」的距離關係,就像只知道門牌號卻不知道兩棟房子相距多遠。
  • 當輸入序列長度超過訓練時的最大長度,模型就無法處理新出現的位置,如同圖書館突然冒出一個 100 號書架,管理員從未學過它的任何資訊,只好束手無策。

2. 解決的痛點:從粗略的相對距離到連續的空間感知

為了解決絕對位置的問題,研究人員提出了「相對位置編碼」。這就像在圖書館裡不再只靠書架號,而是引入一個「相對距離指示器」:當你拿著一本書時,指示器會告訴你另一本書與它相隔幾個書架。但最初的相對位置編碼有一個缺點:它通常只將距離分成幾個離散的區間,例如「相鄰」、「間隔 1-2 個書架」、「間隔 3-5 個書架」、「間隔 5 個以上」。這種分段雖然比絕對位置好,但仍然不夠精細,就像用「近、中、遠」來描述距離,失去了許多細節。

Better-Relative Positional Encoding 則更進一步,它將相對距離建模為一個連續且平滑的函數,就像給圖書館裝上了精密的 GPS 定位系統,可以精確到公分,而且距離越遠的書,其影響力會自然地、平滑地減弱。

2-1 更精細的相對距離感知

在傳統的相對位置編碼中,距離 2 和距離 3 可能被歸為同一類(例如「間隔 1-2 個書架」與「間隔 3-5 個書架」),導致模型無法區分它們的細微差異。但在自然語言中,間隔一個詞和間隔兩個詞的語義影響往往不同。Better-Relative 用連續函數取代離散區間,讓模型能夠感知到任意精度的距離變化。

例如,對於兩個位置 \(i\)\(j\),它們的相對距離 \(d = i - j\)。Better-Relative 使用一個連續的函數 \(f(d)\) 來表示它們的相對位置偏差,這個 \(f(d)\)\(d\) 的細微變化敏感,且通常是可學習的: $$ f(d) = \text{SomeContinuousFunction}(d) $$ 這樣,當 \(d\) 從 2 變成 3 時,\(f(d)\) 會平滑變化,而不是跳變。

2-2 自然的距離衰減效應

在語言中,相距越遠的詞,彼此的影響通常越小。傳統的離散區間很難體現這種平滑的衰減,往往在區間邊界處產生不連續。Better-Relative 透過設計 \(f(d)\) 讓它隨著 \(|d|\) 增大而逐漸趨近於零,例如: $$ \lim_{|d| \to \infty} f(d) = 0 $$ 這種設計讓模型能夠自動學習到長距離詞之間應該有較低的注意力權重,而不需要人工設定一個硬性的截斷閾值。

2-3 卓越的長度外推能力

因為 Better-Relative 學到的是相對距離 \(d\) 的連續函數,而不是針對特定位置的查詢表,所以即使遇到訓練時從未見過的長距離(例如 \(d = 200\)),模型仍然可以根據函數 \(f(d)\) 的趨勢給出合理的值。這就像一個學過連續距離感的人,即使被問到「相距 200 公尺的兩個物體應該有什麼關係」,也能根據常識推斷它們幾乎不相關。這種能力讓 Transformer 在處理超長序列時表現得更穩定,不再被訓練時的最大長度束縛。

3 實際案例運算

我們用一個具體的例子,一步步計算帶有 Better-Relative 位置編碼的自注意力機制。假設有一個句子包含 3 個詞,每個詞用長度為 4 的向量表示。我們將展示從詞嵌入到最終輸出的完整前向傳播,並在關鍵步驟解釋 Better-Relative 解決的痛點。

3-1 輸入與權重初始化

詞嵌入矩陣 \(X\) 形狀為 \(3 \times 4\)(3 個詞,嵌入維度 4):

\[ X\ (shape=3\times 4)= \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \]

為了生成查詢 Q、鍵 K、值 V,我們定義三個權重矩陣(為簡化計算,使用簡單整數):

\[ W_Q\ (shape=4\times 3)= \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 0 & 0 \\ 0 & 1 & 1 \end{bmatrix} \]
\[ W_K\ (shape=4\times 3)= \begin{bmatrix} 0 & 1 & 0 \\ 1 & 0 & 1 \\ 0 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} \]
\[ W_V\ (shape=4\times 3)= \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 0 & 1 \end{bmatrix} \]

3-1-1 計算 Q、K、V

計算 \(Q = X W_Q\)

\[ Q\ (shape=3\times 3)= X \cdot W_Q = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \cdot \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 0 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 2 & 0 & 1 \\ 0 & 2 & 1 \\ 1 & 1 & 1 \end{bmatrix} \]

計算 \(K = X W_K\)

\[ K\ (shape=3\times 3)= X \cdot W_K = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \cdot \begin{bmatrix} 0 & 1 & 0 \\ 1 & 0 & 1 \\ 0 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} = \begin{bmatrix} 0 & 1 & 1 \\ 2 & 1 & 1 \\ 1 & 1 & 1 \end{bmatrix} \]

計算 \(V = X W_V\)

\[ V\ (shape=3\times 3)= X \cdot W_V = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \cdot \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 0 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 0 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 0 \end{bmatrix} \]

3-2 計算未加位置信息的注意力分數

首先求 \(K\) 的轉置 \(K^T\)

\[ K^T\ (shape=3\times 3)= \begin{bmatrix} 0 & 2 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{bmatrix} \]

計算 \(S = Q K^T\)

\[ S\ (shape=3\times 3)= Q \cdot K^T = \begin{bmatrix} 2 & 0 & 1 \\ 0 & 2 & 1 \\ 1 & 1 & 1 \end{bmatrix} \cdot \begin{bmatrix} 0 & 2 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 5 & 3 \\ 3 & 3 & 3 \\ 2 & 4 & 3 \end{bmatrix} \]

此處的 \(S\) 僅基於詞義內容,完全忽略了詞與詞之間的位置關係。

3-3 加入 Better-Relative 位置偏置

Better-Relative 使用一個連續函數來生成位置偏置矩陣 \(B\),這裡我們選用 \(b_{ij} = -0.5 \times |i-j|\),其中 \(i, j\) 是詞的位置索引(從 0 開始)。計算所有 \(i, j\) 對應的 \(b_{ij}\)

  • \(i=0, j=0\): \(|0-0|=0 \Rightarrow b=0\)
  • \(i=0, j=1\): \(|0-1|=1 \Rightarrow b=-0.5\)
  • \(i=0, j=2\): \(|0-2|=2 \Rightarrow b=-1.0\)
  • \(i=1, j=0\): \(|1-0|=1 \Rightarrow b=-0.5\)
  • \(i=1, j=1\): \(|1-1|=0 \Rightarrow b=0\)
  • \(i=1, j=2\): \(|1-2|=1 \Rightarrow b=-0.5\)
  • \(i=2, j=0\): \(|2-0|=2 \Rightarrow b=-1.0\)
  • \(i=2, j=1\): \(|2-1|=1 \Rightarrow b=-0.5\)
  • \(i=2, j=2\): \(|2-2|=0 \Rightarrow b=0\)

得到 \(B\) 矩陣:

\[ B\ (shape=3\times 3)= \begin{bmatrix} 0 & -0.5 & -1.0 \\ -0.5 & 0 & -0.5 \\ -1.0 & -0.5 & 0 \end{bmatrix} \]

\(B\) 加到 \(S\) 上得到新的注意力分數 \(S'\)

\[ S'\ (shape=3\times 3)= S + B = \begin{bmatrix} 1 & 5 & 3 \\ 3 & 3 & 3 \\ 2 & 4 & 3 \end{bmatrix} + \begin{bmatrix} 0 & -0.5 & -1.0 \\ -0.5 & 0 & -0.5 \\ -1.0 & -0.5 & 0 \end{bmatrix} = \begin{bmatrix} 1 & 4.5 & 2 \\ 2.5 & 3 & 2.5 \\ 1 & 3.5 & 3 \end{bmatrix} \]

痛點:連續函數生成的 \(B\) 讓模型能根據精確的相對距離調整注意力分數,距離越遠懲罰越大,且變化平滑。相比離散區間(例如距離 1 和 2 使用相同偏置),連續設計避免了邊界跳變,使模型能感知到更細微的位置差異。

3-4 計算注意力權重(Softmax)

\(S'\) 的每一行進行 softmax 歸一化。先計算每行的指數值,再除以行和。

行 0: \([1, 4.5, 2]\)
指數: \(e^1 \approx 2.718\), \(e^{4.5} \approx 90.017\), \(e^2 \approx 7.389\)
\(\approx 100.124\)
softmax: \([2.718/100.124, 90.017/100.124, 7.389/100.124] \approx [0.0271, 0.899, 0.0738]\)

行 1: \([2.5, 3, 2.5]\)
指數: \(e^{2.5} \approx 12.182\), \(e^3 \approx 20.085\), \(e^{2.5} \approx 12.182\)
\(\approx 44.449\)
softmax: \([12.182/44.449, 20.085/44.449, 12.182/44.449] \approx [0.274, 0.452, 0.274]\)

行 2: \([1, 3.5, 3]\)
指數: \(e^1 \approx 2.718\), \(e^{3.5} \approx 33.115\), \(e^3 \approx 20.085\)
\(\approx 55.918\)
softmax: \([2.718/55.918, 33.115/55.918, 20.085/55.918] \approx [0.0486, 0.592, 0.359]\)

得到注意力權重矩陣 \(A\)

\[ A\ (shape=3\times 3)= \begin{bmatrix} 0.0271 & 0.899 & 0.0738 \\ 0.274 & 0.452 & 0.274 \\ 0.0486 & 0.592 & 0.359 \end{bmatrix} \]

痛點:透過 softmax,模型將連續的相對位置偏置轉化為對不同詞的關注程度,距離近的詞獲得更高權重(例如第 0 個詞對第 1 個詞的權重 0.899),距離遠的詞權重較低,且這種衰減是平滑的,符合真實語境中「遠距離詞影響較小」的直覺。

3-5 計算最終輸出

輸出 \(O = A V\)

\[ O\ (shape=3\times 3)= A \cdot V = \begin{bmatrix} 0.0271 & 0.899 & 0.0738 \\ 0.274 & 0.452 & 0.274 \\ 0.0486 & 0.592 & 0.359 \end{bmatrix} \cdot \begin{bmatrix} 1 & 0 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 0 \end{bmatrix} \]

逐行計算:

  • 第 0 行:
    \([0.0271\times1 + 0.899\times1 + 0.0738\times1,\ \ 0.0271\times0 + 0.899\times1 + 0.0738\times1,\ \ 0.0271\times1 + 0.899\times1 + 0.0738\times0]\)
    \(= [1,\ 0.9728,\ 0.9261]\)

  • 第 1 行:
    \([0.274\times1 + 0.452\times1 + 0.274\times1,\ \ 0.274\times0 + 0.452\times1 + 0.274\times1,\ \ 0.274\times1 + 0.452\times1 + 0.274\times0]\)
    \(= [1,\ 0.726,\ 0.726]\)

  • 第 2 行:
    \([0.0486\times1 + 0.592\times1 + 0.359\times1,\ \ 0.0486\times0 + 0.592\times1 + 0.359\times1,\ \ 0.0486\times1 + 0.592\times1 + 0.359\times0]\)
    \(= [1,\ 0.951,\ 0.6406]\)

最終輸出 \(O\)

\[ O\ (shape=3\times 3)= \begin{bmatrix} 1 & 0.9728 & 0.9261 \\ 1 & 0.726 & 0.726 \\ 1 & 0.951 & 0.6406 \end{bmatrix} \]

這個輸出矩陣已經融入了相對位置信息,每個詞的表示都根據與其他詞的距離進行了加權調整。相較於不使用位置編碼或使用離散位置編碼,Better-Relative 讓模型能夠以連續、平滑的方式感知距離,從而在長序列上表現出更強的泛化能力。

痛點:最終輸出保留了詞義內容,同時透過連續相對位置編碼,使模型能更細膩地捕捉語序結構,避免了離散編碼帶來的邊界突變和長度外推問題。