How Does Critical Batch Size Scale in Pretraining?

A1 主要贡献

本文旨在研究自回归Transformer语言模型预训练中关键批量大小(Critical Batch Size, CBS)的缩放定律。核心问题是,在扩大模型规模时,如何有效利用数据并行,而CBS是决定并行效率的关键阈值。研究目标是系统性地解耦模型大小(N)和数据大小(D)对CBS的影响,这在以往的研究中常被混淆。

为实现这一目标,研究者提出了一种衡量CBS的指标,并对一系列从8500万到12亿参数的自回归语言模型在C4数据集上进行预训练。通过广泛的超参数搜索和对批量大小、动量、学习率及其调度等因素的仔细控制,文章系统地研究了规模对CBS的影响,并拟合了关于模型和数据大小的缩放定律。

核心发现与创新点:

  1. 实证发现CBS主要随数据量扩展

    • 在模型和数据量按比例(Chinchilla设置)共同扩展时,CBS会增加。
    • 通过控制变量实验发现:当固定模型大小(N)而增加数据大小(D)时,CBS同样显著增加;但当固定数据大小(D)而增加模型大小(N)时,CBS几乎保持不变。这表明CBS的增长主要归因于数据量的增加,而非模型大小。
    • 这一发现意味着,在扩展训练数据时,可以通过增大数据并行度来减少串行训练时间,而不会损失计算效率(以FLOPs衡量)。
    图1:Chinchilla(左)和受控(中、右)设置中优化效率与关键批量大小的缩放。为了研究CBS在不同模型大小下的影响,我们跟踪了达到某个目标验证损失所需的相对步数。在Chinchilla设置中(左),我们保持数据与模型大小之比D/N = CChin为常数,并观察到CBS随规模增大而增加。然而,当控制模型大小(中)或数据大小(右)时,目标损失的增长主要取决于数据大小而非模型大小(第3节)。
    图1:Chinchilla(左)和受控(中、右)设置中优化效率与关键批量大小的缩放。为了研究CBS在不同模型大小下的影响,我们跟踪了达到某个目标验证损失所需的相对步数。在Chinchilla设置中(左),我们保持数据与模型大小之比D/N = CChin为常数,并观察到CBS随规模增大而增加。然而,当控制模型大小(中)或数据大小(右)时,目标损失的增长主要取决于数据大小而非模型大小(第3节)。
  2. 理论解释

    • 无限宽度极限:基于最大更新参数化(Maximal Update Parameterization),理论分析表明,当网络宽度增加到一定程度后,其训练动态和性能变得与模型大小无关。因此,CBS在模型大小超过某个点后将保持不变。
    • 最小二乘回归分析:通过分析带有小批量SGD的最小二乘回归问题,文章为CBS随数据量D增长提供了理论依据。分析表明,在特定条件下,最优CBS为 $B^*(D) = \Theta(D^c)$,其中指数 $c \geq 0$,这证实了CBS随数据量增长。
  3. 方法论贡献

    • 超越固定训练时长的研究策略:为解决在达到目标损失时预定义总训练时长的困难,本文提出使用指数权重平均(EWA)策略。实验证明,该策略性能可与余弦调度等方法相媲美,且无需预设训练时长。
    • 超参数影响分析:文章对Adam优化器的动量、二阶矩衰减率以及Transformer的上下文长度、宽度与深度缩放等常见超参数和配置进行了系统性消融实验,为大规模预训练提供了实践指导。

A2 方法细节

2 实验设计与实证发现

本节描述了实验设置,更多细节请参阅附录D。在本文中,我们使用缩写‘M’代表百万,‘B’代表十亿,‘T’代表万亿。

2.1 实验设置

模型和训练细节。我们训练了一系列上下文长度为512的自回归语言模型(LM),模型大小从85M、151M、302M、604M到1.2B不等(见附录D,表2),训练数据为C4(【索引20,Exploring the limits of transfer learning with a unified text-to-text transformer,2020,Journal of machine learning research】)。我们使用Adam优化器(【索引34,Adam: A method for stochastic optimization,2014,arXiv】),其特定的超参数见表3。我们采用了Eleuther AI的gpt-neox-20b的分词器,词汇表大小为50280。在大多数消融研究中,我们使用小型的151M代理模型来分析超参数。我们将微批量大小设置得比全局批量大小小,并使用梯度累积来模拟大全局批量大小的效果。我们专注于完全同步的分布式数据并行场景,其中通信频繁,这简化了评估并将实际的墙钟时间节省抽象为总优化步数。关于优化器配置和评估策略的更多细节包含在附录D中。

实验设计和纲要。为了研究Chinchilla设置中的CBS,我们需要考虑在留出验证集上的目标损失,并测量达到该损失所需的优化步数。我们将线性缩放区域中的最优批量大小视为$B_{opt}$,它不会带来效率开销,具体细节见附录C。我们考虑对于每个模型大小N和批量大小B,最优批量大小$B_{opt} = 256$在步骤$t_{Chin} = C_{Chin} \times N / (\text{ctx len} \times B)$时的验证损失,其中上下文长度ctx len设为512,CChin是Chinchilla系数。

联合扩展模型和数据规模的挑战与应对策略。当模型大小与数据大小联合扩展时,上述方法意味着每个模型大小都将有不同的目标损失。通过上述程序实现这一目标是具有挑战性的,不仅因为超参数组合众多,还因为每个模型大小和批量大小的训练动态是未知的。下面,我们概述了几个关键方面来达成目标:

  1. 由于我们关注的是达到目标验证损失所需的训练步数,学习率衰减策略通常需要预定义总训练时长(【索引40,Sgdr: Stochastic gradient descent with warm restarts,2022,International Conference on Learning Representations】,【索引28,MiniCPM: Unveiling the potential of small language models with scalable training strategies,2024】,【索引29,Scaling laws and compute-optimal training beyond fixed training durations,2024】,【索引10,The road less scheduled,2024,arXiv】)。为了解决这个问题,我们提出使用指数权重平均(EWA)(【索引51,Acceleration of stochastic approximation by averaging,1992,SIAM journal on control and optimization】)来达到期望的目标验证损失,这是一个简单的方法,能与其他流行选择相媲美(图2)。这使得训练可以超越固定的时长或数据大小,允许从检查点恢复训练,直到达到目标验证损失(第2.2节)。
  2. 使用适当的超参数进行训练:确保对动量和学习率进行适当的扫描(附录A);为每个批量大小量身定制,采用精心调整的$\beta_2$参数和指数权重平均衰减率$\tau$(附录B)。

2.2 超越固定时长的训练以达到目标验证损失

学习率调度器基准测试。在实践中,语言模型通常使用固定的token预算进行训练(【索引27,An empirical analysis of compute-optimal large language model training,2022,Advances in Neural Information Processing Systems】),这决定了训练将经历的总迭代次数。这个训练过程可以很容易地分解为学习率预热和衰减阶段,以便在训练结束时保持较低的学习率,从而实现更好的优化。然而,我们的目标是找到在各种超参数和优化条件下表现最佳的运行。这意味着在选择最大训练时长方面需要做出一个不平凡的决定(【索引10,The road less scheduled,2024,arXiv】,【索引29,Scaling laws and compute-optimal training beyond fixed training durations,2024】)。由于超越固定时长的训练在许多大规模预训练场景中特别有利,我们对最近提出的方法进行了基准测试,如免调度优化器(schedule-free optimizer)(【索引10,The road less scheduled,2024,arXiv】)、余弦调度、预热-稳定-衰减调度(WSD)(【索引28,MiniCPM: Unveiling the potential of small language models with scalable training strategies,2024】)(或梯形调度(【索引75,Scaling vision transformers,2022,Proceedings of the IEEE/CVF conference on computer vision and pattern recognition】))以及我们提出的常数+EWA策略。该策略通过维持模型权重的移动平均$\xi_{t+1} = \tau \cdot \xi_t + (1 - \tau) \cdot \theta_t$来改善优化,其中$\theta_t$是第t步的实际模型参数,我们使用$\xi$进行评估。为了确保基准接近最优,我们首先从我们的常数+EWA运行中获得最优步数,然后通过测试这些步数的[0.1, 0.2, 0.3]倍来评估WSD调度。同时,对于余弦调度器,我们探索了相同步数的[0.9, 1.0, 1.1]倍。我们表明,我们的常数+EWA策略可以匹配余弦调度和WSD的效率,特别是在大批量大小的情况下(图2)。它们之间的联系在(【索引46,Connections between schedulefree optimizers, ademamix, and accelerated sgd variants,2025,arXiv】)中有详细解释。

图2:比较和解释训练动态。在整个研究中,我们采用Constant+EWA,因为它在大批量大小下表现最佳,并且避免了为达到目标损失而预先设定固定的训练时长。
图2:比较和解释训练动态。在整个研究中,我们采用Constant+EWA,因为它在大批量大小下表现最佳,并且避免了为达到目标损失而预先设定固定的训练时长。

EWA在语言模型预训练中的应用价值。先前的研究已经证明了EWA的泛化(【索引30,Averaging weights leads to wider optima and better generalization,2018,arXiv】)和优化(【索引33,Analyzing and improving the training dynamics of diffusion models,2024,Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition】)优势,我们的发现进一步揭示,EWA可以在LM预训练中用内存换取优化效率,特别是在大批量训练场景中。这在必须达到目标损失,但从业者不确定为设置学习率调度而需要确切的最大数据量的情况下非常有用。

关于学习率调度的结论。与不使用EWA的恒定学习率相比,EWA持续提高了模型训练效率。事实证明,EWA是与其他采用衰减方案的基准方法相比的一种有效方法,它提供了有竞争力的性能,同时消除了预定义训练时长的需要。

2.3 模型上下文长度的消融实验

不同上下文长度对训练效率和CBS缩放的影响。我们在所有实验中采用512作为语言模型的上下文长度,但尚不清楚当我们扩大上下文长度时,它将如何影响训练效率以及CBS的缩放是否会变化。因此,我们对几个更大的窗口$2^{10}$、$2^{11}$、$2^{12}$进行了扫描(第2.4节)。总体而言,四种不同上下文长度的所有模型在各种批量大小下的相对优化效率都非常相似,从而证明了我们在所有实验中使用512是合理的。

关于模型上下文长度的结论:不同的上下文长度($2^9 \sim 2^{12}$)在批量大小方面表现出相似的缩放行为。

2.4 模型宽度和深度的消融实验

模型扩展策略:宽度与深度。模型大小通常可以通过两种主要方式进行扩展:一是增加宽度,即增大多层感知机(MLP)的隐藏层大小;二是增加深度,即向网络中添加更多层。由于图1中的主要结果仅涉及一种模型扩展方式(表2),例如,604M模型比151M模型的宽度增加了2倍。为了探索替代的扩展策略,我们研究了当我们将604M模型的深度增加4倍时模型行为的变化(详细配置见表5)。

图3:对于计算最优训练,扩展宽度和深度获得了相似的效率增益。
图3:对于计算最优训练,扩展宽度和深度获得了相似的效率增益。

宽度与深度扩展对CBS的影响。首先,如图3(左、中)所示,在Chinchilla缩放(数据和模型大小成比例增长)下,增加模型深度或宽度对CBS有相似的影响。值得注意的是,根据之前的结果,计算最优缩放中CBS的上升应归因于数据量的增加。然后通过受控比较(图3,右),我们看到使用两种不同的方式将151M模型扩展到604M模型在效率上是等效的,因为两条曲线重叠。我们的发现可能为在固定token预算下(该预算与模型大小成比例分配)扩展模型提供实用的见解。这一点尤其重要,因为扩展模型宽度通常比增加深度更受青睐,因为更宽的模型往往更容易并行化,而不会产生额外的延迟开销(【索引60,Megatron-lm: Training multi-billion parameter language models using model parallelism,2019,arXiv】,【索引64,Llama: Open and ` efficient foundation language models,2023,arXiv】,【索引15,Data movement bottlenecks to large-scale model training: Scaling past 1e28 flop,2024】,【索引45,Gemstones: A model suite for multi-faceted scaling laws,2025,arXiv】)。

图4:使用151M模型在上下文长度上的消融实验结果。
图4:使用151M模型在上下文长度上的消融实验结果。

关于在计算最优场景下扩展Transformer宽度和深度的结论:在计算最优的预训练中,增加宽度和深度对关键批量大小有相似的影响。

3 关键批量大小缩放定律

3.1 关键批量大小的正式定义

CBS的直观定义。回想一下,CBS是一个转折点,在该点将批量大小增加k倍,导致所需训练步数的减少因子小于k。我们现在将CBS定义为导致与线性缩放相比产生20%开销的批量大小。首先,定义$R(N, D, B)$为使用大小为N的模型,在D个token上单次遍历,批量大小为B时可达到的最佳损失。这将通过在保持N、D、B固定的情况下,优化调整所有其他优化器参数来获得。以下是CBS的正式定义:

CBS的数学定义
定义1. 定义$R_{opt}(N, D) = \min_B R(N, D, B)$为通过优化批量大小达到的最小损失,而$B_{opt}(N, D) = \arg \min_B R(N, D, B)$为最优批量大小。我们定义$f_{N,D}(B)$为达到$R_{opt}(N, D)$所需的步数,作为批量大小B的函数。显然$f_{N,D}(B_{opt}) = D/B_{opt}$。为了定义关键批量大小$B^*(N, D)$,我们可以定义一个线性缩放曲线$f^*_{N,D}(B) = D/B$。$f^*$在$B_{opt}$处与f匹配,然后随着批量大小的增加而线性下降。$B^*(N, D)$被定义为满足$f_{N,D}(B') \leq 1.2f^*_{N,D}(B')$的最大批量大小$B' > B_{opt}(N, D)$。

CBS的图示说明。如图5所示,$B^*(N, D)$是这样一个批量大小,在该批量大小下,步数比从最优批量大小线性外推预测的步数高出20%。请注意,这里的20%可以替换为任何其他合适的衡量偏离线性缩放的指标。

图5:关键批量大小的图示,其中 B* = 2^11.87,默认上下文长度为512。
图5:关键批量大小的图示,其中 B* = 2^11.87,默认上下文长度为512。

3.2 针对Chinchilla最优预训练的模型大小缩放定律

大模型与批量大小效率的关系。正如在上述所有结果中所观察到的,对于更大的模型,加倍批量大小使其能更有效地减少达到目标损失所需的相对步数。我们想知道这些增加的效率是否可以通过一个缩放定律来预测。

CBS的推导过程。我们第一步是通过拟合批量大小(B)与达到目标损失所需的绝对步数(Y)的幂律关系$\log(Y) = \log(a+ bB^\alpha)$,然后推导出关键批量大小。接着,我们通过$B^* = \sqrt[\alpha]{\frac{b}{5a} + 1.2B_{opt}}$来推导CBS,这是由一个转折点所隐含的,即在该批量大小下,总数据量会比线性缩放产生20%的开销:$D_{total} = (a + b/B^\alpha_{opt}) \times 1.2B_{opt} = (a + b/B^\alpha) \times B$,其中$\alpha=1$,$B_{opt}$被设定为256,这个值被选定在线性缩放区域内,如附录C所示。我们在附录表6中报告了拟合步数Y和批量大小B之间幂律关系的参数。我们采用固定的$\alpha=1$的解,因为两种策略产生的预测结果几乎相同。

CBS关于模型大小的缩放定律。其次,我们拟合了一个关于模型大小N(单位:百万)的幂律$\log(B^*) = \log(c + dN^\beta)$。常数项c默认设置为0(因为当N=0时,B*应为0),这导出了$B^* = 93.20 \times N^{0.47}$。我们在图1(左)中可视化了拟合曲线,并在附录的表7中报告了更多的预测。

Chinchilla最优训练中CBS的增长趋势。总的来说,我们观察到在计算最优训练中进行扩展时,CBS会增加:在图1(左)中,我们根据Chinchilla步数选择了目标损失,并拟合了关键批量大小相对于模型大小N(单位:百万)的幂律为$B^* = 93.20 \times N^{0.47}$。我们的结果表明,对于在Chinchilla最优的token数量上优化小于1B的模型,大约$2^9$到$2^{11}$的关键批量大小将有助于高效优化,以便研究其他经验性问题。然而,训练的token数量通常与模型的参数数量成比例扩展。因此,尚不清楚CBS的增长是因为(1)模型大小的增加,还是(2)数据大小/训练时长的增加,我们将在下一小节探讨这个问题。

3.3 解耦CBS关于数据大小和模型大小的缩放定律

固定数据大小的对照比较。首先,我们使用151M模型的Chinchilla token大小3.072B来记录每个模型大小的目标验证损失,并再次以较短的训练时长训练所有302M、604M、1.2B模型以达到这些目标损失。为了在用较少token训练时优化性能,我们还相应地调整了预热步骤。图1(右上)显示所有曲线的行为相似,并且我们观察到在扩大模型大小时CBS几乎没有增加。此外,我们随后拟合了关于模型大小的缩放定律,如图1(右下)所示:保持数据大小固定导致了一个弱依赖于模型大小的缩放定律$B^* = 621.341 \times N^{0.087}$。

固定模型大小的对照比较。此外,我们专注于302M模型,通过选择批量大小为256的运行在Chinchilla步数的0.28倍、0.5倍、2倍和4倍处的目标损失来进行额外的实验。这个设置导致了两种欠训练和两种过训练配置。为了在过训练场景中实现最佳性能,我们相应地增加了预热比例,而对于欠训练情况,我们按比例减少了预热比例。图1(中)的结果显示,随着我们增加训练的token数量,我们看到CBS的增加,这与我们观察到的在Chinchilla目标损失上训练大模型的情况相似。这也可以从图1(中)所示的预测CBS曲线中看出,该曲线显示随着我们增加训练的token数量,CBS会增加,这与在Chinchilla设置中模型和数据大小成比例扩展时观察到的情况相似。

模型大小与数据大小对CBS影响的综合分析。我们还将同时扩展N和D(图1,左)以及仅扩展D(图1,右)的结果一并绘制在图6中。在并排比较中,我们观察到以下趋势:(i)在Chinchilla设置中(由图例第一列表示),在不同token数量上训练的各种大小的模型(85M、151M、604M、1.2B)显示,随着规模的增长,关键批量大小增加。(ii)此外,每对相同颜色的曲线显著重叠,表明在相同token数量上训练的不同大小的模型往往具有相似的关键批量大小。(iii)最后,当模型大小保持不变,仅数据大小(图例第二列)变化时,我们也观察到关键批量大小随规模增加。因此,我们可以定性地理解,CBS的增加很可能与模型大小无关,而是由于训练时长的增加。

图6:通过在不同数量的token上训练302M模型,然后与其他在相似数量token上训练的模型大小进行受控比较。相同颜色或位于图例同一行的模型代表了这种比较。对于固定的3.072B token计数,我们测量了每个模型大小在该步骤的目标损失。
图6:通过在不同数量的token上训练302M模型,然后与其他在相似数量token上训练的模型大小进行受控比较。相同颜色或位于图例同一行的模型代表了这种比较。对于固定的3.072B token计数,我们测量了每个模型大小在该步骤的目标损失。

关于关键批量大小缩放定律的结论:基于缩放定律和受控比较,我们得出结论,在Chinchilla最优训练中,CBS的增加更强烈地归因于扩展的数据大小或训练时长,而非模型大小的增加。

A3 背景知识/关键Observation/设计原则

4 关于关键批量大小缩放的理论

我们的实验结果表明,CBS随着数据量的增大而增加,但在扩大模型大小时(几乎)保持不变。我们现在通过理论分析来正式研究这两种情况下的观察结果。

4.1 固定数据大小并扩大模型规模

神经网络无限宽度极限的理论基础。以往的多项研究已经确立了神经网络的无限宽度极限(【索引67,Tensor programs iv: Feature learning in infinite-width neural networks,2021】,【索引3,Self-consistent dynamical field theory of kernel evolution in wide neural networks,2022】)。对于遵循这些极限的初始化和架构,我们可以从理论上断言,在固定的训练时长和批量大小下,神经网络的性能会随着宽度的增加而渐近收敛。正式陈述如下:

定理2:无限宽度下的性能收敛性
定理2. 对于给定批量大小B(或对于梯度下降,即B→∞)、训练迭代次数t、误差容限$\epsilon > 0$、固定的学习率调度和数据顺序,对于任何满足Yang & Hu (2021)中主定理(定理G.4)的网络和初始化,存在一个宽度w,使得对于任何两个宽度为w1, w2 > w的网络M1, M2,有$|R(M1, t) - R(M2, t)| \leq \epsilon$,其中R(M, t)表示网络M在时间t的损失。

定理2的证明。证明源于这样一个事实:当宽度趋于无穷大时,网络的轨迹会趋近一个极限。因此,根据极限的定义,存在一个宽度w,使得任何宽度大于w的两个网络在时间t的损失差异最多为$\epsilon$。

最大更新参数化(μP)与超参数传递。请注意,随着宽度增加而固定学习率调度的假设可能看起来很强,但最近的研究(【索引68,Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer,2022,arXiv】)表明,其中一种被称为最大更新参数化(µP)的初始化方法,在宽度上表现出超参数传递的特性。这种初始化方案也因其这一特性最近变得流行,并被许多开源实现所使用(【索引12,Cerebras-gpt: Open compute-optimal language models trained on the cerebras wafer-scale cluster,2023】,【索引13,Btlm-3b-8k: 7b parameter performance in a 3b parameter model,2023】,【索引38,Llm360: Towards fully transparent open-source llms,2023】,【索引28,MiniCPM: Unveiling the potential of small language models with scalable training strategies,2024】)。此外,有研究(【索引68,Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer,2022,arXiv】,【索引66,Feature-learning networks are consistent across widths at realistic scales,2023】)通过经验证明,使用µP,网络在实际宽度下开始展现出一致的损失曲线。

理论的普适性。此外,由于上述定理对于固定的批量大小B以及B→∞都成立,我们期望存在一个有限的宽度w,使得上述定理对所有批量大小B都成立。因此,对于固定的训练token数,我们预期关键批量大小在超过某个点后不会随模型宽度扩展。尽管我们主要讨论了扩展模型宽度,但请注意,一些最近的结果也为无限深度的ResNets和transformers建立了此类极限(【索引69,Tensor programs VI: Feature learning in infinite depth neural networks,2024】,【索引5,Infinite limits of multi-head transformer dynamics,2024】,【索引6,Depthwise hyperparameter transfer in residual networks: Dynamics and scaling limit,2024】),因此上述论点也适用于这些网络。

4.2 固定模型大小并扩大数据规模

高斯线性回归问题的设置。我们现在转向研究在指定明确的高斯线性回归问题中,小批量SGD中数据大小的影响。设(x, y)是从一个总体分布中抽取的协变量和响应对。设总体风险和总体分布为

$$\mathcal{R}(\mathbf{w}) := \mathbb{E}(\mathbf{x}^\top \mathbf{w} - y)^2, \quad \mathbf{x} \sim \mathcal{N}(0, \mathbf{H}), \quad y|\mathbf{x} \sim \mathcal{N}(\mathbf{x}^\top \mathbf{w}^*, \sigma^2)$$

其中w是可训练参数,期望是在总体分布上计算的,并且($H, w^*, \sigma^2$)指定了总体分布。给定来自总体分布的D个独立样本${(x_i, y_i)}_{i=1}^D$,我们考虑由小批量SGD给出的估计,

$$\mathbf{w}_{0}=0, \quad \mathbf{w}_{t+1}=\mathbf{w}_{t}-\gamma \frac{1}{B} \sum_{j=t B}^{(t+1) B-1}\left(\mathbf{x}_{j}^{\top} \mathbf{w}_{t}-y_{j}\right) \mathbf{x}_{j}, \quad t=0, \ldots, n-1,$$

其中$\gamma > 0$是恒定学习率,B是批量大小,$n := D/B$是步数,$w_0 = 0$是初始化(不失一般性),输出是迭代的平均值,$\bar{w} := \frac{1}{n} \sum_{t=0}^{n-1} w_t$。那么,以下定理为小批量SGD迭代平均值所达到的超额风险提供了一个紧密的界。

符号约定。我们记$f(D) \lesssim g(D)$,如果存在一个正常数c使得对于每个$D \geq 1$都有$f(D) \leq cg(D)$。我们记$f(D) \approx g(D)$,如果$f(D) \lesssim g(D) \lesssim f(D)$。所有证明都推迟到附录G。

定理3:小批量SGD的超额风险界
定理3. 令$(\lambda_i)_{i>0}$为H的非增排序的特征值。假设$\|w_0 - w^*\|_H^2 \lesssim \sigma^2$。那么对于每个$\gamma \lesssim \min\{B/\text{tr}(H), 1/\|H\|_2\}$,我们有

$$\mathbb{E} \mathcal{R}(\bar{\mathbf{w}})-\sigma^{2} \simeq\left(\frac{B}{D \gamma}\right)^{2}\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}_{0: k^{*}}^{-1}}^{2}+\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}_{k^{*}: \infty}}^{2}+\sigma^{2} \frac{k^{*}+(D \gamma / B)^{2} \sum_{i>k^{*}} \lambda_{i}^{2}}{D},$$

其中$k^* := \max\{k : \lambda_k \geq B/(D\gamma)\}$,期望是关于$\bar{w}$的随机性计算的。

定理3的背景与意义。定理3的证明受到Zou等人(2023)的启发。为简单起见,我们关注于明确指定的高斯数据分布,但这可以放宽到遵循Zou等人(2023)结果的四阶矩条件下的模型误设情况。定理3表明,对于固定的数据大小D,超额风险仅通过比率$\gamma/B$依赖于批量大小B和学习率$\gamma$。此外,较大的$\gamma/B$倾向于减小偏差误差(依赖于$w^*$的项),但会增加方差误差(依赖于$\sigma^2$的项),反之亦然。这一观察在下面的推论中被利用,我们计算了在不牺牲所达到超额风险率的情况下最小化顺序运行时间的CBS。

推论2:最优超参数与CBS
推论2. 在定理3的设定下,额外假设$\sigma^2 \approx 1$以及以下容量和源条件:

$$for\ a, b > 1: \quad \lambda_{i} \simeq i^{-a}, \quad \mathbb{E} \lambda_{i} \langle \mathbf{v}_{i}, \mathbf{w}_{i}^{*} \rangle^{2} \simeq i^{-b}, \quad \mathbb{E} \lambda_{i} \langle \mathbf{v}_{i}, \mathbf{w}_{i}^{*} \rangle \langle \mathbf{v}_{j}, \mathbf{w}_{j}^{*} \rangle = 0 \ for\ i \neq j,$$

其中$(\lambda_i, v_i)_{i>0}$是H的特征值和相应的特征向量,期望是关于$w^*$的先验计算的。那么我们有

  1. 当$b \leq a$时,最优超参数(使期望超额风险最小化到常数因子)是$\gamma^* \approx 1$和$B^* = 1$。
  2. 当$b > a$时,最优超参数是$\gamma^*$和$B^*$,使得

$$0<\gamma^{*} \lesssim 1, \quad 1 \leq B^{*} \leq D, \quad \gamma^{*} / B^{*} \simeq D^{\frac{a}{\min \{b, 2 a+1\}}-1} .$$

因此,CBS是$B^* \approx D^{1-a/\min\{b,2a+1\}}$,这(连同$\gamma^* \approx 1$)允许小批量SGD输出$\bar{w}$在数据大小D增长时以最小的步数n达到期望超额风险的最优速率。

推论2的解释。容量和源条件来自非参数线性回归文献(【索引7,Optimal rates for the regularized least-squares algorithm,2007,Foundations of Computational Mathematics】),最近被用于研究缩放定律理论(【索引4,A dynamical model of neural scaling laws,2024,Forty-first International Conference on Machine Learning】,【索引37,Scaling laws in linear regression: Compute, parameters, and data,2024,The Thirty-eighth Annual Conference on Neural Information Processing Systems】,【索引48,4+ 3 phases of computeoptimal neural scaling laws,2024,The Thirty-eighth Annual Conference on Neural Information Processing Systems】)。根据推论2,当$b \leq a$时,偏差误差倾向于主导方差误差,在这种情况下,CBS是$B^*=1$以允许最大数量的优化步骤。当$b>a$时,方差误差倾向于主导偏差误差,并且批量大小和学习率的最优选择平衡了这两个误差。虽然可以使用$B^*=1$和一个小的$\gamma^*$来达到最佳的超额风险率,但这会导致次优的顺序运行时间($n=D/B$)。在这种情况下,CBS是$B^* \approx D^{1-a/\min\{b,2a+1\}}$,它在最小化顺序运行时间的同时达到了最优的超额风险率。

关于数据和模型规模扩展时批量大小缩放的理论总结
- 当我们在保持数据大小固定的情况下扩展模型大小时,µP理论表明,关键批量大小在超过某个点后不会随模型宽度扩展。
- 固定模型大小,关键批量大小随训练时长的增加而增加。在高维线性回归的背景下,当方差误差主导偏差误差时,可以为小批量SGD选择一个大的批量大小(作为数据大小的函数),从而在不损害超额风险最小化速率的情况下减少顺序运行时间。

5 相关工作

缩放定律。缩放定律描述了训练神经网络中关键因素之间的参数关系:模型大小N,数据集大小D,训练成本C,以及最终训练损失R。这些定律使得能够根据可用资源预测训练损失R,从而可以优化资源分配以实现高效的模型训练。例如,Hestness等人(2017)发现$R \propto D^{-\alpha}$,其中$\alpha \in [0.07, 0.35]$。在他们改变的因素中,只有任务能改变指数$\alpha$。改变架构、优化器、正则化器和损失函数只会改变比例因子,而不会改变指数;Henighan等人(2020)研究了N、D、C、R在广泛取值范围内的统计关系,并发现了相似的缩放定律,范围涵盖$N \in [10^3, 10^9]$,$C \in [10^{12}, 10^{21}]$,以及多种模态(文本、视频、图像、文本到图像等)。(Kaplan等人,2020)指出N的扩展速度应快于D。然而,Chinchilla缩放(Hoffmann等人,2022a)发现模型训练不足,并建议在给定增加的预算(以FLOPs计)时,为达到计算最优,模型大小N和数据大小D应以大致相等的比例扩展。最近的努力(Pearce & Song, 2024; Besiroglu et al., 2024; Porian et al., 2024)致力于复现(Hoffmann等人,2022a)和(Kaplan等人,2020)的缩放定律。与我们关注衡量CBS效率概念不同,他们大多数关注在给定固定计算预算FLOPs ≈ 6ND的情况下,从小型训练中推导出包括学习率和批量大小在内的最优超参数(Bi et al., 2024; Porian et al., 2024),而没有解耦模型大小和数据大小的影响。

优化与关键批量大小。先前的研究表明,在小规模场景下,增加批量大小可以通过相应调整学习率来抵消(【索引44,An empirical model of large-batch training,2018,arXiv】,【索引76,Which algorithmic choices matter at which batch sizes? insights from a noisy quadratic model,2019,Advances in neural information processing systems】,【索引32,Scaling laws for neural language models,2020,arXiv】,【索引36,On the validity of modeling sgd with stochastic differential equations (sdes),2021,Advances in Neural Information Processing Systems】)。McCandlish等人(2018)引入了梯度噪声尺度,这是一个捕捉不同训练样本间梯度变化的度量,有助于预测关键批量大小(CBS)。他们的发现也表明,小批量训练在计算上更高效,而大批量训练需要更少的优化器步骤。基于动量的方法将缩放扩展到更大的批量大小,但在较小的批量大小时其性能收敛于标准SGD(Shallue等人,2019)。此外,Zhang等人(2019)使用一个带噪声的二次模型分析了曲率对CBS的影响,证明了预处理技术可以增加CBS。Golmant等人(2018)表明,与模型架构和数据复杂性等因素相比,数据集的大小在决定训练效率方面扮演着较小的角色。相反,Hilton等人(2022)研究了如何在较小的批量大小下保持性能。同时,Smith等人(2017);Smith & Le(2017)凭经验研究了最优学习率如何根据动量和训练集大小变化。理论工作进一步试图通过分析最小二乘线性回归中SGD的行为来刻画CBS,特别是在过参数化设置中(Jain等人,2018;Ma等人,2018)。Filatov等人(2024)同时发现最优批量大小和CBS随数据大小扩展。然而,他们没有探讨CBS如何随超过3.54亿参数的模型大小扩展,也没有提供理论依据或解决在广泛超参数范围内选择最优运行的挑战。我们的工作通过形式化CBS并量化其相对于数据大小的增长,以及强调常见超参数选择的重要性,推动了优化文献的发展。它还为研究超越固定训练时长的大规模预训练提供了策略。

A4 实验环境

A4 实验结果

本研究通过一系列实验,系统地分析了关键批量大小(CBS)的缩放行为,核心结论是CBS主要受数据量而非模型大小的影响。

  1. 主要发现:CBS随数据量而非模型大小扩展

    • 在模型大小(N)和数据量(D)按比例共同增加的Chinchilla设置下,CBS显著增长。拟合的缩放定律为 $B^* = 93.20 \times N^{0.47}$(图1左)。
    • 通过控制变量法解耦了N和D的影响:

      • 固定数据量,改变模型大小:当所有模型(85M至1.2B)在相同的数据量(3.07B tokens)上训练时,CBS几乎保持不变。此时的缩放定律为 $B^* = 621.341 \times N^{0.087}$,指数接近0,表明CBS对模型大小依赖性很弱(图1右)。
      • 固定模型大小,改变数据量:当固定模型(302M)而增加训练数据量时,CBS的增长趋势与Chinchilla设置非常相似(图1中)。
    • 这些结果共同证实,CBS的增长主要由训练数据量的增加驱动。

  2. 学习率调度策略的有效性

    • 本文提出的“恒定学习率+指数权重平均(EWA)”策略,在达到目标损失方面,其效率与经过精心调优的余弦(Cosine)调度和预热-稳定-衰减(WSD)调度相当,尤其在大批量大小下表现优异。此策略的优势在于无需预先设定总训练步数(图2)。
  3. 模型架构的消融研究

    • 上下文长度:将上下文长度从512增加到4096,对不同批量大小下的相对优化效率影响甚微,证明了使用512作为默认上下文长度的合理性(图4)。
    • 宽度 vs. 深度:将模型从151M扩展到604M时,无论是通过增加宽度还是增加深度,对CBS的影响都非常相似。这表明在计算最优的训练设置下,这两种模型扩展方式在优化效率上是等效的(图3)。
  4. 优化器超参数的影响

    • Adam动量$\beta_1$:实验表明,较大的动量值(如0.95)对于提高训练效率至关重要,而无动量或动量过小/过大都会损害性能(附录图12)。
    • Adam二阶矩衰减率$\beta_2$:$\beta_2$的最优值依赖于训练时长和批量大小。对于短期、大批量训练,较小的$\beta_2$(如0.95)更优;而对于长期训练或小批量训练,较大的$\beta_2$(如0.99或更高)能显著改善优化效果(附录图8)。
  5. 线性缩放区域的验证

    • 实验确认,在较小的批量大小范围内(例如$2^6$到$2^{10}$),所有模型都表现出近似线性的缩放行为,即批量大小加倍,达到目标损失的步数减半(附录图13)。这验证了实验中用于定义CBS基准的$B_{opt}$选择的合理性。

A5 结论

总而言之,本研究对大规模自回归语言模型预训练中关键批量大小(CBS)的缩放定律进行了广泛的考察。通过系统分析模型大小、数据大小和CBS之间的关系,我们发现,虽然CBS随数据大小的增加而增加,但它对模型大小相对不敏感。这一发现表明,在更多数据上进行训练可能允许在预训练中实现更大的数据并行性。我们进一步强调了关键超参数和指数权重平均的作用,后者可以在不需要固定训练时长的情况下匹配余弦调度的性能。这些见解为在资源受限的情况下扩展模型同时保持效率提供了实用的策略。

A6 附录

A 训练动态

一种设置预热步数的简单策略。为了进一步证明关键批量大小确实存在,并且大批量大小的饱和并非由于未使用适当超参数进行良好训练的假象,我们同样考虑了训练中的预热比例:

预热步数比例的消融实验。我们对预热步数比例(即学习率从零线性增加所需的训练步数分数)进行了扫描,范围为0.25和0.1,发现0.25对85M模型效果最佳。因此,我们将此预热步数固定为未来实验中$t_{Chin}$的0.25。对于151M模型,我们扫描了{0.15, 0.25, 0.35}的预热步数分数。我们在图7中显示,使用0.25的预热比例是一个合理的设计选择,因为它始终比0.15表现更好,且仅略逊于0.35。在我们发现根据这一启发式方法设置预热步数后,我们将其比例应用于所有其他模型大小。这一策略在(【索引52,Resolving discrepancies in compute-optimal scaling of language models,2024,arXiv】)中也显示出其有效性。

图7:在线性学习率预热阶段使用的大批量大小的预热步数消融实验。
图7:在线性学习率预热阶段使用的大批量大小的预热步数消融实验。

审视训练的最后阶段。通过仔细审视训练过程的最后阶段(图8),可以明显看出,应用指数加权平均(EWA)有助于平滑噪声,使优化能更有效地收敛到目标损失。例如,即使对于一个1.2B模型和1024的中等批量大小,也需要一个非常高的EWA衰减率。此外,我们观察到优化过程受训练最后阶段的显著影响。例如,到第10,000步时,大多数运行的验证损失都低于3.2(图8a),同样,到第30,000步时,损失低于2.8(图8b)。然而,为了达到2.736的目标损失,最佳运行和次佳运行之间的差距大幅增加,最佳运行所需步数减少了5,000多步。

图8:足够大的EWA衰减率τ和Adam β2对于长时间训练至关重要。我们绘制了1.2B模型的评估曲线,因为在Chinchilla设置中,我们按模型大小成比例地扩展数据大小。当增加训练token数量时,仔细设置β2和τ的适当值以有效考虑效率至关重要。
图8:足够大的EWA衰减率τ和Adam β2对于长时间训练至关重要。我们绘制了1.2B模型的评估曲线,因为在Chinchilla设置中,我们按模型大小成比例地扩展数据大小。当增加训练token数量时,仔细设置β2和τ的适当值以有效考虑效率至关重要。

其他批量大小的调度器比较。在图9中,我们包含了更多关于不同调度器的比较,这些调度器在正文(图1)中有所报告。总的来说,我们的Constant+EWA与余弦调度表现相当,并且优于WSD调度,尤其是在大批量大小的情况下。请注意,我们对WSD调度的衰减步骤进行了扫描,其值为总训练步骤的0.1、0.2、0.3倍。我们通过对各种最大优化步骤进行扫描来调整余弦调度以确定最优值,然后使用该步数重新运行训练。这种方法确保模型在训练接近结束时达到目标损失,从而优化学习率衰减的性能。对于免调度优化器,我们调整了$\beta_1$为0.9、0.95、0.98。在小批量大小下,免调度优化器(【索引10,The road less scheduled,2024,arXiv】)是一个有竞争力的基准,但对于大于1024的批量大小,其表现明显更差。

图9:不同批量大小的调度器比较。所有模型大小均为151M。
图9:不同批量大小的调度器比较。所有模型大小均为151M。

更长的训练需要更高的EWA衰减率τ。在整篇论文中,我们在大多数实验中采用0.00316的学习率,但尚不清楚这是否会是次优的,特别是因为更长时间的训练可能需要更低的学习率,正如(【索引9,Deepseek llm: Scaling open-source language models with longtermism,2024】)所建议的。因此,我们通过在一些151M模型上进行以下实验,来证明我们在不同训练时长上调整EWA衰减率以模拟学习率衰减的设计决策是合理的:(a) 批量大小256,0.5倍Chinchilla tokens;(b) 批量大小256,20倍Chinchilla tokens;(c) 批量大小2048,20倍Chinchilla tokens,所有实验的学习率都在{0.00316, 0.00158, 0.01264, 0.00632, 0.00075}中扫描,EWA衰减率τ在{0.99, 0.9968, 0.999, 0.99968, 0.9999}中扫描。我们设置(a)的预热步数为总步数的0.25,(b)和(c)为0.05。图10中的结果表示训练结束时的验证损失,它一致显示,在每组实验中,我们在整篇论文中使用的0.00316的学习率始终是最好的。此外,当将训练数据大小从0.5倍Chinchilla tokens增加到20倍时,最优的EWA衰减率值τ也会增加。这在图8和表4的结果中也得到了证实,表明更长的训练时长可能受益于更高的EWA衰减率以提高优化性能。

图10:学习率和EWA衰减率在不同训练时长下的影响。我们报告了每个超参数组合在训练结束时的验证损失。N表示模型大小。最佳损失由一个符号标记。如(b)和(c)所示,更长的训练时长需要给定学习率下更高的EWA衰减率。
图10:学习率和EWA衰减率在不同训练时长下的影响。我们报告了每个超参数组合在训练结束时的验证损失。N表示模型大小。最佳损失由一个符号标记。如(b)和(c)所示,更长的训练时长需要给定学习率下更高的EWA衰减率。

权重衰减的影响。虽然权重衰减(WD)对预训练没有提供泛化优势,但以往的研究表明它可能改善语言模型训练的收敛性(【索引27,An empirical analysis of compute-optimal large language model training,2022,Advances in Neural Information Processing Systems】,【索引35,Rotational equilibrium: How weight decay balances learning across neural networks,2023,arXiv】)。我们在PyTorch中采用了默认的解耦权重衰减(【索引39,Decoupled weight decay regularization,2017,arXiv】)实现,并对学习率为0.01时权重衰减率{0.01, 0.0316, 0.1}和学习率为0.00316时{0.0316, 0.1, 0.316}进行了扫描。我们在图11中显示,对于带EWA的恒定学习率,虽然权重衰减提供了轻微的性能提升,但它对关键批量大小的影响很小,而这仍然是我们的主要关注点。因此,我们在整篇论文中禁用了权重衰减。

B 研究Adam优化器的额外消融研究

在整篇论文中,我们采用Adam作为大规模模型训练的默认优化器。

在本节中,我们关注两个显著影响优化效率的关键超参数,并详细考察它们的影响。

Adam动量$\beta_1$对CBS的影响。我们对Adam中的几个动量$\beta_1$值在所有学习率和批量大小下进行了扫描:[0, 0.8, 0.9, 0.95, 0.975]。总的来说,图12显示语言模型预训练可能需要一个较大的动量值才能高效,并且对于批量大小而言,$\beta_1=0.95$比0.9略好(在评估损失上有<0.02的增益)。我们观察到,在像26这样的小批量大小范围内,带与不带$\beta_1$动量优化的性能差距很小,而随着我们加倍批量大小,这个差距会增加(【索引59,Measuring the effects of data parallelism on neural network training,2019,Journal of Machine Learning Research】)。此外,我们表明动量0.9和0.975对于达到目标验证损失所需的步数和关键批量大小有相似的影响。另一方面,小动量,尤其是不使用动量,会损害优化。这与动量在带动量的SGD中被广泛研究的加速效果非常吻合(【索引17,Why momentum really works,2017,Distill】)。

图11:有无权重衰减的效率比较
图11:有无权重衰减的效率比较

目标验证损失3.3
图12:动量的消融实验结果。所有数据点均使用151M模型进行训练,并报告了达到固定目标损失所需的总优化步数。

Adam中二阶矩衰减率$\beta_2$的影响。如附录表4所报告,我们发现Adam中的$\beta_2$,即用于平滑模型更新的梯度二阶动量估计的指数衰减率,对于小批量大小的训练也有显著影响。这可能是因为小批量训练中的梯度更稀疏。具体来说,我们对所有模型大小和批量大小在[64, 128, 256, 512]范围内对$\beta_2 \in [0.95, 0.99, 0.999]$进行了消融实验。我们发现,以往工作中为数百万tokens批量大小训练设置的默认值0.95可能是次优的(【索引62,Using deepspeed and megatron to train megatron-turing nlg 530b, a large-scale generative language model,2022,arXiv】,【索引67,Small-scale proxies for large-scale transformer training instabilities,2023】,【索引22,Olmo: Accelerating the science of language models,2024,arXiv】)。对于大批量大小[1024, 2048, 4096, 8192],我们用151M模型大小实验了一个较小的$\beta_2=0.9$,发现它比我们选择的默认值0.95要差。当用更长的时长训练一个更大的模型时(例如附录图8b中的Chinchilla设置),一个足够高的$\beta_2$是必要的。

关于Adam优化器的结论

C 包含小批量大小的结果

小批量范围内的线性缩放行为。为了完整性,我们展示了所有模型大小在小批量范围内的线性缩放行为(图13)。这表明所有模型在批量大小从$2^6$到$2^{10}$的范围内都表现出线性缩放(有合理的偏差),其中批量大小加倍大致使达到目标验证损失所需的步数减半,该目标损失由批量大小为256的最优运行在Chinchilla步数确定。

图13:线性缩放区域:加倍批量大小可以将达到目标损失的优化步骤减半。
图13:线性缩放区域:加倍批量大小可以将达到目标损失的优化步骤减半。

包含极小批量大小的完整结果。此外,我们包含了包含最小几个批量大小的所有结果(图14)。注意,现在的分母是批量大小为64时达到目标损失的步数,而不是256。现在我们可以观察到所有模型大小的清晰线性缩放直到大约$2^{10}$,而最大的三个模型大小的线性缩放几乎维持到$2^{11}$。由于在像64这样非常小的批量大小下进行优化的困难,与图1中的主图存在微小差异,但这不影响我们想要传达的结论和要点。由于我们的重点主要在大批量大小上,我们在正文中始终使用$2^9$作为起始批量大小。此外,回想一下我们在选择目标损失时设置了$B_{opt}=2^8$。图13和图14证实了$2^8$处于线性缩放区域内,这证明了我们的设计选择是合理的。

图14:Chinchilla设置中不同模型的完整结果。我们既包括了最大批量大小2^14,也从几个小批量大小2^6, 2^7, 2^8开始绘图。报告了相对于批量大小2^6的步数。
图14:Chinchilla设置中不同模型的完整结果。我们既包括了最大批量大小2^14,也从几个小批量大小2^6, 2^7, 2^8开始绘图。报告了相对于批量大小2^6的步数。

D 实验的额外细节

优化器设置。对于优化器,我们尝试了SGD(【索引54,A stochastic approximation method,1951,The annals of mathematical statistics】)和Adam(【索引34,Adam: A method for stochastic optimization,2014,arXiv】),发现没有动量的SGD效果明显更差,因此我们在所有实验中仅使用Adam。我们在Adam中禁用了权重衰减,因为我们观察到它对关键批量大小没有显著影响(图11)。为了一般性,尽管训练集C4可能包含低质量或重复的文档,这可能导致训练不稳定(【索引47,Scaling data-constrained language models,2023,arXiv】,【索引67,Small-scale proxies for large-scale transformer training instabilities,2023】),我们观察到这些问题并未影响我们的主要关注目标——即最终的优化效率。因此,我们没有明确采用额外的归一化,如QK归一化(【索引11,Scaling vision transformers to 22 billion parameters,2023,International Conference on Machine Learning】,【索引75,Stabilizing transformer training by preventing attention entropy collapse,2023,Proceedings of the 40th International Conference on Machine Learning】)或z-loss(【索引8,Palm: Scaling language modeling with pathways,2022】)来减轻损失尖峰。我们默认设置$\epsilon$为1e-8,并在整篇论文中默认将Adam中的动量称为$\beta_1$。

用于确定目标损失的Chinchilla步数(批量大小为256)。对于每个模型大小,我们的目标是通过在Chinchilla最优的token数量上使用256的全局批量大小进行训练来建立一个目标验证损失。鉴于整个实验中使用的上下文长度为512,我们可以根据下表1确定所需的训练步数。我们使用约20.34的token与模型大小比率$C_{Chin}$来研究加倍批量大小的减半效应,并观察其对关键批量大小的影响。

表1:确定各模型大小目标损失的Chinchilla步数。
表1:确定各模型大小目标损失的Chinchilla步数。

评估数据大小和频率。为了确保在留出的C4验证集上频繁进行模型评估,保持可靠性和效率之间的平衡至关重要。较大的评估集大小可以提供更稳定和可靠的性能指标,但它也必须是高效的,以在每次运行中保持实用性。使用151M模型,我们评估了不同token数量下的方差:327,680个token为2.17e-4,1,638,400个token为4.53e-5,3,276,800个token为7.65e-6。基于这些结果,我们将默认的评估批次数设置为100。需要注意的是,不同批量大小的总训练步数是不同的。为了解决这个问题,我们实现了一个混合评估协议:模型在$2^i$(其中$i \in Z$)的间隔、每1,000步以及在总步数n的最后30%期间的0.7n, 0.75n, 0.8n, ..., 直到n时进行评估。这种方法确保在训练结束时进行更频繁的评估,从而更准确地评估达到目标评估损失所需的总训练步数。

超参数搜索细节。由于计算限制,我们无法对所有超参数配置进行详尽搜索。相反,正如正文中的消融研究所建议的,我们通过训练较小的代理模型(151M参数)来获得对超参数的洞察。我们按顺序优化以下超参数:学习率、动量($\beta_1$)、预热步数、调度器和上下文长度。此外,我们为每个模型大小和批量大小调整$\beta_2$和$\tau$。具体来说,对于大批量大小(>1024),较小的$\beta_2$和较大的$\tau$往往更有效,而对于较小的批量大小则相反,这与(【索引52,Resolving discrepancies in compute-optimal scaling of language models,2024,arXiv】,【索引77,Adam can converge without any modification on update rules,2022,Neural Information Processing Systems】)中的发现一致。

超参数与模型配置表。下面我们展示了我们主要图中用于研究CBS与模型大小关系的超参数选择(表3)和最优选择(表4)。此外,表5展示了各种模型大小配置和扩展方法,其中加粗的模型表示在我们的受控实验中使用的模型。

表2:模型架构细节。
表2:模型架构细节。

表3:扫描实验设置。超参数搜索后的默认值以粗体显示。粗体表示无需大量调整即可紧密复现我们结果的默认超参数。非粗体意味着对每个模型规模进行了全面扫描。括号中的值并未用于每次扫描:对于151M模型,我们测试了3.16e-4和1e-2的学习率,但发现使用EWA时,这些值比3.16e-3表现更差。EWA衰减率0.99995仅用于长时间的1.2B运行。
表3:扫描实验设置。

表4:不同模型大小的最优超参数。最优指的是达到目标验证损失所需的步数。我们称Adam中二阶矩估计的指数衰减率为β2,EWA中插值参数为τ (ξt+1 = τ · ξt + (1 - τ) · θt)。所有最优运行均使用动量β1 = 0.95和学习率3.16e-3进行训练。
表4:不同模型大小的最优超参数。

表5:关于深度和宽度缩放的消融研究的模型架构。仅使用加粗突出的模型,因为它们在模型大小方面更具可比性。
表5:关于深度和宽度缩放的消融研究的模型架构。

E 缩放定律的额外细节

步数与批量大小的幂律关系。我们首先展示了达到目标损失所需的优化步数与批量大小之间的拟合幂律关系(表6)。所有结果都是通过scipy.optimize.fsolve使用默认超参数求解第3.2节中的方程得到的。

更大规模的CBS预测结果。我们报告了各种模型和token大小的预测结果,超出了正文中呈现的图表范围(表7)。对于每一行,增加模型大小或token大小都显示预测结果保持可比性。

CBS定义与现有工作的关联。请注意,我们对CBS及其缩放定律的定义与(【索引44,An empirical model of large-batch training,2018,arXiv】,【索引32,Scaling laws for neural language models,2020,arXiv】)中的$E_{min}S_{min}$有相似的解释,$S_{min}$表示达到目标损失所需的最少步数,$E_{min}$是达到目标损失所需处理的最少训练样本数。特别地,回想一下关键批量大小可以解析地推导为$B^* = \sqrt[\alpha]{\frac{b}{a} + 1.2B_{opt}}$。这个关系反映了当批量大小加倍时,批量大小缩放产生20%开销的点(McCandlish et al., 2018)。在这里,参数b扮演着类似于$E_{min}$的角色,而a对应于$S_{min}$,这取决于为描述增加批量大小带来的递减回报而选择的具体开销。我们还注意到递减回报开销可以变化,导致以下观察结果:10% : $B^* = 20.67 \times D^{0.48}$, 20% : $B^* = 22.91 \times D^{0.47}$, 50% : $B^* = 30.50 \times D^{0.44}$。

表6:在固定α = 1时,Chinchilla设置下拟合的缩放定律参数:log(Y) = log(a + b/B^α),其中Y是达到Chinchilla目标损失的步数,B表示批量大小,关键批量大小解为$B^* = (\frac{b+5a \times 1.2 \times B_{opt}}{5a})^{1/\alpha}$,B_opt = 256。
(a) 固定α = 1(默认)
表格 (a) 固定α = 1(默认)

(b) 拟合α
表格 (b) 拟合α

表7:更大规模的额外预测CBS结果。回想一下,我们拟合了$B^* = 93.20 \times N^{0.47}$,$B^* = 22.91 \times D^{0.47}$,其中模型大小N以百万计,数据大小D以十亿计。
表7:更大规模的额外预测CBS结果。

F 复现性

复现性验证。在我们的训练环境中,我们验证了在多个模型大小(2.4M, 9.4M, 19M, 42M, 85M, 151M, 302M)上,我们可以(近似地)复现(Wortsman等人,2023)中图1的最终评估损失。我们使用配备8个A100 GPU的节点进行模型训练,每个GPU有80GiB内存。我们使用Olmo训练套件(Groeneveld等人,2024)构建了我们的训练框架。

G 第4.2节的完整证明

定理3的证明。Zou等人的工作(2023)研究了批量大小为1的线性回归SGD,并建立了超额风险的匹配( jusqu'à un facteur constant près)上下界。我们的定理通过进一步考虑批量大小的影响来推广他们的工作。我们的分析通过适当的简化使用了他们的中间结果。我们首先定义一组关于PSD矩阵的操作如下:

$$ \mathcal{I}=\mathbf{I} \otimes \mathbf{I}, \quad \mathcal{M}^{\mathrm{B}}=\mathbb{E}\left[\left(\frac{1}{B} \sum_{i \in \mathcal{I}} \mathbf{x}_{i} \mathbf{x}_{i}^{\top}\right) \otimes\left(\frac{1}{B} \sum_{i \in \mathcal{I}} \mathbf{x}_{i} \mathbf{x}_{i}^{\top}\right)\right], \quad \widetilde{\mathcal{M}}=\mathbf{H} \otimes \mathbf{H}, $$

$$ \mathcal{T}^{\mathrm{B}}=\mathbf{H} \otimes \mathbf{I}+\mathbf{I} \otimes \mathbf{H}-\gamma \mathcal{M}^{\mathrm{B}}, \quad \widetilde{\mathcal{T}}=\mathbf{H} \otimes \mathbf{I}+\mathbf{I} \otimes \mathbf{H}-\gamma \mathbf{H} \otimes \mathbf{H}, $$

其中I是B个独立数据的索引集。请注意

$$(\mathcal{M}^{\mathrm{B}}-\widetilde{\mathcal{M}}) \circ \mathbf{A}=\operatorname{Cov}\left(\frac{1}{B} \sum_{i \in \mathcal{I}} \mathbf{x}_{i} \mathbf{x}_{i}^{\top} \mathbf{A}^{1 / 2}\right)=\frac{1}{B} \operatorname{Cov}\left(\mathbf{x x}^{\top} \mathbf{A}^{1 / 2}\right).$$

高斯数据的协方差计算。对于高斯数据$x \in N(0, H)$,我们有

$$\operatorname{Cov}(\mathbf{x} \mathbf{x}^{\top} \mathbf{A}^{1 / 2})=\mathbb{E}_{\mathbf{x} \in \mathcal{N}(0, \mathbf{H})}[\mathbf{x} \mathbf{x}^{\top} \mathbf{A} \mathbf{x} \mathbf{x}^{\top}]-\mathbf{H A} \mathbf{H}=2 \operatorname{tr}(\mathbf{H A}) \mathbf{H}.$$

综合结果。综合起来,我们得到

$$(\mathcal{M}^{\mathrm{B}}-\widetilde{\mathcal{M}}) \circ \mathbf{A}=\frac{2}{B} \operatorname{tr}(\mathbf{H A}) \mathbf{H}.$$

SGD步骤中的误差传播。现在我们计算沿SGD步骤的误差传播。令$\eta_t = w_t - w^*$为误差向量。为方便起见,令$G_t = \frac{1}{B} \sum_{i \in I_t} x_i x_i^\top$为独立批次的经验协方差。然后我们可以定义偏差和方差的迭代如下

$$\boldsymbol{\eta}_{t}^{\text {bias }}=\left(\mathbf{I}-\gamma \mathbf{G}_{t}\right) \boldsymbol{\eta}_{t-1}^{\text {bias }}, \quad t=1, \ldots, n-1, \quad \boldsymbol{\eta}_{0}^{\text {bias }}=\mathbf{w}_{0}-\mathbf{w}^{*}$$

$$\boldsymbol{\eta}_t^{\text{variance}} = (\mathbf{I} - \gamma \mathbf{G}_t) \boldsymbol{\eta}_{t-1}^{\text{variance}} + \gamma \cdot \frac{1}{B} \sum_{i \in \mathcal{I}_t} \xi_i \mathbf{x}_i, \quad t=1, \dots, n-1, \quad \boldsymbol{\eta}_0^{\text{variance}} = \mathbf{0},$$

误差迭代的协方差矩阵。其中$\xi_i = y_i - x_i^\top w^* \sim N(0, \sigma^2)$。我们接着计算这两个误差迭代的协方差矩阵

$$\mathbf{B}_{t}^{\mathrm{B}}:=\mathbb{E}[\boldsymbol{\eta}_{t}^{\text {bias }} \otimes \boldsymbol{\eta}_{t}^{\text {bias }}], \quad \mathbf{C}_{t}^{\mathrm{B}}:=\mathbb{E}[\boldsymbol{\eta}_{t}^{\text {variance }} \otimes \boldsymbol{\eta}_{t}^{\text {variance }}] .$$

协方差矩阵的迭代更新。使用这些算子,这些协方差矩阵采用以下迭代更新:

$$\begin{aligned} \begin{aligned} \mathbf{B}_{0}^{\mathrm{B}} & =\boldsymbol{\eta}_{0} \otimes \boldsymbol{\eta}_{0}, \quad \mathbf{B}_{t}^{\mathrm{B}}=\mathbb{E}_{\mathbf{G}_{t}}\left[\left(\mathbf{I}-\gamma \mathbf{G}_{t}\right) \mathbf{B}_{t-1}^{\mathrm{B}}\left(\mathbf{I}-\gamma \mathbf{G}_{t}\right)\right]=\left(\mathcal{I}-\gamma \mathcal{T}^{\mathrm{B}}\right) \circ \mathbf{B}_{t-1}^{\mathrm{B}}, \\ \mathbf{C}_{0}^{\mathrm{B}} & =\mathbf{0}, \quad \mathbf{C}_{t}^{\mathrm{B}}=\mathbb{E}_{\mathbf{G}_{t}}\left[\left(\mathbf{I}-\gamma \mathbf{G}_{t}\right) \mathbf{C}_{t-1}^{\mathrm{B}}\left(\mathbf{I}-\gamma \mathbf{G}_{t}\right)\right]+\frac{\gamma^{2}}{B^{2}} \mathbb{E}\left[\left(\sum_{i \in \mathcal{I}_{t}} \xi_{i} \mathbf{x}_{i}\right)\left(\sum_{i \in \mathcal{I}_{t}} \xi_{i} \mathbf{x}_{i}\right)^{\top}\right] \\ & =\left(\mathcal{I}-\gamma \mathcal{T}^{\mathrm{B}}\right) \circ \mathbf{C}_{t-1}^{\mathrm{B}}+\frac{\gamma^{2} \sigma^{2}}{B} \mathbf{H}, \end{aligned} \end{aligned}$$

最后一个方程的推导。其中最后一个方程是因为

$$\mathbb{E}\left[\left(\sum_{i \in \mathcal{I}_{t}} \xi_{i} \mathbf{x}_{i}\right)\left(\sum_{i \in \mathcal{I}_{t}} \xi_{i} \mathbf{x}_{i}\right)^{\top}\right]=\mathbb{E}\left[\sum_{i \in \mathcal{I}_{t}} \xi_{i}^{2} \mathbf{x}_{i} \mathbf{x}_{i}^{\top}\right]=\sigma^{2} B \mathbf{H}.$$

偏差-方差分解。回想一下$\bar{w} = \frac{1}{n} \sum_{t=0}^{n-1} w_t$。首先,使用(Zou等人,2023)中的引理B.3和C.1,我们得到以下偏差-方差分解(注意我们的设置是明确指定的):

$$\mathbb{E}[\mathcal{R}(\bar{\mathbf{w}})]-\min \mathcal{R}(\cdot)=\text { bias }+\text { variance },$$

其中

$$\begin{aligned} \begin{aligned} \text{bias} := \frac{1}{2} \langle \mathbf{H}, \mathbb{E}[\bar{\boldsymbol{\eta}}^{\text{bias}} \otimes \bar{\boldsymbol{\eta}}^{\text{bias}}] \rangle & \begin{cases} \leq \displaystyle \frac{1}{n^2} \sum_{t=0}^{n-1} \sum_{k=t}^{n-1} \langle (\mathbf{I} - \gamma \mathbf{H})^{k-t} \mathbf{H}, \mathbf{B}_t^{\text{B}} \rangle, \\ \geq \displaystyle \frac{1}{2n^2} \sum_{t=0}^{n-1} \sum_{k=t}^{n-1} \langle (\mathbf{I} - \gamma \mathbf{H})^{k-t} \mathbf{H}, \mathbf{B}_t^{\text{B}} \rangle, \end{cases} \\ \text{variance} := \frac{1}{2} \langle \mathbf{H}, \mathbb{E}[\bar{\boldsymbol{\eta}}^{\text{variance}} \otimes \bar{\boldsymbol{\eta}}^{\text{variance}}] & \begin{cases} \leq \displaystyle \frac{1}{n^2} \sum_{t=0}^{n-1} \sum_{k=t}^{n-1} \langle (\mathbf{I} - \gamma \mathbf{H})^{k-t} \mathbf{H}, \mathbf{C}_t^{\text{B}} \rangle, \\ \geq \displaystyle \frac{1}{2n^2} \sum_{t=0}^{n-1} \sum_{k=t}^{n-1} \langle (\mathbf{I} - \gamma \mathbf{H})^{k-t} \mathbf{H}, \mathbf{C}_t^{\text{B}} \rangle, \end{cases} \end{aligned} \end{aligned}$$

其中

$$\bar{\boldsymbol{\eta}}^{\text{bias}} := \frac{1}{n} \sum_{t=0}^{n-1} \bar{\boldsymbol{\eta}}_t^{\text{bias}}, \quad \bar{\boldsymbol{\eta}}^{\text{variance}} := \frac{1}{n} \sum_{t=0}^{n-1} \bar{\boldsymbol{\eta}}_t^{\text{variance}} .$$

偏差与方差部分的推导。剩下的工作是为批量大小B刻画$BB_t$和$CB_t$。对于偏差部分,我们有

$$\begin{aligned} \begin{aligned} \mathbf{B}_{t}^{\mathrm{B}} & =\left(\mathcal{I}-\gamma \mathcal{T}^{\mathrm{B}}\right) \circ \mathbf{B}_{t-1}^{\mathrm{B}} \\ & =(\mathcal{I}-\gamma \tilde{\mathcal{T}}) \circ \mathbf{B}_{t-1}^{\mathrm{B}}+\gamma^{2}\left(\mathcal{M}^{\mathrm{B}}-\tilde{\mathcal{M}}\right) \circ \mathbf{B}_{t-1}^{\mathrm{B}} \\ & =(\mathcal{I}-\gamma \tilde{\mathcal{T}}) \circ \mathbf{B}_{t-1}^{\mathrm{B}}+\frac{2 \gamma^{2}}{B} \operatorname{tr}\left(\mathbf{H B}_{t-1}^{\mathrm{B}}\right) \mathbf{H}, \quad t=1, \ldots, n-1 . \end{aligned} \end{aligned}$$

方差部分的推导。对于方差部分,我们有

$$\begin{aligned} \begin{aligned} \mathbf{C}_{t}^{\mathrm{B}} & =(\mathcal{I}-\gamma \mathcal{T}^{\mathrm{B}}) \circ \mathbf{C}_{t-1}^{\mathrm{B}}+\frac{\gamma^{2} \sigma^{2}}{B} \mathbf{H} \\ & =(\mathcal{I}-\gamma \tilde{\mathcal{T}}) \circ \mathbf{C}_{t-1}^{\mathrm{B}}+\gamma^{2}(\mathcal{M}^{\mathrm{B}}-\tilde{\mathcal{M}}) \circ \mathbf{C}_{t-1}+\frac{\gamma^{2} \sigma^{2}}{B} \mathbf{H} \\ & =(\mathcal{I}-\gamma \tilde{\mathcal{T}}) \circ \mathbf{C}_{t-1}^{\mathrm{B}}+\frac{2 \gamma^{2}}{B} \operatorname{tr}(\mathbf{H C}_{t-1}^{\mathrm{B}}) \mathbf{H}+\frac{\gamma^{2} \sigma^{2}}{B} \mathbf{H}, \quad t=1, \ldots, n-1 . \end{aligned} \end{aligned}$$

超额风险的上下界推导。为了获得超额风险的上界,我们将Zou等人(2023)的假设2.2中的$\alpha$替换为$2/B$,步数替换为$n := D/B$,噪声水平$\sigma^2$替换为$\sigma^2/B$,然后应用定理2.1的证明。类似地,对于超额风险的下界,我们将假设2.4中的$\beta$替换为$2/B$,步数替换为$n := T/B$,噪声水平$\sigma^2$替换为$\sigma^2/B$,并应用定理2.2的证明。通过这样做,我们获得了以下关于小批量SGD超额风险的匹配( jusqu'à un facteur constant près)上下界:

$$\begin{aligned} \begin{aligned} \mathbb{E} \mathcal{R}(\bar{\mathbf{w}})-\min \mathcal{R}(\cdot) & \simeq\left(\frac{1}{n \gamma}\right)^{2}\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}_{0: k^{*}}^{-1}}^{2}+\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}_{k^{*}: \infty}}^{2} \\ & +\frac{1 / B\left(\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{I}_{0: k^{*}}}^{2}+n \gamma\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}_{k^{*}: \infty}}^{2}\right)}{n \gamma} \cdot \frac{k^{*}+(n \gamma)^{2} \sum_{i>k^{*}} \lambda_{i}^{2}}{n} \\ & +\frac{\sigma^{2}}{B} \cdot \frac{k^{*}+(n \gamma)^{2} \sum_{i>k^{*}} \lambda_{i}^{2}}{n}, \end{aligned} \end{aligned}$$

步长条件的推导。其中$k^* = \max\{k : \lambda_k \geq 1/(n\gamma)\}$,并且一个充分的步长条件(见Zou等人(2023)中的引理4.1,定理2.1和2.2)是

$$0 < \gamma \lesssim \min \left\{ \frac{1}{\alpha \text{tr}(\mathbf{H})}, \frac{1}{\|\mathbf{H}\|_2} \right\} \simeq \min \left\{ \frac{B}{\text{tr}(\mathbf{H})}, \frac{1}{\|\mathbf{H}\|_2} \right\}.$$

简化超额风险界。假设$\|w_0 - w^*\|_H^2 \lesssim \sigma^2$意味着

$$\frac{\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{I}_{0: k^{*}}}^{2}+n \gamma\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}_{k^{*}: \infty}}^{2}}{n \gamma} \leq\left\|\mathbf{w}_{0}-\mathbf{w}^{*}\right\|_{\mathbf{H}}^{2} \lesssim \sigma^{2},$$

这进一步将超额风险界简化为

$$ \mathbb{E}\mathcal{R}(\bar{\mathbf{w}}) - \min \mathcal{R}(\cdot) \simeq \left( \frac{1}{n\gamma} \right)^2 \|\mathbf{w}_0 - \mathbf{w}^*\|^2_{\mathbf{H}_{0:k^*}^{-1}} + \|\mathbf{w}_0 - \mathbf{w}^*\|^2_{\mathbf{H}_{k^*:\infty}} + \frac{\sigma^2}{B} \cdot \frac{k^* + (n\gamma)^2 \sum_{i>k^*} \lambda_i^2}{n} $$

完成证明。最后,在界中替换$n = D/B$即可完成我们的证明。

推论2的证明。通过$\lambda_i \approx i^{-a}$,我们可以解出$k^*$以获得$k^* \approx (D\gamma/B)^{1/a}$。然后我们利用定理3和容量及源条件计算期望超额风险:

$$\begin{aligned} \begin{aligned} \mathbb{E} \mathcal{R}(\overline{\mathbf{w}})-\sigma^{2} & \simeq \mathbb{E}\left(\left(\frac{B}{D \gamma}\right)^{2}\left\|\mathbf{w}^{*}\right\|_{\mathbf{H}_{0: k^{*}}^{-1}}^{2}+\left\|\mathbf{w}^{*}\right\|_{\mathbf{H}_{k^{*}: \infty}}^{2}\right)+\frac{k^{*}+(D \gamma / B)^{2} \sum_{i>k^{*}} \lambda_{i}^{2}}{D} \\ & \simeq\left(\frac{B}{D \gamma}\right)^{2} \sum_{i \leq k^{*}} i^{-b+2 a}+\sum_{i>k^{*}} i^{-b}+\frac{1}{D}\left(k^{*}+\left(\frac{D \gamma}{B}\right)^{2} \sum_{i>k^{*}} i^{-2 a}\right) \\ & \simeq\left(\frac{B}{D \gamma}\right)^{2} \max \left\{\left(k^{*}\right)^{1-b+2 a}, 1\right\}+\left(k^{*}\right)^{1-b}+\frac{1}{D}\left(k^{*}+\left(\frac{D \gamma}{B}\right)^{2}\left(k^{*}\right)^{1-2 a}\right) \\ & \simeq \max \left\{\left(\frac{D \gamma}{B}\right)^{(1-b) / a},\left(\frac{D \gamma}{B}\right)^{-2}\right\}+\frac{1}{D}\left(\frac{D \gamma}{B}\right)^{1 / a} . \end{aligned} \end{aligned}$$

分情况讨论。我们接着讨论三种情况。

  1. 当$b \leq a$时,我们有

$$\mathbb{E} \mathcal{R}(\bar{\mathbf{w}})-\sigma^2 \simeq\left(\frac{D \gamma}{B}\right)^{(1-b) / a}+\frac{1}{D}\left(\frac{D \gamma}{B}\right)^{1 / a} \simeq\left(\frac{D \gamma}{B}\right)^{(1-b) / a},$$

其中最后一个等式成立是因为$\gamma/B \lesssim 1$,所以第一项主导第二项。因此最优超参数是$\gamma^* \approx 1$和$B^* = 1$。

  1. 当$a < b < 2a + 1$时,我们有

$$\mathbb{E} \mathcal{R}(\bar{\mathbf{w}})-\sigma^2 \simeq\left(\frac{D \gamma}{B}\right)^{(1-b) / a}+\frac{1}{D}\left(\frac{D \gamma}{B}\right)^{1 / a},$$

因此最优超参数是

$$0 < \gamma^* \lesssim 1, \quad 1 \le B^* \le D, \quad \gamma^*/B^* \simeq D^{a/b-1}.$$
  1. 当$b > 2a + 1$时,我们有
$$\mathbb{E} \mathcal{R}(\bar{\mathbf{w}})-\sigma^2 \simeq\left(\frac{D \gamma}{B}\right)^{-2}+\frac{1}{D}\left(\frac{D \gamma}{B}\right)^{1 / a},$$

因此最优超参数是

$$0 < \gamma^* \lesssim 1, \quad 1 \leq B^* \leq D, \quad \gamma^* / B^* \asymp D^{a / (2 a+1)-1}$$

证明完成。结合第二和第三种情况即可完成证明。