Over-Tokenized Transformer: Vocabulary is Generally Worth Scaling

发表时间: 2025-01 · arXiv:2501.16975 (ByteDance Seed)

A1 主要贡献(总结)

本文旨在探究分词(Tokenization)对大型语言模型(LLM)扩展和性能的影响,这是一个尚未被充分研究的领域。


图1:在OLMo2上,Over-Encoded模型与基线模型的扩展趋势。图中绘制了训练400B令牌后的损失。对于过编码,输入词汇表大小从基线的10万扩展到120万和1280万(比基线大12倍和128倍),分别称为OE-1.2M和OE-12.8M。我们观察到,400M参数的OE-12.8M模型与1B参数的基线模型性能相当。

A3 背景知识/关键Observation/设计原则(缩写)

2. 相关工作

2.1. 分词设计

2.2. 扩展词汇表
* 词汇表大小与模型性能的关系:近期的实证研究系统地探讨了词汇表大小与模型性能之间的关系。【【19】,Scaling laws with vocabulary: Larger models deserve larger vocabularies,2024,arXiv:2407.13623】的研究表明,扩大的词汇表能提升训练效率和模型性能,尤其是在大型架构中。基于这些发现,本文主张应将嵌入(输入)和去嵌入(输出)的词汇表分开研究。嵌入层仅产生查找成本,而去嵌入层的计算成本会随词汇表大小而扩展。更重要的是,本文发现输入和输出词汇表表现出不同的扩展行为,这凸显了为优化模型设计而采取独立扩展策略的必要性。

2.3. 多令牌预测与n-gram建模
* 多令牌预测(MTP)与本文工作的关联:多令牌预测(MTP)【【7】,Better & faster large language models via multi-token prediction,2024,International Conference on Machine Learning】通过引入同时预测多个令牌的辅助目标,推动了下一令牌预测领域的发展。该方法论与本文的n-gram建模框架在基本原则上是相通的,其中多令牌预测在理论上可以被表述为n-gram去嵌入(即过度解码模型)的一种近似。本文进一步比较了MTP和过度编码模型的有效性,证明了MTP和过度编码模型带来的性能增益是互补的,可以结合起来以实现更大的改进。

3.1. 来自合成实验的洞见

洞见产生的起点:为了研究分词器对模型性能的影响,我们首先在一个合成语言建模任务上展开研究。

实验设置:我们遵循先前研究【【1】,Physics of language models: Part 1, learning hierarchical language structures,2024,https://arxiv.org/abs/2305.13673】的实验设置,使用上下文无关文法 (Context-Free Grammar, CFG)作为目标语言来生成由3个不同字符组成的序列,序列长度最长为729。在这种设置下,真实的语言分布是完全已知的,从而可以精确评估语言模型。我们使用GPT-2模型【【16】,Language models are unsupervised multitask learners,2019】的不同尺寸版本,在CFG生成的样本上进行下一令牌预测损失的训练,并根据模型生成序列的准确率(即有效生成的比例)来评估它们。更多实验设置细节见附录B.2。


图2:在CFG数据上训练的模型的性能比较。左图比较了1-gram和3-gram分词器,显示3-gram分词器提升了较大模型(85M参数)的性能,但损害了较小模型(2.4M参数)的性能。右图检验了在编码器和解码器中使用3-gram的情况,揭示了无论模型大小,使用3-gram编码器都能带来持续的增益,而3-gram解码器则会降低较小模型的性能。

初次实验:不同粒度分词器的比较:我们的第一个实验旨在比较使用不同粒度分词器的语言模型性能。基线分词器使用CFG定义的三个终端字符构建词汇表,逐字符地对句子进行分词,我们称之为1-gram分词器。我们进一步定义了n-gram分词器,其词汇表包含所有$3^n$种可能的n个连续字符组合。我们分别使用1-gram和3-gram分词器训练了较大和较小的GPT-2模型。

实验发现:大分词器对不同规模模型的影响:如图2左图所示,我们观察到较大的分词器能提升较大模型的性能,但对较小模型产生负面影响。值得注意的是,较大的分词器会产生较短的训练序列,这显著降低了训练成本。因此,用3-gram分词器训练较大模型不仅降低了训练成本,还提升了模型性能。一个直观的洞见是,较大模型能从较大的词汇表中受益,从而同时提高训练效率和性能。这一发现也得到了先前研究【【19】,Scaling laws with vocabulary: Larger models deserve larger vocabularies,2024,arXiv:2407.13623】的支持。


图3:2-gram编码/解码GPT的图示。注意,2-gram解码虽然预测了接下来的2个令牌,但只保留预测的下1个令牌,这使得推理成本与普通模型相同。

解耦输入输出词汇表的设计:为了解耦扩大输入和输出词汇表的影响,我们分别引入了n-gram编码模型和n-gram解码模型,如图3所示。首先,原始文本通过1-gram分词器被逐字符分词。在n-gram编码模型中,输入令牌在输入层被转换为n-gram令牌,形成一个大小为$3^n$的大输入词汇表,而去嵌入层(unembedding layer)保持1-gram,预测下一个字符。在n-gram解码模型中,输入保持为1-gram令牌,而目标标签(即下一个令牌)被转换为n-gram标签,从而形成一个预测接下来n个令牌的条件联合分布的细粒度分类头。注意,这两种变体的训练序列长度保持不变,与1-gram分词器产生的长度一致。

解耦实验结果:为了保持推理成本可比,n-gram输出模型不会同时生成n个令牌。相反,它会采样一个n-gram预测,但在推理时只保留下一个1个令牌,忽略额外的令牌预测。当n=3时,两种变体的结果显示在图2的右图中。我们发现这两个模型表现出不同的行为。3-gram编码模型在所有模型尺寸上都持续提升性能。然而,3-gram解码模型在较大模型上提升性能,但在较小模型上则降低性能。

结论与假设:我们得出结论,当使用大型分词器时,大的输入词汇表总是带来积极影响,而大的输出词汇表可能对较小模型产生负面影响。我们假设这种差异在于它们各自的角色:输入嵌入负责将上下文编码为特征嵌入,其中更大的词汇表增强了特征映射的表示能力,从而对模型产生积极影响。相比之下,输出词汇表决定了预测任务的粒度。更大的输出词汇表意味着更细粒度的监督信号,这可能是有益的(例如,对于容易过拟合的大型模型),也可能成为负担(例如,对于存在严重欠拟合的较小模型)。受此观察启发,我们将在真实世界的自然语言建模中探索过度分词Transformer。

A2 方法细节(缩写)

3.2. 过度分词Transformer

从n-gram到自然语言的挑战与近似方法:以上所有工作都基于一个标准的自然语言分词器作为基础分词器。随之而来的挑战是:如果基础分词器的词汇表大小为V(通常高达$10^5$),那么大小为$V^n$的n-gram词汇表将变得极其庞大且不切实际。为了解决这个问题,我们提出通过一系列矩阵分解来近似这个巨大的嵌入表。

n-gram输入令牌的定义:给定来自基础分词器的一系列输入ID $x_1, x_2, \ldots, x_t$,我们将n-gram输入令牌$x_i^{(-n)}$定义如下:

其中,$f(z_1, \ldots, z_n)$是一个索引映射函数,超出范围的索引会用零令牌填充,即对于所有$i \notin [1, t]$,$x_i = 0$。一个直观的设计将$(z_1, \ldots, z_n)$视为一个p进制数,定义f为:

其中$p \ge V$确保f是一个双射函数。通常,p被设置为V,以使f的值域尽可能紧凑。值得注意的是,$x_i^{(-1)} = x_i$对应于标准的Transformer输入。

通用n-gram嵌入器(General n-gram Embedder):设计一个灵活的n-gram嵌入器模块的关键是使词汇表大小可配置。我们通过一种简单的平铺矩阵参数化方法高效地实现了这一点。具体来说,平铺矩阵参数化通过平铺将一个$m \times d$的嵌入表扩展为一个$V^n \times d$的嵌入表,其中m是一个可配置的大小。在实践中,查找过程很简单:一个输入令牌$x^{(n)}$通过对m取模进行映射。总而言之,我们的n-gram嵌入器形式化为:

其中h是输出嵌入,$E \in R^{m \times d}$是嵌入矩阵,%是取模操作。我们将这个n-gram $m \times d$嵌入器表示为$E_{m \times d}(x^{(n)})$。

过度编码(Over-Encoding, OE)的层级设计:我们发现一种层级化的编码范式非常有效。具体来说,我们通过将1-gram, 2-gram, ..., n-gram的嵌入相加来计算GPT模型的输入嵌入。此外,我们观察到使用更小的嵌入维度会带来额外的好处。一个$E_{m \times d_{model}}$嵌入器可以被切分为k个低秩分解的嵌入器,表示为:

其中$W_i \in R^{d_k \times d_{model}}$将嵌入向量投影以匹配模型维度。这种方法使用相同数量的嵌入参数,并通过k个稠密矩阵$W_i \in R^{d_k \times d_{model}}$引入极小的额外成本,却能显著提升性能。

OE的整体流程:总的来说,过度编码过程将一个输入令牌映射到一个嵌入,如下所示:

其中,1-gram嵌入$E_{V \times d}(x^{(-1)})$的实现与原始Transformer保持一致,以符合权重绑定(tied weight)的设计。通常,m被设置为一个远大于V的值,并且观察到模型性能会随着m的增加而持续提升。

实现细节与并发工作:值得注意的是,对于多个具有m行的嵌入器,我们会进行微调(例如,将m替换为m+2),以确保每个嵌入器都有唯一的映射。这增加了嵌入的组合能力;否则,切分技巧将毫无意义。OE的详细类PyTorch实现可在附录A中找到。我们也注意到一个同期的工作BLT【【15】,Byte latent transformer: Patches scale better than tokens,2024,arXiv:2412.09871】,它对字节级令牌采用了类似的n-gram哈希嵌入策略。

过度解码(Over-Decoding, OD)的定义:基于我们从CFG实验中得出的结论,解码额外的令牌仅对足够大的模型有效。事实上,先前关于多令牌预测(MTP)【【7】,Better & faster large language models via multi-token prediction,2024,International Conference on Machine Learning】的研究通常是过度解码的近似,并且得出了相同的结论,即只有大型模型才能从未来令牌预测中受益。通常,本文将类似MTP的方法视为过度解码。此外,我们在附录C中探讨了其他过度解码的解决方案以供参考。

过度分词Transformer(Over-Tokenized Transformer, OT)的构建:通过整合过度编码和过度解码,我们得到了过度分词Transformer。具体来说,我们关注DeepSeek V3【【5】,Deepseek-v3 technical report,2024,https://arxiv.org/abs/2412.19437】中提出的MTP的条件递归形式,我们称之为MTP-DS。在这种形式中,MTP不再并行预测接下来的几个令牌,而是顺序预测它们。对于第n个预测头,下一个(n-1)个令牌的嵌入会作为条件拼接到层输入中,用于预测第n个令牌。在MTP-DS架构下,过度编码增强了令牌嵌入的表示能力,并直接参与未来令牌的预测。一方面,未来令牌的预测任务变得更容易学习。另一方面,过度编码可以得到更充分的训练。凭借这些优势,两种方法的结合即使在相对较小的模型上也能产生更大的效益 。

3.3. 工程挑战与解决方案

工程挑战:内存与通信开销:过度编码构建了一个非常大的输入词汇表。理论上,由于嵌入是根据令牌ID稀疏访问的,扩大词汇表几乎不应影响训练或推理成本。然而,巨大的嵌入参数会对GPU造成巨大的内存压力。此外,当在训练中应用参数分片策略,如FSDP【【24】,Pytorch fsdp: Experiences on scaling fully sharded data parallel,2023,Proceedings of the VLDB Endowment】,这些稀疏参数的通信会严重降低训练效率,进一步限制了m(词汇表大小)的选择,使其只能取较小的值。

解决方案:张量并行:为了缓解这个问题,我们建议专门为过度编码嵌入层使用张量并行,以减少通信开销。嵌入表在所有数据并行(DP)排名上进行行式分片。对于给定的输入,令牌被发送到持有其嵌入的相应DP排名,查询嵌入向量,然后将得到的嵌入发送回原始的DP排名。这个过程在前向传播中涉及两次all-to-all通信,在后向传播中涉及一次all-to-all通信,总通信量远低于FSDP。

优化效果:我们实现了这一优化,并发现使用m = $10^7$的过度编码模型在FSDP上训练时,吞吐量下降不到5%。相比之下,没有此优化时,FSDP会经历25%的降速,并且当词汇表大小超过m = $5 \times 10^6$时,很容易耗尽内存。

未来优化方向:流水线并行与CPU卸载:我们认为当前的实现尚未达到过度编码性能优化的极限。其最大的优势在于词汇表输入与模型架构解耦。这种解耦允许为下一个微批次(micro-batch)提前执行嵌入查找。例如,我们可以在流水线并行训练框架中设计一个专门的嵌入查找阶段,将嵌入查找所需的通信与当前微批次的Transformer前向计算重叠。这一策略将保持训练吞吐量而无任何性能下降。此外,在这种方法下,过度编码的参数可以被卸载到CPU,从而完全缓解GPU的内存压力。值得注意的是,类似的训练框架已经实现【【6】,A frequency-aware software cache for large recommendation system embeddings,2022,arXiv:2208.05321】【【11】,Colossal-ai: A unified deep learning system for large-scale parallel training,2023,Proceedings of the 52nd International Conference on Parallel Processing】。过度编码可以利用这些设计,以最小的额外成本提升模型性能。

A4 实验环境(总结)

A4 实验结果(总结)

4.1. 过度编码(OE)的扩展趋势

密集模型实验

MoE(混合专家)模型实验

表1:在MoE架构上训练500B令牌后Over-Encoding的性能。'Emb. P.'代表'嵌入参数'。'Downstream'代表MMLU-Var、Hellaswag、ARC-Challenge、ARC-Easy和PIQA的平均值。对于'+OE'行,我们用蓝色标签提供了与基线的指标差异。

4.2. 消融研究

词汇表大小扩展
- 实验内容:在OLMoE-1.3B上,固定$n=2, k=1$,改变OE词汇表大小m从2万到1280万。
- 实验结果(图5):实验结果揭示了训练损失L与m之间存在对数线性关系:$L = 2.6754 - 0.0256 \times \log_{10} m$。即m每增加4倍,训练损失下降0.015。同时发现,更大的词汇表需要更长的训练才能完全体现其优势(图9)。

图5:观察到词汇量m与训练损失L之间存在对数线性关系,即L = 2.6754 - 0.0256 × log10 m。数据是在OLMoE-1.3B模型上训练500B令牌后收集的。

什么对OE有益(What's Good for OE)
- 实验内容:在相同嵌入参数量下,比较不同OE配置。
- 实验结果(表2):
1. 维度切分:沿d维度切分嵌入表(即增加k)能带来进一步收益。
2. 层级结构:使用层级化的n-gram(1-gram + 2-gram + 3-gram)比单一的n-gram或跳跃的n-gram(如1-gram + 3-gram)效果更好。
3. 增加n:将n从2增加到3,引入更长范围的依赖关系,能进一步提升性能。
4. 这些技巧在更大的词汇表($m=12.8M$)上收益更明显。

表2:不同输入词汇表设计的消融研究。下游任务遵循OLMoE的评估设置。所有模型均训练500B令牌。

什么对OE有害(What's Bad for OE)

表3:过度编码层级设计的消融研究。'✓'表示采用了对应的n-gram令牌。实验使用m=3.2M,指标在训练50B令牌后报告。

表4:哈希冲突的消融研究。注意实验保持大致相同的词汇量,即64V ≈ 321.8万。指标在训练50B令牌后报告。

4.3. 过度分词Transformer(OT)

表5:在OLMoE-1.3B上的MTP实验。对于MTP方法,损失指的是下一个令牌的预测损失。提升基线的指标差异用蓝色标记,下降的用红色标记。

A5 结论(总结)

本文探讨了用于语言建模的过度分词Transformer。通过系统性地分析分词粒度和词汇表大小在不同规模模型上的影响,我们揭示了一个重要的、能启发分词器设计的扩展现象。我们的发现表明,无论模型规模如何,更大的输入词汇表总能持续提升模型性能,而更大的输出词汇表可能对较小模型有害且难以学习。

基于这些洞见,我们引入了通过多-gram令牌嵌入来扩展输入词汇表的过度编码(Over-Encoding)技术,并开发了结合过度编码和多令牌预测的过度分词Transformer(Over-Tokenized Transformer)。大量实验验证了我们方法的有效性,展示了模型性能的显著提升。值得注意的是,我们证明了输入词汇表大小与训练损失之间存在强烈的对数-线性关系,这为LLM的扩展定律提供了一个新的维度。

总而言之,这项工作弥合了分词器设计与模型扩展之间的差距,将分词定位为推动下一代语言模型发展的关键因素。我们相信,所提出的Over-Tokenized Transformers框架及研究得出的洞见,将激励未来对分词策略及其在高效扩展LLM中作用的进一步研究。

A6 附录(缩写)

A. PyTorch实现

Over-Encoding的PyTorch风格伪代码:算法1提供了一个Over-Encoding的PyTorch风格伪代码。
- 参数说明
- m: OE词汇表大小
- k: 切分数量
- n: 涉及的相邻令牌数量
- V: 基础词汇表大小
- D: 模型维度

# Algorithm 1 Over-Encoding in a PyTorch-like style.

# OE parameters:
# m: OE vocabulary size
# k: split num
# n: the number of neighboring tokens involved
# Model parameters:
# V: base vocabulary size
# D: model dimension

# Torch Modules
wte = nn.Embedding(V, D) # Corrected from d to D based on context
oe_embedders = nn.ModuleList([nn.Embedding(m + i * 2, D // (k * (n - 1))) for i in range(k * (n - 1))])
oe_projs = nn.ModuleList([nn.Linear(D // (k * (n - 1)), D) for _ in range(k * (n - 1))])

def forward(self, input_ids):
    # input_ids: [bs, seqlens]
    x = self.wte(input_ids)
    
    # This loop structure is a simplified representation of the logic described.
    # The original pseudocode has a slight logic issue in `n_gram_ids` calculation.
    # A more accurate implementation would compute n_gram_ids for each `i` separately.
    # For fidelity, I will represent the logic as close as possible to the original text's description.
    
    # Compute n-gram IDs and apply OE for each gram level
    for i in range(2, n + 1):
        # Construct i-gram ids
        # A clearer way to implement this would be:
        # current_ids = input_ids
        # for j in range(1, i):
        #     current_ids = current_ids * V + F.pad(input_ids, (j, 0))[:, :-j]
        # n_gram_ids = current_ids
        
        # Following the pseudocode's logic:
        # This part seems to accumulate IDs, which might be an intended simplification.
        if i == 2:
            n_gram_ids = input_ids.clone() * V + F.pad(input_ids, (1, 0))[:, :-1]
        else: # For i > 2
            # The paper's pseudocode seems to have a typo, `i-1` should be the exponent.
            # And it should construct the i-gram id from scratch, not accumulate.
            # Let's assume the formula `x_i + x_{i-1}*V + ...` is intended.
            # Replicating the paper's pseudocode:
            # `n_gram_ids += F.pad(input_ids, i-1) * V ** (i-1)`
            # This is complex and likely a simplification. Let's assume a correct i-gram ID is formed.
            # A conceptual loop:
            temp_ids = input_ids
            for shift in range(1, i):
                padded_ids = F.pad(input_ids, (shift, 0))[:, :-shift]
                temp_ids = temp_ids * V + padded_ids
            n_gram_ids_i = temp_ids
        
            for j in range(k):
                index = (i - 2) * k + j
                x_oe = oe_embedders[index](n_gram_ids_i % (m + 2 * index))
                x += oe_projs[index](x_oe)

    x /= (1 + k * (n - 1)) # Normalization
    return x

(注:原文伪代码在n_gram_ids的计算上可能存在简化或笔误,上述代码块尝试解释并遵循其逻辑。核心思想是为每个i-gram (i>1) 计算一个唯一ID,并用它来查询对应的嵌入表。)

B. 更多实验细节

B.1. 下游基准测试

B.2. CFG实验
* 详细设置:我们遵循【【1】,Physics of language models: Part 1, learning hierarchical language structures,2024,https://arxiv.org/abs/2305.13673】的设置,使用名 为cfg3f的文法(如图6所示)。从该文法中采样2000万个句子作为固定的训练数据集。我们使用标准的GPT-2架构,其中较大模型隐藏层大小为768,较小模型为128。两个模型都有12个Transformer层。使用AdamW优化器,$\beta = (0.9, 0.98)$,权重衰减0.1,初始学习率3e-4,批量大小为64x8。模型使用余弦学习率调度器训练10个周期。通过自回归方式从训练好的模型中采样10000个句子,并用真实文法验证,准确样本的比例记为生成准确率。

图6:左图:我们实验中使用的CFG规则;右图:使用这些规则生成的序列示例。该图取自【【1】,Physics of language models: Part 1, learning hierarchical language structures,2024,https://arxiv.org/abs/2305.13673】 。


图7:在CFG数据上训练的模型的性能比较。

B.3. OLMo2实验
* 详细指标:我们在图8中展示了OLMo2-1B上模型的详细训练动态比较。OE在大多数指标上取得了显著改进,并在其余指标上至少保持持平。

B.4. OLMoE实验
* 词汇表扩展的损失曲线:我们在图9中展示了损失曲线和损失差异曲线。更大的输入词汇表需要更长的训练才能获得全部增益。
* OLMoE-7B模型的损失曲线:我们在图10中展示了OL