论文地址:https://arxiv.org/abs/2101.03961
作者:Google 团队
引言#
在追求更强大、更智能的人工智能模型的道路上,“越大越好” 似乎已成为一条公认的法则。然而,模型的 “大” 往往伴随着惊人的计算成本和训练难度。传统的深度学习模型在处理每个输入时,都会动用其全部参数,这使得扩展模型规模变得异常昂贵。
“专家混合(Mixture of Experts, MoE)” 模型曾被视为一种有前景的解决方案 —— 它允许模型为不同的输入激活不同的参数子集(即 “专家”),从而在拥有海量参数的同时保持计算成本恒定。但 MoE 模型的复杂性、高昂的通信开销以及棘手的训练不稳定性,使其难以得到广泛应用。
谷歌研究团队在论文《Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity》中带来了突破。他们提出的 Switch Transformer 架构,通过一系列优雅而高效的设计,彻底改变了我们构建和训练超大规模语言模型的方式。
本文将带你深入了解 Switch Transformer 的核心思想:
- 它是如何简化专家路由机制的?
- 它采用了哪些创新技术来克服训练不稳定性,并首次实现了使用低精度(bfloat16)训练大型稀疏模型?
- 这种架构带来了多大的性能提升?
规模化困境:密集模型 vs 稀疏 MoE 范式#
传统的密集型模型,以 Transformer 架构为例,自注意力 Module 的计算复杂度和存储占用为 O (n2),随着输入序列的增加,计算操作成平方级增长,对于长文本任务来说对成本高昂,成为了 LLM 的显著瓶颈。更为关键的是密集模型 (Dense Model) 的参数两与每个样本的计算成本 (FLOPs) 紧密耦合 -- 模型越大,计算越贵。混合专家(MoE)模型应运而生,它提出了一种富有远见的 “稀疏激活” 策略:为每个输入 token 动态选择并激活网络中的一小部分参数子集(即 “专家”)。这种设计的核心魅力在于,它允许模型参数规模可以扩展到极大(例如数万亿级别),而处理每个 token 的实际计算量却能保持相对恒定,从而在理论上打破了 “模型越大、计算越贵” 的魔咒。 然而,早期的 MoE 模型在实践中却面临着模型设计复杂、专家间通信成本高昂以及训练过程不稳定等一系列严峻挑战。正是在这样的背景下,Switch Transformer 登场,旨在直面并攻克这些难题,让 MoE 的巨大潜力得以真正释放。
架构的深度解析#
Switch 的创新:简化的稀疏路由#
传统的 MoE 一个 token 对应激活 Top k 个专家,这由超参数 k 决定。Switch Transformer 简化了稀疏路由,一个 token 只对应激活一个专家。首先针对输入词元 (token) ,有路由参数(N 为专家数):
得到对应每一个专家的得分。然后经过 softmax 函数进行归一化,得到路由到专家 i 的概率:
最终 MoE 层的输出为:
其中 T 表示路由到的专家,E 为对应专家对 token 的运算操作。
Switch Transformer 的单专家路由相比传统的 Top k 专家路由,减少了计算量,降低了专家批处理大小的需求并降低了通信开销。
管理专家负载:容量与 Token 处理#
专家容量 (Expert Capacity) 是一个专家能处理的 token 数量,其衡量标准由公式计算:
其中为容量因子,用于帮助解决 token 溢出 (token overflow),默认为 1,其值越大表明缓冲能力越好,能够容纳更多的 token 运算。专家不能承载过多的 token 通过残差连接 (Residual Connection) 直接传递到下一层 (也叫 token 丢弃)。同时需要注意的是,更大的容量因子虽然减少了丢弃的 token,但增加了计算和内存开销。
确保公平性:可微分的负载均衡损失#
MoE 的路由机制 (即稀疏激活) 相比 Dense 模型有了很多的优势,但是存在负载不均衡的情况:路由过程中可能存在某些专家被多次路由选中而部分专家存在无参与的情况,这样违背了我们设计 N 个专家的初衷。所以需要一种策略进行负载均衡,让每一个专家都能均衡参与到稀疏激活的运算中。Switch Transofrmer 提出了可微的辅助负载均衡损失函数(Auxiliary Load Balancing Loss)来解决这个问题。
,其中为控制损失的超参数,N 为专家总数,为实际专家 i 分配到的 token 比例,表示专家 i 被路由选中的概率
;
由损失函数可以直观的看出损失越小表明负载越均衡,且当和相等为时,损失为。在论文中取为 0.01 时效果最好。
赋能稳定高效的大规模训练#
混合精度训练策略#
文中作者提出使用 bfloat16 精度训练稀疏模型训练非常不稳定,因此考虑使用混合精度训练。参考先前的工作中通常使用 float32 精度训练的情况,Switch Transformer 选择使用 float32 和 bfloat16 混合训练的策略:在路由函数中将输入强制转换为 float32 精度进行关键运算(如 softmax),但在昂贵的 all-to-all 通信前,将张量重新转换成 bfloat16。实验发现:在保持与 float32 精度的稳定性的同时获得了接近 bfloat16 的训练速度。
进一步探究稳定训练:更小的参数初始化#
作者发现模型参数初始化对 Switch Transformer 的训练稳定性有着关键的作用。通过阶段正态分布,即使用均值,标准差,其中 s 表示缩放超参数,取值为 0.1,n 是连接到该层(权重张量所应用的神经元)的输入特征数量。
微调中过拟合的抑制:Dropout#
文中特别强调了在下游任务训练数据较少的情况下 Switch Transformer 出现过拟合的风险较高。因此采用了深度学习中使用最为广泛的 Dropout,具体为:在 Expert Layer 中应用更高的 dropout 值,Non Expert Layer 引用较小的 dropout 值,并比较了 T5-Base 模型和 Switch-Base 模型使用不同的 dropout 值在 GLUE、CNNDM、SQuAD 和 SuperGLUE 数据集上微调的结果,最终得出结论。
模型能力跃迁:蒸馏的应用#
文中作者将大规模稀疏模型的能力蒸馏到小参数量的模型上,这一举措对于提升大型稀疏模型在实际场景中的部署可行性至关重要 。毕竟,动辄数千亿乃至万亿参数的庞然大物,其推理成本和硬件需求是很多应用难以承受的。Switch Transformer 论文中探讨的蒸馏方法主要包括:
- 权重初始化技巧:研究者们发现,在蒸馏时,使用稀疏教师模型中的 “非专家层” 权重来初始化密集学生模型的对应层,能够带来一定的性能提升 。这是因为即便整体模型架构不同(稀疏 vs 密集),许多基础的非专家层(如注意力机制的某些部分、嵌入层等)在功能上具有相似性,其训练好的权重对学生模型而言是宝贵的先验知识。
- 损失函数组合:蒸馏过程不仅仅依赖于学生模型在标准数据集上的表现(硬标签),还会学习教师模型的输出概率分布(软标签)。论文发现,将教师模型的输出概率与真实标签(ground truth)以一定的比例混合(例如,0.25 的教师概率 + 0.75 的真实标签损失)作为学生模型的优化目标,能取得更好的蒸馏效果 。
通过这些蒸馏技术,Switch Transformer 成功地将大型稀疏模型的约 30% 的性能增益迁移到了参数量压缩数十倍甚至近百倍的小型密集模型中 。例如,一个拥有数十亿参数的 Switch-Base 模型,经过蒸馏后,可以在参数量减少 95% 以上的情况下,学生模型仍能保留原教师模型性能提升的近三成 。这一成果不仅适用于预训练阶段的语言模型,同样在针对特定下游任务(如 SuperGLUE)微调后的模型蒸馏中也得到了验证 ,极大地增强了这些先进稀疏架构的实用价值。
卓越成就:SOTA 性能速览#
凭借其创新的架构和训练策略,Switch Transformer 不仅在模型规模和训练效率上取得了突破,更是在多项自然语言处理基准测试中取得了当时的 SOTA(State-of-the-Art,顶尖水平)或极具竞争力的成绩:
- 空前的规模与效率:论文成功训练了参数量高达 1.6 万亿的 Switch-C 模型 和与 T5-XXL(110 亿参数)计算量相当但拥有 3950 亿参数的 Switch-XXL 模型 。更重要的是,这些巨大的稀疏模型展现了极高的训练效率,例如 Switch-XXL 相较于 T5-XXL 实现了 4 倍的预训练速度提升 ,而针对 T5-Base 等模型,Switch Transformer 甚至能达到 7 倍的预训练加速,且这些都是在相同的计算资源下完成的 。
- 下游任务的强劲表现:经过预训练的 Switch Transformer 模型在微调后,于众多下游任务中展现了强大实力。无论是在 GLUE、SuperGLUE 等综合性语言理解基准,还是在 SQuAD 等阅读理解任务,以及 XSum 等文本摘要任务上,Switch 模型均显著优于其 FLOPs 匹配的 T5 密集基线模型 。特别是在知识密集型的闭卷问答任务(如 TriviaQA, Natural Questions, WebQuestions)上,Switch Transformer 也取得了 SOTA 或接近 SOTA 的表现 。
- 多语言能力的普适提升:Switch Transformer 架构的优势并不仅限于单一语言。在包含 101 种语言的 mC4 数据集上进行训练的 mSwitch-Base 模型,在 所有参与评估的语言上均超越了强大的 mT5-Base 基线 ,显示了其架构在多任务和多语言学习中的普适性和高效性。
值得注意的是,论文也坦诚地指出,尽管在预训练困惑度等指标上取得了巨大进步,但在当时,将这些优势完全转化为某些特定下游任务(尤其是复杂推理任务)的 SOTA 性能,对于最大规模的模型而言,仍是一个持续探索的领域 。
总结与深远影响#
Switch Transformer 论文无疑是深度学习,特别是大规模语言模型发展史上的一个重要里程碑。它通过以下几个核心贡献,深刻影响了后续的研究方向:
- 简化并普及了 MoE 架构:通过引入 “单专家路由”(k=1)机制,极大地简化了传统 MoE 的复杂性,使其更易于理解、实现和训练 。
- 攻克了大规模稀疏模型训练的关键难题:提出并验证了一系列如选择性精度训练、小尺度参数初始化、负载均衡损失和专家 Dropout 等创新技术,显著提升了训练的稳定性和效率 。
证明了稀疏化的巨大潜力:以无可辩驳的实验结果展示了稀疏激活模型在达到数万亿参数规模、同时保持计算效率方面的巨大潜力,为 “大力出奇迹” 提供了新的、更经济的路径。 - 推动了开源生态与后续创新:相关代码(如基于 Mesh TensorFlow 和后续的 T5X/JAX)的开放,极大地促进了学术界和工业界对 MoE 及相关稀疏技术的探索和应用。
对后续工作的影响
-
MoE 成为主流技术路线的关键推手:Switch Transformer 的成功使得 MoE 从一种相对小众的探索方向,逐渐成为构建顶尖大型语言模型(如 Mixtral 系列及其他受此启发的模型)的关键技术之一。
-
激发路由机制的持续创新:虽然单专家路由简洁高效,但也激发了对更动态、更智能路由算法的研究,例如如何让模型自适应学习专家容量、如何根据任务或数据特性优化令牌分配等。
-
深化对稀疏模型微调与部署的理解:论文中关于微调和蒸馏的初步探索,为后续如何高效利用和部署这些庞大的稀疏模型指明了方向,催生了更多关于专家特化、剪枝、量化等方面的研究。
-
大规模训练经验的普惠:其在解决训练不稳定性、优化通信、利用混合精度等方面的经验,也为整个大规模 AI 模型的训练领域提供了宝贵的借鉴。
总而言之,Switch Transformer 不仅是一系列模型的名称,更代表了一种设计哲学和一套行之有效的方法论。它雄辩地证明了,通过精心的工程设计和算法创新,我们完全有能力驾驭拥有海量知识的稀疏巨兽,并让它们以更简单、更高效的方式服务于日益复杂的智能应用。也为后续 DeepSeek 对 MoE 的负载均衡策略改进等研究奠定了基础。