論文地址:https://arxiv.org/abs/2101.03961
作者:Google 團隊
引言#
在追求更強大、更智能的人工智能模型的道路上,“越大越好” 似乎已成為一條公認的法則。然而,模型的 “大” 往往伴隨著驚人的計算成本和訓練難度。傳統的深度學習模型在處理每個輸入時,都会動用其全部參數,這使得擴展模型規模變得異常昂貴。
“專家混合(Mixture of Experts, MoE)” 模型曾被視為一種有前景的解決方案 —— 它允許模型為不同的輸入激活不同的參數子集(即 “專家”),從而在擁有海量參數的同時保持計算成本恆定。但 MoE 模型的複雜性、高昂的通信開銷以及棘手的訓練不穩定性,使其難以得到廣泛應用。
谷歌研究團隊在論文《Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity》中帶來了突破。他們提出的 Switch Transformer 架構,通過一系列優雅而高效的設計,徹底改變了我們構建和訓練超大規模語言模型的方式。
本文將帶你深入了解 Switch Transformer 的核心思想:
- 它是如何簡化專家路由機制的?
- 它採用了哪些創新技術來克服訓練不穩定性,並首次實現了使用低精度(bfloat16)訓練大型稀疏模型?
- 這種架構帶來了多大的性能提升?
規模化困境:密集模型 vs 稀疏 MoE 範式#
傳統的密集型模型,以 Transformer 架構為例,自注意力 Module 的計算複雜度和存儲占用為 O (n2),隨著輸入序列的增加,計算操作成平方級增長,對於長文本任務來說對成本高昂,成為了 LLM 的顯著瓶頸。更為關鍵的是密集模型 (Dense Model) 的參數兩與每個樣本的計算成本 (FLOPs) 緊密耦合 -- 模型越大,計算越貴。混合專家(MoE)模型應運而生,它提出了一種富有遠見的 “稀疏激活” 策略:為每個輸入 token 動態選擇並激活網絡中的一小部分參數子集(即 “專家”)。這種設計的核心魅力在於,它允許模型參數規模可以擴展到極大(例如數萬億級別),而處理每個 token 的實際計算量卻能保持相對恆定,從而在理論上打破了 “模型越大、計算越貴” 的魔咒。然而,早期的 MoE 模型在實踐中卻面臨著模型設計複雜、專家間通信成本高昂以及訓練過程不穩定等一系列嚴峻挑戰。正是在這樣的背景下,Switch Transformer 登場,旨在直面並攻克這些難題,讓 MoE 的巨大潛力得以真正釋放。
架構的深度解析#
Switch 的創新:簡化的稀疏路由#
傳統的 MoE 一個 token 對應激活 Top k 個專家,這由超參數 k 決定。Switch Transformer 簡化了稀疏路由,一個 token 只對應激活一個專家。首先針對輸入詞元 (token) ,有路由參數(N 為專家數):
得到對應每一個專家的得分。然後經過 softmax 函數進行歸一化,得到路由到專家 i 的概率:
最終 MoE 層的輸出為:
其中 T 表示路由到的專家,E 為對應專家對 token 的運算操作。
Switch Transformer 的單專家路由相比傳統的 Top k 專家路由,減少了計算量,降低了專家批處理大小的需求並降低了通信開銷。
管理專家負載:容量與 Token 處理#
專家容量 (Expert Capacity) 是一個專家能處理的 token 數量,其衡量標準由公式計算:
其中為容量因子,用於幫助解決 token 溢出 (token overflow),默認為 1,其值越大表明緩衝能力越好,能夠容納更多的 token 運算。專家不能承載過多的 token 通過殘差連接 (Residual Connection) 直接傳遞到下一層 (也叫 token 丟棄)。同時需要注意的是,更大的容量因子雖然減少了丟棄的 token,但增加了計算和內存開銷。
確保公平性:可微分的負載均衡損失#
MoE 的路由機制 (即稀疏激活) 相比 Dense 模型有了很多的優勢,但是存在負載不均衡的情況:路由過程中可能存在某些專家被多次路由選中而部分專家存在無參與的情況,這樣違背了我們設計 N 個專家的初衷。所以需要一種策略進行負載均衡,讓每一個專家都能均衡參與到稀疏激活的運算中。Switch Transformer 提出了可微的輔助負載均衡損失函數(Auxiliary Load Balancing Loss)來解決這個問題。
,其中為控制損失的超參數,N 為專家總數,為實際專家 i 分配到的 token 比例,表示專家 i 被路由選中的概率
;
由損失函數可以直觀的看出損失越小表明負載越均衡,且當和相等為時,損失為。在論文中取為 0.01 時效果最好。
賦能穩定高效的大規模訓練#
混合精度訓練策略#
文中作者提出使用 bfloat16 精度訓練稀疏模型訓練非常不穩定,因此考慮使用混合精度訓練。參考先前的工作中通常使用 float32 精度訓練的情況,Switch Transformer 選擇使用 float32 和 bfloat16 混合訓練的策略:在路由函數中將輸入強制轉換為 float32 精度進行關鍵運算(如 softmax),但在昂貴的 all-to-all 通信前,將張量重新轉換成 bfloat16。實驗發現:在保持與 float32 精度的穩定性的同時獲得了接近 bfloat16 的訓練速度。
進一步探究穩定訓練:更小的參數初始化#
作者發現模型參數初始化對 Switch Transformer 的訓練穩定性有著關鍵的作用。通過階段正態分佈,即使用均值,標準差,其中 s 表示縮放超參數,取值為 0.1,n 是連接到該層(權重張量所應用的神經元)的輸入特徵數量。
微調中過擬合的抑制:Dropout#
文中特別強調了在下游任務訓練數據較少的情況下 Switch Transformer 出現過擬合的風險較高。因此採用了深度學習中使用最為廣泛的 Dropout,具體為:在 Expert Layer 中應用更高的 dropout 值,Non Expert Layer 引用較小的 dropout 值,並比較了 T5-Base 模型和 Switch-Base 模型使用不同的 dropout 值在 GLUE、CNNDM、SQuAD 和 SuperGLUE 數據集上微調的結果,最終得出結論。
模型能力躍遷:蒸餾的應用#
文中作者將大規模稀疏模型的能力蒸餾到小參數量的模型上,這一舉措對於提升大型稀疏模型在實際場景中的部署可行性至關重要。畢竟,動輒數千億乃至萬億參數的龐然大物,其推理成本和硬件需求是很多應用難以承受的。Switch Transformer 論文中探討的蒸餾方法主要包括:
- 權重初始化技巧:研究者們發現,在蒸餾時,使用稀疏教師模型中的 “非專家層” 權重來初始化密集學生模型的對應層,能夠帶來一定的性能提升。這是因為即便整體模型架構不同(稀疏 vs 密集),許多基礎的非專家層(如注意力機制的某些部分、嵌入層等)在功能上具有相似性,其訓練好的權重對學生模型而言是寶貴的先驗知識。
- 損失函數組合:蒸餾過程不僅僅依賴於學生模型在標準數據集上的表現(硬標籤),還會學習教師模型的輸出概率分佈(軟標籤)。論文發現,將教師模型的輸出概率與真實標籤(ground truth)以一定的比例混合(例如,0.25 的教師概率 + 0.75 的真實標籤損失)作為學生模型的優化目標,能取得更好的蒸餾效果。
通過這些蒸餾技術,Switch Transformer 成功地將大型稀疏模型的約 30% 的性能增益遷移到了參數量壓縮數十倍甚至近百倍的小型密集模型中。例如,一個擁有數十億參數的 Switch-Base 模型,經過蒸餾後,可以在參數量減少 95% 以上的情況下,學生模型仍能保留原教師模型性能提升的近三成。這一成果不僅適用於預訓練階段的語言模型,同樣在針對特定下游任務(如 SuperGLUE)微調後的模型蒸餾中也得到了驗證,極大地增強了這些先進稀疏架構的實用價值。
卓越成就:SOTA 性能速覽#
憑藉其創新的架構和訓練策略,Switch Transformer 不僅在模型規模和訓練效率上取得了突破,更是在多項自然語言處理基準測試中取得了當時的 SOTA(State-of-the-Art,頂尖水平)或極具競爭力的成績:
- 空前的規模與效率:論文成功訓練了參數量高達 1.6 萬億的 Switch-C 模型 和與 T5-XXL(110 億參數)計算量相當但擁有 3950 億參數的 Switch-XXL 模型。更重要的是,這些巨大的稀疏模型展現了極高的訓練效率,例如 Switch-XXL 相較於 T5-XXL 實現了 4 倍的預訓練速度提升,而針對 T5-Base 等模型,Switch Transformer 甚至能達到 7 倍的預訓練加速,且這些都是在相同的計算資源下完成的。
- 下游任務的強勁表現:經過預訓練的 Switch Transformer 模型在微調後,於眾多下游任務中展現了強大實力。無論是在 GLUE、SuperGLUE 等綜合性語言理解基準,還是在 SQuAD 等閱讀理解任務,以及 XSum 等文本摘要任務上,Switch 模型均顯著優於其 FLOPs 匹配的 T5 密集基線模型。特別是在知識密集型的閉卷問答任務(如 TriviaQA, Natural Questions, WebQuestions)上,Switch Transformer 也取得了 SOTA 或接近 SOTA 的表現。
- 多語言能力的普適提升:Switch Transformer 架構的優勢並不僅限於單一語言。在包含 101 種語言的 mC4 數據集上進行訓練的 mSwitch-Base 模型,在所有參與評估的語言上均超越了強大的 mT5-Base 基線,顯示了其架構在多任務和多語言學習中的普適性和高效性。
值得注意的是,論文也坦誠地指出,儘管在預訓練困惑度等指標上取得了巨大進步,但在當時,將這些優勢完全轉化為某些特定下游任務(尤其是複雜推理任務)的 SOTA 性能,對於最大規模的模型而言,仍是一個持續探索的領域。
總結與深遠影響#
Switch Transformer 論文無疑是深度學習,特別是大規模語言模型發展史上的一個重要里程碑。它通過以下幾個核心貢獻,深刻影響了後續的研究方向:
- 簡化並普及了 MoE 架構:通過引入 “單專家路由”(k=1)機制,極大地簡化了傳統 MoE 的複雜性,使其更易於理解、實現和訓練。
- 攻克了大規模稀疏模型訓練的關鍵難題:提出並驗證了一系列如選擇性精度訓練、小尺度參數初始化、負載均衡損失和專家 Dropout 等創新技術,顯著提升了訓練的穩定性和效率。
- 證明了稀疏化的巨大潛力:以無可辯駁的實驗結果展示了稀疏激活模型在達到數萬億參數規模、同時保持計算效率方面的巨大潛力,為 “大力出奇跡” 提供了新的、更經濟的路徑。
- 推動了開源生態與後續創新:相關代碼(如基於 Mesh TensorFlow 和後續的 T5X/JAX)的開放,極大地促進了學術界和工業界對 MoE 及相關稀疏技術的探索和應用。
對後續工作的影響
-
MoE 成為主流技術路線的關鍵推手:Switch Transformer 的成功使得 MoE 從一種相對小眾的探索方向,逐漸成為構建頂尖大型語言模型(如 Mixtral 系列及其他受此啟發的模型)的關鍵技術之一。
-
激發路由機制的持續創新:雖然單專家路由簡潔高效,但也激發了對更動態、更智能路由算法的研究,例如如何讓模型自適應學習專家容量、如何根據任務或數據特性優化令牌分配等。
-
深化對稀疏模型微調與部署的理解:論文中關於微調和蒸餾的初步探索,為後續如何高效利用和部署這些龐大的稀疏模型指明了方向,催生了更多關於專家特化、剪枝、量化等方面的研究。
-
大規模訓練經驗的普惠:其在解決訓練不穩定性、優化通信、利用混合精度等方面的經驗,也為整個大規模 AI 模型的訓練領域提供了寶貴的借鑒。
總而言之,Switch Transformer 不僅是一系列模型的名稱,更代表了一種設計哲學和一套行之有效的方法論。它雄辯地證明了,通過精心的工程設計和算法創新,我們完全有能力駕馭擁有海量知識的稀疏巨獸,並讓它們以更簡單、更高效的方式服務於日益複雜的智能應用。也為後續 DeepSeek 對 MoE 的負載均衡策略改進等研究奠定了基礎。