0%

Switch Transformers:用简单高效的稀疏性扩展到万亿参数模型 — 深度技术评审

1. 这篇论文做了什么

想象一下:你有一座巨大的图书馆,里面有上千位专业图书管理员,每位都是不同领域的专家。当你带着一个关于古罗马水渠的问题走进去时,你不需要问遍每一位管理员(那太花时间了),而是被一个智能导航系统瞬间引导到最了解罗马工程学的那位管理员面前。你用问一个人的时间就得到了专家级别的答案。这就是 Switch Transformer 的核心思想。

在深度学习的世界里,传统模型用全部参数来处理每一个输入。一个拥有 110 亿参数的模型,对它读到的每一个词都会动用全部 110 亿个权重。Switch Transformer 彻底颠覆了这种做法:它拥有海量参数——最多达到 1.6 万亿——但对于任意一个输入的 token,只有一小部分参数被激活。其余的参数处于休眠状态,等待需要它们特定专长的输入。

这篇论文由 Google 的 William Fedus、Barret Zoph 和 Noam Shazeer 撰写,提出了 Switch Transformer 架构。它简化了早期的混合专家(Mixture-of-Experts, MoE)方法,核心改动是把每个 token 的路由目标从两个或更多专家减少到仅仅一个专家。这看似微小的改变带来了巨大的实践好处:更低的计算开销、更简单的实现、更少的通信成本和更好的训练稳定性。

实验结果令人瞩目:

  • 在相同计算预算下,比 T5-Base 模型实现 7 倍预训练加速
  • 比规模大得多的 T5-XXL 模型实现 4 倍加速
  • 在多语言预训练中,全部 101 种语言均获得提升
  • 成功扩展到 1.6 万亿参数
  • 可以将稀疏模型蒸馏回密集模型,在 99% 压缩率下仍保留约 30% 的质量增益

这篇论文是现代大语言模型设计的基石。Mixtral、DeepSeek-V2 等模型以及 Google 的许多生产系统都直接建立在这里提出的思想之上。理解 Switch Transformer 对于任何从事高效 AI 系统工作的人来说都至关重要。


2. 前置知识:你需要先了解什么

在我们深入 Switch Transformer 本身之前,让我们仔细梳理你需要的所有背景知识。即使其中一些概念你已经熟悉,这一节的设计也力求全面,因为每一块知识都是理解 Switch Transformer 为何有效、为何重要的基础。

2.1 Transformer 架构

Transformer 由 Vaswani 等人在 2017 年提出,是几乎所有现代大语言模型(GPT、BERT、T5、PaLM、LLaMA 等)的骨架。其核心是由一组相同结构的层堆叠而成,每一层包含两个主要子组件:

  1. 自注意力层(Self-Attention Layer):让序列中的每个词(token)都能"看到"其他所有词来理解上下文。例如在句子"猫坐在垫子上因为它累了"中,自注意力机制帮助模型理解"它"指的是"猫"而不是"垫子"。

  2. 前馈网络层(Feed-Forward Network, FFN):注意力计算之后,每个 token 独立地通过一个小型神经网络(通常是两个线性变换中间夹一个非线性激活函数)。这是模型"知识"存储的主要位置——FFN 层就像一个记忆库。

每个子组件都包裹了残差连接(输入被加回输出)和层归一化(数值被重新缩放以保持稳定的统计特性)。这种架构简洁、高度可并行化,且扩展性极好。

本文的关键洞察:FFN 层对每个 token 的处理是完全独立的——在这一步中 token 之间不存在交互。正是这种独立性使得将不同 token 路由到不同的 FFN "专家"成为可能,而不会破坏整体计算。

2.2 T5 模型(Text-to-Text Transfer Transformer)

T5 由 Raffel 等人在 2019 年提出,是 Switch Transformer 的直接构建基础。T5 将每个 NLP 任务都转化为文本到文本的问题:无论你做的是翻译、摘要、问答还是分类,输入和输出都被统一格式化为文本字符串。

T5 有多种尺寸:

  • T5-Small:6000 万参数
  • T5-Base:2.23 亿参数
  • T5-Large:7.39 亿参数
  • T5-XL:30 亿参数
  • T5-XXL:110 亿参数

Switch Transformer 论文主要与 T5-Base 和 T5-Large 做对照实验,与 T5-XXL 做最大规模的对比。理解 T5 的角色很重要,因为所有 Switch Transformer 变体都被设计为与其 T5 对应模型 FLOP 匹配——即每个 token 使用相同的计算量,只是参数量大幅增加。

2.3 什么是 FLOPs?为什么重要?

FLOPs(浮点运算次数)衡量的是处理数据通过神经网络所需的计算量。当我们说两个模型是"FLOP 匹配"的时候,意思是它们对每个输入 token 执行的算术运算次数相同。

这对公平比较至关重要。如果一个 70 亿参数的 Switch Transformer 比 2.23 亿参数的 T5-Base 表现更好,你可能会想:"当然了——它参数多了 30 倍!"但 Switch Transformer 使用的计算量完全相同,因为对于任意给定的 token,那 70 亿参数中的绝大多数都处于休眠状态。只有被选中的那个专家的参数被激活。

打个比方:FLOP 匹配的比较就像比较两个用电量相同的工厂。一个工厂有一条全天运行的通用生产线。另一个工厂有 128 条专用生产线,但每次只根据需要运行其中一条。它们用的电一样多,但专用生产线的工厂可能生产出质量更高的产品,因为每条生产线都针对特定任务做了优化。

2.4 混合专家模型(MoE):Switch Transformer 的前身

混合专家概念最早追溯到 1991 年(Jacobs 等人),是 Switch Transformer 的直接前身。它是这样工作的:

基本思想:不用单一的 FFN 层处理所有 token,而是设置 N 个独立的 FFN 层(称为"专家"),并用一个路由器来决定每个 token 应该由哪个(些)专家来处理。

路由器:路由器是一个小型神经网络(通常是一个线性层加 softmax),以 token 的表征作为输入,输出一个在所有 N 个专家上的概率分布。然后 token 被发送到概率最高的专家。

传统 MoE(Top-k 路由):在 Shazeer 等人(2017)的原始方案中,每个 token 被路由到前 k 个专家(通常 k=2)。被选中专家的输出按路由器分配的概率进行加权求和。

数学表达为,对于 token x,路由器计算:

pi(x)=eh(x)ij=1Neh(x)jp_i(x) = \frac{e^{h(x)_i}}{\sum_{j=1}^{N} e^{h(x)_j}}

其中 h(x) = W_r · x 是路由器的 logits。输出为:

y=itop-kpi(x)Ei(x)y = \sum_{i \in \text{top-k}} p_i(x) \cdot E_i(x)

为什么 MoE 没有被广泛采用:尽管有不错的成果,MoE 模型有三大问题:

  1. 复杂度高:路由到多个专家、合并输出、管理路由逻辑——整个流程很复杂。
  2. 通信开销大:在分布式训练中,token 必须通过网络发送到不同设备上的专家,产生昂贵的 all-to-all 通信。
  3. 训练不稳定:硬路由决策在训练过程中产生不连续性,导致不稳定,在大规模时尤其严重。

2.5 负载均衡:为什么重要

专家模型面临的最大挑战之一是负载不均衡。想象你有 128 个专家,但路由器学会把 90% 的 token 都发送到其中 3-4 个专家身上。剩下的 124 个专家闲置,浪费了内存和参数。更糟糕的是,热门专家不堪重负,导致 token 丢弃(无法处理的 token 被直接跳过)。

为了防止这种情况,MoE 模型使用辅助负载均衡损失——训练目标中的一个附加项,当 token 在专家间的分布不均匀时对模型进行惩罚。这个损失轻柔地推动路由器更均匀地分配 token。

调好这个平衡至关重要:负载均衡的压力太大,会覆盖路由器的自然专业化;太小则会出现灾难性的不均衡。

2.6 分布式训练与并行策略

训练大语言模型需要将工作分配到多个加速器(GPU 或 TPU)上。有几种实现方式:

  • 数据并行:每个加速器处理不同批次的数据,但各自持有模型的完整副本。每步训练后在所有加速器间平均梯度。简单但受限于模型大小。
  • 模型并行:模型参数被拆分到各加速器上。每个加速器只持有部分权重。前向和反向传播过程中需要通信。
  • 专家并行:MoE 模型特有——每个加速器持有一个或多个专家。token 通过 all-to-all 通信被路由到相应的加速器。
  • 流水线并行:模型的不同层运行在不同的加速器上,数据像流水线一样流过它们。

Switch Transformer 论文的一个亮点是对如何有效组合这些并行策略做了深入分析。

2.7 精度格式:float32、bfloat16 与混合精度

神经网络使用浮点数进行运算。标准格式 float32 每个数字占 32 位,精度高但消耗大量内存和计算资源。bfloat16(brain floating-point 16)只用 16 位,减半内存使用且提高速度,但代价是数值精度下降。

混合精度训练在大部分操作中使用 bfloat16,但在数值敏感的计算中切换到 float32。这是大模型训练的标准实践。

Switch Transformer 在这里面临特殊挑战:路由器的 softmax 计算对数值精度高度敏感。路由概率中的微小误差可能导致 token 被发送到错误的专家,从而使训练不稳定。论文的解决方案——选择性精度——是其关键技术贡献之一。

2.8 知识蒸馏

蒸馏是一种利用大型强力模型("教师")来训练小型可部署模型("学生")的技术。学生不仅在真实标签上训练,还要学习匹配教师的输出概率分布。

直觉是教师的概率分布包含"暗知识"——它不仅告诉学生正确答案,还告诉学生对其他选项应有多大信心。例如,如果语言模型看到"法国的首都是___",教师可能对"巴黎"赋予 95% 的概率,对"里昂"赋予 3%,对"马赛"赋予 1%。这些软目标传达了硬标签(仅"巴黎")无法传达的有用信息。

蒸馏对 MoE 模型尤其重要,因为其巨大的参数量使直接部署不切实际。能够将万亿参数的稀疏模型压缩为可管理的密集模型,同时保留大部分质量增益,是核心卖点。


3. Switch Transformer:核心方法

现在我们有了所有背景知识,让我们详细考察 Switch Transformer 的设计。

3.1 关键洞察:路由到一个专家,而非两个

Switch Transformer 最重要的设计选择是将每个 token 路由到恰好一个专家(top-1 路由),而非传统的 top-2 或 top-k。

之前的工作(Shazeer 等人,2017)推测路由到 k > 1 个专家对于路由器在训练中获得有意义的梯度是必要的。理由是:如果只路由到一个专家,模型就无法"比较"专家,也就无法学习哪些专家更适合哪些 token。Ramachandran 和 Le(2018)甚至发现在较低层中更高的 k 值特别重要。

Fedus 等人正面挑战了这一假设。他们证明 top-1 路由不仅可行——实际上效果更好。好处有三个方面:

  1. 路由计算减少:计算 top-1 选择比 top-2 更便宜。
  2. 专家容量减半:每个 token 只去一个专家(而非两个),每个专家需要的批次缓冲区减半,节省内存。
  3. 通信简化:只需要一组 token 到专家的传输而非两组,减少 all-to-all 通信开销。

3.2 架构:专家放在哪里?

Switch Transformer 将交替的 Transformer 层中的 FFN 层替换为 Switch FFN 层。具体流程如下:

  1. 一个 token x 经过自注意力和层归一化后进入该层。
  2. 路由器计算在 N 个专家上的概率分布:p(x) = softmax(W_r · x)。
  3. token 被发送到概率最高的专家:i* = argmax p(x)。
  4. 被选中的专家 E_{i*} 处理该 token(这就是一个标准的 FFN 计算)。
  5. 输出被乘以路由器的门控值 p_{i*}(x),按路由器的置信度缩放专家输出。
  6. 结果通过残差连接传递下去。

这个门控值的乘法很重要——它使得路由对路由器参数是可微分的,尽管 argmax 选择本身不可微。梯度通过门控值流动,使路由器能够学习。

放置位置:在标准的 Switch Transformer 配置中,专家替换每隔一层的 FFN。其他层使用标准(密集)FFN。这种交替模式减少了通信开销,同时仍提供显著的容量扩展。

3.3 专家容量与 Token 丢弃

由于在 TPU 上训练需要静态声明的张量形状,Switch Transformer 必须为每个专家预分配固定大小的缓冲区。这个缓冲区大小称为专家容量

专家容量=(批次中的 token 数专家数量)×容量因子\text{专家容量} = \left(\frac{\text{批次中的 token 数}}{\text{专家数量}}\right) \times \text{容量因子}

容量因子(CF)是一个超参数,控制在理论上完美(均匀)分布基础上额外分配多少缓冲:

  • CF = 1.0:每个专家恰好获得(tokens/experts)个槽位。如果分布不完美均匀,一些 token 会被丢弃。
  • CF = 1.25:每个专家多 25% 缓冲,可容忍一定不均衡。
  • CF = 2.0:100% 额外缓冲,容忍度很高但浪费资源。

Token 丢弃:当一个专家的缓冲区满了,额外路由到它的 token 会被丢弃——直接跳过专家层通过残差连接传递。这不是灾难性的(token 仍保有之前各层的表征),但过多丢弃会降低质量。

论文发现在正确的负载均衡下,token 丢弃率通常低于 1%,且较低的容量因子(1.0-1.25)对 Switch Transformer 实际效果更好,而传统 MoE 需要更高的容量因子(2.0)。这是只路由到一个专家的直接结果:每个专家负载的方差更低。

3.4 负载均衡损失

为鼓励 token 在专家间的均匀分布,Switch Transformer 在训练目标中添加辅助损失。对于 N 个专家和包含 T 个 token 的批次:

Laux=αNi=1NfiPi\mathcal{L}_{\text{aux}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i

其中:

  • f_i 是实际被分派到专家 i 的 token 比例(离散的、不可微的量)
  • P_i 是分配给专家 i 的路由概率比例(连续的、可微的)
  • α 是超参数(论文全程使用 α = 10⁻²)

为什么有效:在完美均匀路由下,对所有 i 都有 f_i = P_i = 1/N,损失值为 N × N × (1/N)² = 1。任何偏离均匀性的情况都会增大这个乘积(由 AM-GM 不等式保证)。乘以 N 使得无论专家数量如何变化,损失量级保持不变。

为什么 α = 10⁻²:作者在 10⁻¹ 到 10⁻⁵ 的范围内搜索,发现 10⁻² 足够大以实现良好的负载均衡,又足够小以不压倒主要的交叉熵训练目标。

3.5 训练稳定性技术

论文介绍了三种关键的稳定训练技术:

3.5.1 选择性精度

路由器的 softmax 计算对数值精度高度敏感。完全用 bfloat16 训练会导致模型发散(表 2 显示 bfloat16 模型的负对数困惑度是灾难性的 -3.780,而 float32 为 -1.718)。

解决方案:仅将路由器内部计算转换为 float32,其他一切保持 bfloat16。这行得通是因为:

  • 路由函数在每个设备上是局部的(不需要跨设备通信 float32 张量)
  • 分派和组合张量在路由器输出端被重新转换为 bfloat16
  • 计算开销几乎可以忽略

结果:选择性精度达到了与完全 float32 几乎相同的质量(-1.716 vs -1.718),同时保持与完全 bfloat16 相同的速度(1390 examples/sec vs float32 的 1160)。

3.5.2 更小的参数初始化

作者发现将标准 Transformer 权重初始化缩放因子减小 10 倍(从 s=1.0 到 s=0.1)显著改善了平均质量和训练稳定性。使用 s=1.0 时,3.5k 步后的平均负对数困惑度为 -3.60,标准差 0.68(非常不稳定)。使用 s=0.1 时,平均值改善到 -2.72,标准差仅 0.01(非常稳定)。

这项技术具有广泛的适用性,从 2.23 亿参数的基线模型到超过万亿参数的模型都有效。

3.5.3 专家 Dropout

在小规模下游任务上微调时,Switch Transformer 的巨大参数量会带来严重的过拟合风险。标准的 dropout(均匀应用于所有层)无法解决这一问题——将 dropout 从 0.1 增加到 0.3 实际上反而损害了性能。

论文提出专家 dropout:对所有非专家层使用低 dropout 率(0.1),但对专家 FFN 层专门使用高得多的 dropout 率(0.4)。这选择性地正则化了模型中最容易过拟合的部分(专家),同时保留了共享层中学到的表征。

结果(表 4):专家 dropout(d=0.1, ed=0.4)在 GLUE(85.2)、CNNDM(19.6)、SQuAD(83.7)和 SuperGLUE(73.0)上均达到最佳分数,超越了所有均匀 dropout 配置。


4. 实验结果:扩展性能

4.1 Switch vs. MoE vs. 密集模型(表 1)

论文的第一个实验是精心控制的正面对比。所有模型使用 128 个专家(如适用),在 32 个 TPUv3 核心上训练,运行相同步数。

模型 容量因子 100k步质量 达到阈值时间 速度
T5-Base -1.731 未达到 1600 ex/s
T5-Large -1.550 131.1 hrs 470 ex/s
MoE-Base (top-2) 2.0 -1.547 68.7 hrs 840 ex/s
Switch-Base 2.0 -1.554 72.8 hrs 860 ex/s
MoE-Base 1.25 -1.559 80.7 hrs 790 ex/s
Switch-Base 1.25 -1.553 65.0 hrs 910 ex/s
MoE-Base 1.0 -1.572 80.1 hrs 860 ex/s
Switch-Base 1.0 -1.561 62.8 hrs 1000 ex/s
Switch-Base+ 1.0 -1.534 67.6 hrs 780 ex/s

关键发现

  1. Switch 在每个容量因子上都超越 MoE,无论是质量还是速度。
  2. 在 CF=1.0 时,Switch 达到 1000 examples/sec(MoE 为 860)——快 16%。
  3. Switch-Base+(放大到匹配 MoE 速度)在所有模型中质量最优(-1.534)。
  4. T5-Base 在 100k 步内从未达到 -1.50 的质量阈值。

4.2 随专家数量扩展

论文的图 4 展示了一个优美的扩展关系:随着专家数量增加(2, 4, 8, 16, 32, 64, 128, 256),性能持续提升——尽管每个 token 的 FLOPs 保持不变。这验证了论文的核心假设:参数数量是一个有用的独立扩展维度,与计算量无关。

具体结果:Switch-Base 64 专家模型在 七分之一的训练步数(60k vs 450k 步)就达到了与 T5-Base 相同的质量——样本效率提升 7.5 倍。

4.3 时钟时间速度优势(图 5)

按实际时钟时间(考虑通信开销),Switch-Base 64 专家模型比 T5-Base 实现 7 倍加速。这比步数基础上的改进略低(因为路由和通信开销),但仍然非常惊人。

4.4 与更大的密集模型对比(图 6)

即使与 T5-Large(使用 3.5 倍更多 FLOPs 每 token)相比,Switch-Base 64 专家在时钟时间上仍然 快 2.5 倍。这是一个了不起的结果:一个使用更少 FLOPs 的模型,仅凭拥有更多(但稀疏激活的)参数,就超越了使用 3.5 倍更多计算的模型。


5. 下游微调结果

5.1 全面基准测试结果(表 5)

论文在广泛的 NLP 任务上进行评估。完整结果如下:

任务 T5-Base Switch-Base T5-Large Switch-Large
GLUE 84.3 86.7 (+2.4) 87.8 88.5 (+0.7)
SQuAD 85.5 87.2 (+1.7) 88.1 88.6 (+0.5)
SuperGLUE 75.1 79.5 (+4.4) 82.7 84.7 (+2.0)
Winogrande 66.6 73.3 (+6.7) 79.1 83.0 (+3.9)
XSum 18.7 20.3 (+1.6) 20.9 22.3 (+1.4)
ANLI (R3) 51.8 54.0 (+2.2) 56.6 58.6 (+2.0)
ARC Easy 56.7 61.3 (+4.6) 68.8 66.0 (-2.8)
ARC Challenge 35.5 32.8 (-2.7) 35.5 35.5 (0.0)
CB Web QA 26.6 27.4 (+0.8) 27.7 31.3 (+3.6)
CB Natural QA 25.8 26.8 (+1.0) 27.6 29.5 (+1.9)
CB Trivia QA 24.5 30.7 (+6.2) 29.5 36.9 (+7.4)

值得注意的发现

  • SuperGLUE 获得最大幅度的一致改进:Base +4.4,Large +2.0。
  • Winogrande(常识推理):Base +6.7,Large +3.9。
  • 闭卷 Trivia QA:Base +6.2,Large +7.4——单项最大改进,表明稀疏模型是出色的知识存储库。
  • ARC Challenge 是 Switch-Base 表现不如 T5-Base 的唯一任务(-2.7 分),暗示某些类型的多步推理可能并未从稀疏参数扩展中受益。

5.2 蒸馏结果

论文展示了大型稀疏模型可以通过蒸馏被压缩:

稀疏模型大小 压缩率 质量保留
1.1B → 223M 82% 37%
2.0B → 223M 90% 32%
3.8B → 223M 95% 30%
7.4B → 223M 97% 27%
14.7B → 223M 99% 28%

最佳蒸馏方案:用教师的非专家权重初始化学生 + 使用 75% 硬标签和 25% 软教师标签的混合损失。这保留了约 30% 的质量增益。

对于 SuperGLUE 上的微调蒸馏:Switch-Base(7.4B)达到 81.3;蒸馏后的 T5-Base(223M)达到 76.6,保留了基线 T5-Base(74.6)之上 6.7 分差距的 30%。

5.3 多语言结果

在覆盖 101 种语言的多语言 Common Crawl(mC4)上预训练时:

  • Switch Transformer 在全部 101 种语言上超越 mT5-Base(图 7)
  • 平均每步加速 5 倍(图 8)
  • 91% 的语言达到至少 4 倍加速
  • 改进覆盖不同语系、文字系统和资源水平

6. 扩展到万亿参数

6.1 模型配置(表 9)

论文设计了两个大规模模型:

模型 参数量 FLOPs/序列 专家数 dmodel dff 层数
T5-XXL 110 亿 6.3T 4096 10240 24
Switch-XXL 3950 亿 6.3T 64 4096 10240 24
Switch-C 15710 亿 890B 2048 2080 6144 15

Switch-C 的设计很有意思:它仅使用专家并行(不用模型并行),拥有 2048 个专家。这使得每个专家都很小,但总参数量巨大。每 token 的 FLOPs(890B)实际上远低于 T5-XXL(6.3T)。

Switch-XXL 与 T5-XXL FLOP 匹配,使用 64 个专家但每个维度更大。它对每个 token 施加约 10 倍于 Switch-C 的 FLOPs。

6.2 预训练结果

250k 步后:

  • Switch-XXL:-1.086 负对数困惑度
  • Switch-C:-1.096 负对数困惑度
  • T5-XXL:-1.147 负对数困惑度

500k 步后:

  • Switch-XXL:-1.008
  • Switch-C:-1.043
  • T5-XXL:-1.095

Switch-XXL 在达到同等困惑度时比 T5-XXL 快 4 倍

6.3 大规模下的训练不稳定

论文坦诚报告了在最大规模下训练稳定性仍是挑战。Switch-C(1.6 万亿参数,2048 个专家)没有出现任何不稳定——可能因为其每 token 计算量适中。Switch-XXL(3950 亿参数,但每 token FLOPs 高 10 倍)出现了偶发的不稳定,导致无法完成完整的 100 万步训练。

这是一个诚实且有价值的观察:高每 token 计算量与稀疏路由的组合会产生目前尚未完全理解的交互效应。

6.4 并行策略分析

论文第 5 节提供了异常深入的不同并行策略组合分析。核心洞察是每种策略都涉及权衡:

  • 纯数据并行:简单,前向/反向传播期间无通信,但受限于单设备能容纳的模型大小。
  • 纯模型并行:允许更大模型,但每层都需要 all-reduce 通信。
  • 纯专家并行:需要 all-to-all 通信进行 token 路由,但每个专家足够小可以放在一个设备上。
  • 组合使用:对最大模型是必要的,但最优配置需要根据具体硬件(TPU 拓扑、互联带宽、每设备内存)进行经验调优。

7. 讨论:优势、局限与边界条件

7.1 优势

  1. 简洁性:核心思想(路由到一个专家而非两个)优雅简洁却效果显著。
  2. 高效性:FLOP 匹配的模型实现了显著更好的质量,展示了 7 倍加速。
  3. 可扩展性:方法从单个 GPU 上的 2 个专家扩展到 TPU 集群上的 2048 个专家。
  4. 实用的训练技术:选择性精度、更小的初始化和专家 dropout 都有超越 Switch Transformer 的广泛适用性。
  5. 多语言普适性:测试的全部 101 种语言均获得改进,不仅限于高资源语言。

7.2 局限与边界条件

  1. 极端规模下的训练不稳定:Switch-XXL 模型经历了偶发的不稳定。训练稳定化技术虽然有帮助,但对最大配置仍不够充分。

  2. 推理任务的微调差距:尽管预训练质量优越,向下游推理任务的转化并不一致。Switch-C(1.6 万亿参数)在 SQuAD 上只达到 87.7,而更小的 Switch-XXL(3950 亿参数)达到 89.6。对于推理来说,每 token 更多的 FLOPs 似乎比更多参数更重要。

  3. ARC Challenge 退化:Switch-Base 在 ARC Challenge 上实际不如 T5-Base(32.8 vs 35.5),暗示某些推理任务可能不受益于(甚至被损害于)稀疏路由。

  4. 部署挑战:1.6 万亿参数的模型即使采用稀疏激活,仅存储权重就需要巨大内存。蒸馏有帮助但会损失 70% 的质量增益。

  5. 负载均衡敏感性:辅助损失超参数 α 需要仔细调整。论文使用 α = 10⁻²,但这可能无法推广到所有设置。

  6. Token 丢弃:虽然通常低于 1%,但 token 丢弃是密集模型中不存在的信息损失机制。最坏情况下,重要 token 可能从过载的专家中被丢弃。

  7. 静态专家容量:TPU 编译需要固定张量形状,意味着专家容量必须在编译时设定。这阻止了在推理时根据实际路由模式进行动态调整。

7.3 可重复性

论文提供了良好的可重复性支持:

  • JAX 代码和模型检查点公开可用
  • 所有模型在公开的 C4 数据集上训练
  • 超参数在表 9 中完全指定
  • 训练基础设施(TPUv3)有详细文档

但复现最大规模实验(Switch-C,2048 个专家)需要大量的 TPU 集群访问,这限制了独立验证。

7.4 与后续工作的对比视角

值得将 Switch Transformer 的设计与后来的 MoE 模型进行对比,以理解其贡献的持久性和局限性:

与 Mixtral 的对比:Mixtral(2024)选择了 top-2 路由而非 Switch 的 top-1,每次激活 2 个(共 8 个)专家。这说明后续研究发现适度的 k 值(k=2)在某些场景下可能优于纯粹的 top-1。但核心理念——稀疏激活、大参数量——直接继承自 Switch Transformer。

与 DeepSeek-V2 的对比:DeepSeek-V2 引入了更精细的专家分段(Fine-grained Expert Segmentation),将每个专家进一步分成更小的子专家,同时使用共享专家来稳定训练。这可以看作是对 Switch Transformer 中训练不稳定问题的一种工程化解决方案。

与 GShard 的对比:GShard(Lepikhin 等人,2020)是 Switch Transformer 的直接前身,使用 top-2 路由和 float32 全精度训练。Switch Transformer 的贡献正是简化了这些选择(top-1 + 选择性精度),使系统更简单且更高效。

容量因子的演进:Switch Transformer 发现 CF=1.0-1.25 效果最好。后续的工作(如 DeepSeek-V2)完全去掉了固定容量因子的概念,转而使用更灵活的路由策略。这反映了从 TPU 静态编译约束向更灵活的 GPU 推理的演进。

7.5 对实践者的具体建议

基于论文的发现,如果你要在自己的项目中使用 MoE 架构,以下是一些可操作的建议:

  1. 从少量专家开始:即使只有 2 个专家也能带来改进(附录 D)。不要因为没有大规模集群就放弃尝试。
  2. 使用选择性精度:始终在路由器内部使用 float32,其他地方用 bfloat16。这几乎没有成本但大幅提升稳定性。
  3. 缩小初始化:将权重初始化缩放因子设为默认值的 1/10。
  4. 微调时用专家 dropout:非专家层 dropout=0.1,专家层 dropout=0.4。
  5. 监控 token 丢弃率:如果超过 1-2%,增加容量因子或调整负载均衡系数 α。
  6. 考虑蒸馏:如果部署受限,训练一个大型稀疏教师然后蒸馏到小型密集学生可能是最优策略。

8. 遗产与影响

Switch Transformer 对整个领域产生了深远的影响:

  1. Mixtral(Mistral AI,2024):使用 top-2 稀疏 MoE 路由,直接受此系列工作启发,以远少于密集替代方案的计算达到了有竞争力的性能。

  2. DeepSeek-V2/V3:采用 MoE 并引入细粒度专家分段等创新,直接建立在 Switch Transformer 的基础之上。

  3. Google 的生产模型:Google 许多内部模型使用源自 Switch Transformer 设计的稀疏 MoE 架构。

  4. 高效 ML 生态系统:论文证明参数数量和计算可以解耦,影响了整个领域对模型扩展的思考方式。

  5. "稀疏即美"范式:在这篇论文之前,传统智慧是"放大密集模型"。在此之后,稀疏模型成为合法且流行的替代方案,催生了当前一代高效 MoE 模型。


9. 关键图表深度解读

图 1(扩展与样本效率)

左图展示了在保持每 token FLOPs 固定的情况下,随着稀疏模型参数的增加(通过增加专家数),测试损失持续下降。从 T5-Base(2.23 亿参数,左上角)到 256 专家模型(147 亿参数,右下角),这条曲线近似对数线性,表明参数效率有清晰的 scaling law。右图则在训练步数的维度上比较了不同专家数量的 Switch-Base 与密集 T5-Base。128 专家版本在约 6 万步就超越了 T5-Base 在 45 万步时的性能。

图 2(架构图)

这张图清晰地展示了 Switch Transformer 编码器块的结构。与标准 Transformer 的唯一区别在于:密集 FFN 层被替换为稀疏 Switch FFN 层。图中展示了两个 token("More"和"Parameters")如何被路由器独立路由到四个专家中。注意输出被路由器门控值(虚线)缩放——这是保持可微性的关键。

图 3(容量因子与 Token 丢弃)

这张图直观地解释了容量因子的概念。CF=1.0 时,缓冲区恰好够用,但不均匀分布会导致溢出(红色虚线标记的被丢弃 token)。CF=1.5 时,额外的缓冲(白色空槽位)缓解了溢出问题,但增加了计算和通信成本。这是一个经典的空间-质量权衡。

图 5(时钟时间速度优势)

这可能是论文中最有说服力的图。在固定的 32 个 TPUv3 核心和相同的计算预算下,Switch-Base 128 专家模型在约 50 小时内达到的性能水平,T5-Base 需要约 350 小时——7 倍的加速。64 专家版本在约 50 小时内达到同等质量,并在后续训练中持续改进。

图 7-8(多语言结果)

图 7 是一个跨 101 种语言的条形图,每种语言的 Switch 模型负对数困惑度都低于(=优于)密集基线。图 8 的直方图显示绝大多数语言获得了 4-6 倍的步数加速,均值为 5 倍,证明了稀疏架构在多语言场景下的普适有效性。

表 9(万亿参数模型配置)

这张表揭示了一个微妙但重要的设计权衡:Switch-C(1.6T 参数)只用了 2080 的 dmodel 和 15 层,远小于 T5-XXL 的 4096 dmodel 和 24 层。它依靠 2048 个专家来实现参数量的爆发。而 Switch-XXL 保持了与 T5-XXL 相同的维度和层数,仅添加 64 个专家。两种策略各有利弊:Switch-C 更稳定(无训练不稳定),Switch-XXL 在下游任务上更强(因为每 token FLOPs 更高)。


10. 结论

Switch Transformer 是那种罕见的既概念简洁又技术深刻的论文。其核心贡献——将每个 token 路由到单一专家——容易表述但对模型设计、训练效率和扩展性具有深远影响。

论文证明了通过将参数数量与计算成本解耦,我们可以构建同时比密集对应模型更大(参数方面)和更便宜(计算方面)的模型。比 T5-Base 快 7 倍、比 T5-XXL 快 4 倍——这不是渐进式改进,而是代表了我们对扩展方式思考的范式转变。

训练技术(选择性精度、更小初始化、专家 dropout)和并行策略的深入分析为构建大型稀疏模型提供了实用的操作手册。蒸馏结果则在完整稀疏模型过于庞大而无法实际部署时提供了一条出路。

对于从业者:如果你正在构建或微调大语言模型却没有考虑稀疏 MoE 架构,这篇论文给出了你应该考虑的有力理由。效率增益大到不可忽视,而现代生态系统(Mixtral、DeepSeek 等)已经证明这些想法在生产规模上是可行的。

对于研究者:论文对不工作的方面(训练不稳定、微调改进不一致、推理差距)保持了令人耳目一新的诚实。这些开放问题仍是活跃的研究方向,代表着有价值的未来工作方向。

Switch Transformer 不仅推进了技术水平——它开启了一种关于神经网络扩展的新思考方式,至今仍在塑造这个领域。

从技术层面来说,它教给我们三个持久的教训:第一,参数数量和计算是可以解耦的独立扩展维度;第二,简单的设计选择(top-1 vs top-2)可以产生深远的系统影响;第三,训练稳定性不是可选的——它需要在架构层面而非事后修补来解决。这些教训在今天构建任何大规模 AI 系统时都同样适用。


参考文献

  1. Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR, 23, 1-40.
  2. Shazeer, N., et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer.
  3. Raffel, C., et al. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer (T5).
  4. Vaswani, A., et al. (2017). Attention Is All You Need.
  5. Kaplan, J., et al. (2020). Scaling Laws for Neural Language Models.
  6. Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3).
  7. Lepikhin, D., et al. (2020). GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.
  8. Hinton, G., et al. (2015). Distilling the Knowledge in a Neural Network.
  9. Jacobs, R., et al. (1991). Adaptive Mixtures of Local Experts.
  10. Wang, A., et al. (2019). SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding.