MAGIS: Memory Optimization via Coordinated Graph Transformation and Scheduling for DNN

文章标题与作者/机构

文章标题:MAGIS: 通过协同图变换和调度的深度神经网络内存优化
作者与机构
- Renze Chen (北京大学)
- Zijian Ding (加州大学洛杉矶分校)
- Size Zheng (北京大学)
- Chengrui Zhang (北京大学)
- Jingwen Leng (上海交通大学)
- Xuanzhe Liu (北京大学)
- Yun Liang (北京大学)


A1 主要贡献

深度神经网络(DNN)的内存消耗随着其拓扑结构和规模的复杂化而持续增长,这主要归因于两个因素:一是大量张量(如模型参数、训练前向传播中的激活值、复杂网络中的中间张量)具有较长的生命周期;二是许多张量(如大批量大小、长序列长度、高分辨率图像)具有较大的尺寸。这给服务器和移动设备的计算带来了巨大挑战。

核心问题: 现有的内存优化技术主要分为两类,各有其局限性。
1. 图调度 (Graph Scheduling):包括重计算(rematerialization)、交换(swapping)和重排序(re-ordering)。这类技术通过操纵张量的生命周期来减少峰值内存,但通常会显著损害性能,并且无法改变张量的形状,从而限制了优化空间。
2. 图变换 (Graph Transformation):通过等价变换改变图的结构来优化性能,主要分为聚合变换(Aggregation Transformation, A-Trans)和中间变换(Interim Transformation, I-Trans)。然而,现有工作主要关注性能优化,而未充分利用图变换进行内存优化。特别是聚合变换的对偶操作——分裂变换 (Fission Transformation, F-Trans),虽然能通过拆分大算子有效减少内存占用,但会引入两大挑战:
* 复杂性:F-Trans 会导致计算图规模急剧增长,增加了后续优化的难度,且其自身的搜索空间巨大。
* 协同优化:图变换和图调度之间存在复杂的权衡关系,需要进行高效的协同优化,但这两种优化本身都非常复杂。

研究目标与创新点:
为了解决上述挑战,本文提出了 MAGIS,一个通过协同图变换和图调度来优化 DNN 内存的框架。
- 设计并实现了 MAGIS:一个基于协同图变换和图调度的内存优化框架。
- 形式化并优化了图分裂变换:本文形式化了图分裂变换(F-Trans),并提出使用分裂层次树 (Fission Hierarchy Tree, F-Tree) 来表示它,通过图结构分析来减少其巨大的搜索空间,从而在不实际增加图复杂度的前提下探索 F-Trans 的优化潜力。
- 提出了高效的协同优化算法
1. 将重计算和交换等调度技术分解为图变换和重排序,从而将内存与性能的权衡完全移至变换阶段,简化了调度过程。
2. 设计了一种增量图调度算法,在每次图变换后,能够利用先前的调度信息高效地生成新的调度方案,大幅降低了调度开销。

实验结果表明,与现有技术相比,MAGIS 在相同的延迟约束下,峰值内存使用量仅为它们的 15%~85%,并在内存和延迟的双目标优化中获得了更优的帕累托边界。


图 1. 图变换示例。(a) 和 (b) 是从 TASO 【25,Zhihao Jia et al. TASO: optimizing deep learning computation with automatic generation of graph substitutions. SOSP 2019】中借鉴的变换,用于优化性能。(c) 是聚合变换的对偶变换,可以有效地用性能换取内存。


A3 背景知识与动机

表 1. 符号表

2.1 计算图

2.2 图调度与变换

2.3 动机


图 2. 内存限制为100的动机示例。(a) 无任何优化。(b) 使用交换。(c) 使用分裂变换。(d)(e) 使用分裂变换和交换。


A2 方法细节

3 设计概览


图 3. MAGIS 概览。

4 M-Analyzer

4.1 维度图


(a) 自注意力机制的图G


(b) G中每个节点的形状


批处理维度


头维度


序列维度
(c) G的D-Graph的一些子图

图 4. D-Graph示例。N, T, C, H, h分别代表批大小、序列长度、隐藏维度、头数量、头维度。

4.2 分裂变换


图 5. 图G中的F-Trans f = (S, D, n) (n = 2),该图是从一个MLP的训练图中简化的。(a) 子图 S = {v3, v4, v5, v6, v7, v8}。(b) D-Graph D,表示S的激活值的批处理维度。(c) F-Trans之后的结果图。

4.3 分裂层次树

$$ M_0 = \sum_{v \in U} |v| \quad M_f \approx \sum_{v \in (H \setminus S) \cup U} |v| + \sum_{v \in H \cap S} \frac{|v|}{n} $$
$$M_0 - M_f = \sum_{v \in H \cap S} \left(1 - \frac{1}{n}\right)|v| - \sum_{v \in F \setminus H} |v|$$

$heat(v) = \sum_{w \in H \cap T.des(v)} |w|$
$$ \text{score}(v) = \left(1 - \frac{1}{n} \text{heat}(v) - \sum_{u \in G, \text{mps}(T, \text{des}(v)), H} |u| \right) (4) $$

算法 1: M-Analyzer: F-Tree 构建
输入: 图: G; 最大层级: L
输出: 分裂层次树: F
1 F := ∅;
2 H := MemoryHotspots(G);
3 for D ∈ D(G) 的连通分量 do
4     G' := 从 D 诱导的 G 的子图;
5     T := T(G');
6     S := GetScores(G', T, H);
7     S_max = max_{v∈V(G')} S[v];
8     if S_max ≤ 0 then continue;
9     for l ∈ {1, 2, ..., L} do
10        C := {v ∈ V(G&#39;) | l/L ≤ S[v]/S_max < (l + 1)/L};
11        for v_dom ∈ {v ∈ C | T.des(v) ∩ C = ∅} do
12            S := T.des(v_dom) \ {v_dom};
13            S&#39; := 从 S 诱导的 G&#39; 的子图;
14            f := (S&#39;, D, 1);
15            if f is valid then F := F ∪ {f};
16 return F;


图 6. 基于算法1(L=5)的F-Tree构建示例。每个张量的大小为1。(a) 算法1第4行的G'。(b) 支配树 Dom G' = T(G')。(c) 基于公式(3)(4)计算的分数,橙色框中的节点是选定的支配者(算法1第11行的v_dom)。(d) 选定的子图(算法1第12行的S)。(e) 构建的F-Tree。

5 M-Rules
5.1 F-Tree 突变规则


图 7. F-Tree突变规则图示。(a) 启用一个F-Tree节点。(b) 提升一个F-Tree节点。(c) 禁用一个F-Tree节点。(d) 增加分裂数n(维度长度d=12)。

5.2 基于调度的规则


图 8. 基于调度的规则,表示从图调度分解出的变换。标有星号(*)的边代表零条或多条边。

6 M-Optimizer

6.1 增量调度
算法 2: M-Optimizer: 增量调度
输入: 旧图、新图: G_old, G_new;
      旧的突变子图节点: S_old;
      旧图的调度: ψ_old
输出: 新图的调度: ψ_new
1 function GetRescheduleInterval(G, S, ψ):
2   function ExtendBound(i, d):
3     ñ := 0; l := 0; v := ψ[i];
4     while l < 20 ∧ (ñ > 10 ∨ nw(v) < 4) ∧ nw(v) < ñ do
5       ñ := nw(v); i := i + d; v := ψ[i]; l := l + 1;
6     return i;
7   I_S := {i | i = 1, ..., |ψ| if ψ[i] ∈ S};
8   return ExtendBound(min I_S, -1), ExtendBound(max I_S, 1);
9 beg, end = GetRescheduleInterval(G_old, S_old, ψ_old);
10 S_new := V(G_new) \ (ψ_old[:beg] ∪ ψ_old[end:]);
11 Ψ := {DpSchedule(S) | S ∈ GraphPartition(S_new)};
12 return Merge(ψ_old[:beg], MergeSubSched(Ψ), ψ_old[end:]);
6.2 顶层搜索算法
算法 3: M-Optimizer: 搜索算法
输入: 输入图 G; 内存约束 M; F-Tree 最大层级 L
输出: 优化的 M-State μ_best
1 function BetterThan(μ_1, μ_2, δ = 1):
2   return (max(μ_1.mem, M), μ_1.lat) < (max(δ × μ_2.mem, M), δ × μ_2.lat);
3 function GraphHash(G):
4   for v ∈ topo-order(G) do
5     x_v := hash(v) ⊕ (⊕_{u∈G.pre(v)} x_u);
6   return hash(Σ_{v∈G} x_v);
7 μ_best := InitState(G); X := ∅;
8 Q := PriorityQueue({μ_best}, BetterThan);
9 while Q ≠ ∅ do
10  μ := Q.pop(); x := GraphHash(μ.G);
11  if x ∈ X then continue;
12  X := X ∪ {x};
13  if μ&#39;s F-Tree needs update then
14    μ := Analyze(μ, l); # 算法 1
15  for μ&#39; ∈ ApplyTransformRules(μ) do
16    μ&#39; := ApplyIncrementalSchedule(μ&#39;); # 算法 2
17    if BetterThan(μ&#39;, μ_best) then μ_best := μ&#39;;
18    if BetterThan(μ&#39;, μ_best, 1.1) then Q.push(μ&#39;);
19 return μ_best;

A4 实验

实验环境 (总结)

表 2. 评估使用的工作负载

* 硬件平台:Intel工作站,配备20核Intel(R) Xeon(R) Silver 4210R CPU,一张NVIDIA GeForce RTX 3090 GPU。
* 软件配置:CUDA 11.6, cuDNN 8.4.0, PyTorch 2.1.0, MegEngine 1.12.3, TensorFlow 2.15.0, TVM 0.14.0。MAGIS优化时间预算为3分钟。为公平比较,所有基线都先应用TASO规则进行初步优化。

实验结果 (总结)

内存优化(带延迟约束)


图 9. 与未优化的PyTorch相比的峰值内存比率(越低越好)。"OOM"表示内存使用超出实验平台内存限制。

延迟优化(带内存约束)


图 10. 与未优化的PyTorch相比的延迟开销(越低越好)。"FAILURE"表示无法将内存比率优化到满足约束。

延迟与内存的权衡曲线


图 11. MAGIS与基线的延迟和内存曲线。MAGIS几乎在所有情况下都能实现帕累托最优。

与微批次(Micro-batching)的比较


图 12. MAGIS与POFO的比较。POFO使用的网络已经过微批次预处理(使用不同因子)。

启发式策略消融实验


图 13. 在3分钟内优化BERT工作负载时MAGIS的启发式策略分解,约束条件与§7.2.1和§7.2.2相同。曲线上的菱形"⋄"是其优化结果满足约束的时间点。曲线上的方形"□"是获得最佳优化结果的时间点。

增量调度评估


图 14. 增量调度(IS)与完整调度(FS)的比较。(a) IS相对于FS的调度时间加速比。(b) IS与FS的调度结果质量比较。

优化时间成本

图 15. MAGIS优化ViT (batch 64) 1分钟的优化时间成本分解。"Filtered"表示被哈希测试过滤掉的重复图。

案例研究:UNet


图 16. UNet的执行时间与内存使用情况。


A7 相关工作


A5 结论

我们提出了MAGIS,一个用于内存和延迟优化的DNN优化器,其系统化地设计了分裂变换,并实现了图变换和调度之间的有效协调。实验结果表明,与最先进的方法相比,MAGIS在相同的延迟约束下仅使用15%~85%的内存,并获得了更好的内存与延迟帕累托边界。