核心问题: 深度神经网络(DNNs)在处理长上下文任务时会产生巨大的中间张量,导致严重的内存开销。现有的DNN优化方法由于对张量属性的感知不足,无法有效进行内存优化,并可能导致计算效率低下。具体来说,当一个巨大的中间张量被多个具有不同归约(reduction)维度的操作符使用时,现有的融合策略会因复杂的依赖关系而导致并行度降低,性能严重下降。
研究目标: 本文旨在提出一个名为FlashTensor的DNN优化系统,通过利用细粒度的张量属性来减少内存开销并提高推理性能。
创新点与主要贡献:
* 总结了四种关键张量属性: 本文总结了四种对优化至关重要的张量属性,分别是归约依赖(reduce dependency)、广播性(broadcastability)、尺寸(size)和值(value)。
* 提出了张量属性识别器 (Tensor Property Identifier): 该模块系统地分析整个计算图,并捕获每个张量的细粒度属性。
* 提出了张量属性感知优化 (Tensor Property-Aware Optimization): 该模块基于属性感知的转换规则和核映射策略来搜索最优的计算核(kernels),以实现高计算效率和低内存访问开销。
* 设计并实现了FlashTensor系统: 这是一个利用细粒度张量属性优化张量程序的系统。实验表明,与八个最先进的工作相比,FlashTensor在H100上平均取得了1.50倍的端到端加速和3.24倍的核心模块性能加速(在A100上分别为1.86倍和3.70倍)。
近期的一个研究趋势是使用新的注意力变体来减少Vanilla Attention的巨大计算量。例如,像H2O【45,H2O: heavy-hitter oracle for efficient generative inference of large language models,2024,NIPS】,RoCo【28,On the Efficacy of Eviction Policy for Key-Value Constrained Generative Language Model Inference,2024,arXiv】和Keyformer【1,Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference,2024,MLSys】这样的模型,可以通过丢弃一些不重要的令牌来减少总计算量。以图2中展示的H2O为例,它在SoftMax(由Exp、Reduce 0和Div组成)之后使用了一个额外的Reduce算子(Reduce 1)来计算令牌的重要性。随后,一个TopK算子和一个Gather算子一起选择并缓存最重要的令牌,同时丢弃其余的。
在H2O的推理过程中,包含预填充(prefill)和解码(decode)两个阶段。预填充阶段逐块处理输入提示,选择并缓存重要令牌,尤其是在输入超过缓存容量时。解码阶段顺序生成输出令牌,并进一步更新缓存。因此,增加的算子会影响预填充和解码阶段的性能。然而,对于长上下文的文档摘要等任务,预填充阶段是主要的性能瓶颈。例如,在InfiniteBench【44,∞Bench: Extending Long Context Evaluation Beyond 100K Tokens,2024,ACL】中,提示可以达到442K个令牌,而生成的令牌只有0.7K个,相差631倍,导致H2O中预填充与解码的执行时间比为4.51。主要瓶颈来自于预填充阶段核心模块中创建的巨大张量,如图2所示。这个尺寸为$O(seqlen^2)$的张量由MatMul 0产生,带来了巨大的内存访问开销。即使经过TensorRT【36,NVIDIA TensorRT,2017,https://developer.nvidia.com/tensorrt】优化,H2O的核心模块仍占总预填充时间的约57.62%,但仅达到10.77 TFLOP/s(A100 F16 TensorCore峰值性能的3.45%)。
关键挑战在于两个归约操作,MatMul 1和Reduce 1,它们对Div的输出张量(称为DivOut)具有不同的归约维度。为了提高效率,DivOut的每一行必须分配给同一个并行单元以进行MatMul 1的计算,而DivOut的每一列也必须驻留在单个并行单元内。因此,除了使用单个并行单元外,没有其他可行的分区策略能够满足这些约束。结果是,像TensorRT这样的现有方法必须通过缓慢的全局内存跨多个核来处理这个大张量。此外,这两个具有不同归约维度的归约算子是H2O和Vanilla Attention之间的主要区别,这使得FlashAttention【10,FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,2023,arXiv】【11,FLASHATTENTION: fast and memory-efficient exact attention with IO-awareness,2024,NIPS】【30,FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision,2024,arXiv】由于这些结构差异而难以优化。这种局限性源于它们缺乏对细粒度张量属性(如每个维度上的归约依赖)的感知,从而错失了将MatMul 1与前面的算子融合以避免访问大张量的机会。
FlashTensor通过利用细粒度的张量属性来减少内存开销。在介绍我们的系统之前,我们首先介绍对关键张量属性的观察。这里我们列出一些细粒度属性,并使用H2O来说明它们为何重要。
基于以上观察,我们提出了具备这些细粒度张量属性感知的FlashTensor。FlashTensor专注于优化预填充阶段,这是长上下文任务的瓶颈,如动机部分所讨论。FlashTensor的概览如图4所示,它由两个主要模块组成:张量属性识别器(Tensor Property Identifier)和张量属性感知优化(Tensor Property-Aware Optimization)。
1. 首先,张量属性识别器模块将一个以计算图(节点表示算子,边表示张量)形式表示的张量程序作为输入,并捕获计算图中每个张量的所有细粒度属性,包括归约依赖、广播、尺寸和值(第4节)。
2. 然后,张量属性感知优化模块基于图转换规则和核映射策略搜索最优计划。图转换规则旨在通过特定的限制来减小中间张量的尺寸,而核映射策略则通过考虑内存访问、计算强度以及并行性来选择高效的候选核。最终生成优化的张量程序,准备作为高效代码执行(第5节)。
在本节中,我们首先正式定义我们观察到的对后续分析和优化至关重要的属性,然后介绍FlashTensor如何在计算图中识别这些属性。
如表1所示,张量属性被分为两大类:1) 逐维度属性 (Per-Dimension Property) 关注张量单个维度的特定属性,提供对归约依赖和广播性等方面的洞察。2) 整体张量属性 (Entire Tensor Property) 包含描述张量整体的属性,如其总尺寸和可能的常量值。
归约依赖 (Reduce Dependency)。这是一个多值属性,描述了张量维度与归约操作之间的相互依赖关系,对优化计算效率至关重要。它有三个值:
广播 (Broadcast)。这是一个属性,指示张量维度是否会被后续算子广播。如果张量是某个广播支持算子在该维度上的操作数,则该维度被认为是Broadcasted。图6(a)(b)展示了维度如何被扩展以满足算子形状要求的示例。
尺寸 (Size)。它表示张量中的元素总数,计算为所有维度的乘积,如图7所示。
值 (Value)。它指示张量在计算过程中是否保持一个常量值,如图7所示。与传统编译器将常量视为标量不同,FlashTensor在张量级别处理它,从而实现了更广泛的优化机会。具有常量值的张量可以被预计算或高效存储,从而最大限度地减少执行期间的动态更新。在FlashTensor中,常量值信息表示为一个枚举而不是实际值,以平衡信息的有效性和开销。
一些基本属性,如尺寸,可以很容易地从张量中提取。然而,像归约依赖这样的复杂属性需要复杂的分析才能准确标注。例如,在单个算子中,输入和输出张量的归约依赖可以根据其计算语义进行初步标注,即使没有完全实现该算子,如图5所示。然而,在具有多个算子的计算图中,张量的归约依赖属性受到后续和先前算子的双重影响,这使得确定其最终值具有挑战性。
为了解决这个问题,我们提出了一种基于数据流的两阶段属性识别算法。
1. 属性传播 (Property Propagation):根据每个算子的计算语义,张量属性会进行前向和后向传播。
2. 属性聚合 (Property Aggregation):从不同算子传播来的属性会为每个张量进行聚合,以确定最终的属性值。
该算法利用了固有的单调优先级(例如,归约依赖类型:NonPara > Reuse > Batch)。当不同类型汇集到同一个张量维度上时,会保留优先级更高的值,确保在整个图上属性表示的准确性。这两个阶段会迭代执行,直到属性稳定为止。
图8提供了归约依赖识别的示例以供进一步说明。图8(b)说明了MatMul的输出如何通过后向传播将右侧操作数在维度N'上的归约依赖从Reuse更新为NonPara。图8(a)展示了属性如何在MatMul和Reduce之间进行聚合,其中中间张量从两个算子接收到不同的归约依赖值。
FlashTensor基于已识别的张量属性进一步优化张量程序。优化基于两条规则:代数等价图变换和非凸核映射,以及一种轻量级计划搜索方法。代数等价图变换提供了一系列变换规则,允许在等价变换中改变中间张量的大小。非凸核映射提出了一种新的候选核类型以实现高效执行。轻量级计划搜索通过属性约束剪枝和低成本性能模型实现高效的计划生成。
中间张量的大小主导了内存访问开销。因此,我们提出了一种变换方法,主要通过在计算图上执行代数等价变换来关注中间张量大小的变化。这使得后续步骤能够在整个搜索空间中搜索具有最小中间张量大小的计算图,从而减少内存开销。广播是带来中间张量大小变化的本质。
广播属性感知的变换规则。我们提出了一系列不仅考虑张量大小,还考虑每个维度上广播属性的变换规则。具体来说,表2列出了所有的变换规则,可分为两大部分:
值属性感知的变换规则。我们利用值属性通过消除不必要的迭代来优化循环。具体来说,它专注于跳过For循环中的某些迭代而不影响最终输出。如图10所示,某些迭代可能会从Mask算子产生常量值,如-∞。由于这些常量值在For循环中传播时保持不变,我们可以识别它们并安全地跳过这些迭代。这不仅保持了输出的正确性,还减少了计算量。
核映射涉及将计算图中的算子分配给GPU核,以便在现代硬件平台上执行【15,Optimal Kernel Orchestration for Tensor Programs with Korch,2024,ASPLOS】。对于给定的计算图,简单地将所有算子融合成一个单一的核通常会导致性能次优,因为存在复杂的归约依赖和有限的并行性。识别合适的核以实现高效率至关重要。
我们首先给出最先进工作【15】中核的正式定义。
定义1 (核)。对于一个计算图 $G = (V, E)$,一个节点集合 $S \subseteq V$ 形成一个核,如果不存在节点 $v_1, v_2 \in S$ 和另一个节点 $u \in V \setminus S$ 使得 $v_1 \rightarrow u$ 且 $u \rightarrow v_2$,其中 $x \rightarrow y$ 表示在G中存在从x到y的路径。
这个定义将计算图中的一个凸(Convex)算子子集视为一个核。凸意味着这种核不能包含通过外部算子依赖于自身的算子,如图11(a)所示。Korch【15】认为这种核由于循环依赖(即Div依赖于Reduce,而Reduce又依赖于Exp)而无法执行。然而,这种循环依赖可以通过其他核来解决。具体来说,一个核不必生成其所有输入;只要有另一个核提供这些输入就足够了。例如,为了解决图11(a)中的循环依赖,另一个包含Exp和Reduce的核可以提供Reduce的输出。
通过放宽凸的要求,即支持非凸(Non-Convex)的核,一个核不再需要包含内部算子路径上的任何算子,这带来了以下两个好处:
1. 更少的计算:与凸核相比,非凸核由于放宽了要求,能够包含更少的算子,直接避免了相应的计算。
2. 潜在的更宽松的归约依赖:排除一些算子可能会使子图的依赖关系变得更简单。如图11(a)和(b)所示,如果我们排除Reduce来形成一个包含Exp和Div的非凸核,其归约依赖要简单得多,因为该核中只剩下逐元素的算子。
下面我们给出FlashTensor中考虑的非凸核的正式定义。
定义2 (非凸核)。对于一个计算图 $G = (V, E)$,一个节点集合 $S \subseteq V$ 形成一个非凸核,如果存在节点 $v_1, v_2 \in S$ 和另一个节点 $u \in V \setminus S$ 使得 $v_1 \rightarrow u$ 且 $u \rightarrow v_2$。
基于此定义,我们可以讨论核的输入和输出张量。
1. 核输入。核的输入张量是根据核内所有算子的依赖关系确定的。具体来说,如果一个张量不是核内任何算子的输出,它就被分类为该核的输入。
2. 核输出。相反,识别核的输出张量不像推断输入那样直接。输出张量不能仅通过检查数据依赖来确定,因为核内的中间张量可能被指定为输出来解决其他核的循环依赖。因此,有必要明确指定哪些中间张量应被视为核的输出。
基于提出的规则,下一步是在巨大的搜索空间中高效地搜索一个高性能计划。我们主要在变换和核映射阶段采用两阶段搜索。下面我们分别介绍每个阶段的快速搜索方法。
阶段1:变换。此阶段的搜索目标是尽可能减小中间张量的大小,以减少内存访问开销。然而,由于搜索空间巨大,解决这个问题具有挑战性。一个直接的贪心搜索方法是不够的,因为它倾向于收敛到局部最优解。例如,一些变换可能不会立即影响输入和输出张量的大小,但可以为未来大幅减小张量大小铺平道路。这个限制导致贪心算法错失全局最优解。为了应对这一挑战,我们提出了一种基于模拟退火(Simulated Annealing)【3,Simulated annealing,1993,Statistical science】的张量大小最小化搜索算法。我们概率性地探索各种变换方案,偶尔接受较大或相同大小的张量以跳出局部最优,并随着算法的进行逐渐降低这种接受的概率。虽然该算法不保证找到最优计划,但它提供了对执行时间的控制,使其在实际应用中可行。
阶段2:核映射。由于FlashTensor引入了非凸核,核的搜索空间进一步增大。我们使用一个基于屋顶线(roofline)模型的性能模型来预测每个计划的性能,以便从剪枝后的空间中找到最优计划。为了快速识别高性能核,我们还基于属性约束进行了大量的剪枝,以消除性能差的候选核。
整个搜索算法如算法1所示。它首先确定可用的并行执行单元数量,如GPU流式多处理器(SMs)。然后,我们迭代计算图的连通子集,剪枝那些并行度或算术强度低于性能阈值的核。一旦候选核被剪枝,我们遵循先前的工作【15】将最优候选搜索形式化为一个二元线性规划(BLP)任务并求解。
FlashTensor基于MLIR【18,MLIR: A Compiler Infrastructure for the End of Moore’s Law,2020,CoRR】和Triton【37,Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】实现,包含1万行C++代码和2千行Python代码。FlashTensor接受ONNX格式的张量程序,并将其转换为有效的MLIR代码。我们实现了两个MLIR Dialect:FT和FTTriton。FT定义了张量算子及相应的MLIR pass实现的变换。FTTriton作为从FT到Triton DSL的桥梁。所有优化应用后,最终的MLIR代码将被转换为有效的Triton DSL。
如图12(a)所示,FlashTensor在A100上最高实现了2.22倍的加速,在H100上最高实现了1.62倍的加速,优于所有基线。性能提升主要源于对核心模块的高度优化。图12(b)展示了核心模块(如注意力变体)的性能,FlashTensor在A100上最高实现了4.52倍的加速,在H100上最高实现了5.43倍的加速。我们观察到,一些结构相似的模型在使用FlashTensor优化后性能大致相当,但经过SOTA编译器优化后性能差异巨大。例如,Gemma2对Vanilla Attention做了微小修改,在A100上,TensorRT处理Gemma2和V.A.的推理时间分别为5.13ms和2.25ms,因为TensorRT的预写规则(类FlashAttention)匹配了V.A.但未能匹配Gemma2。而FlashTensor处理两者的时间分别为1.14ms和1.20ms,性能相当且均优于TensorRT,因为它能自动搜索最优方案并利用值感知优化消除因果掩码带来的冗余计算。
图13展示了H2O在不同序列长度下的计算效率和内存占用。
* 计算效率: 随着序列长度增加,FlashTensor的FLOP/s更高,因为其内存访问开销更低。TorchInductor和TensorRT需要在其CUDA核中从全局内存读写大小为$O(seqlen^2)$的中间张量,导致高内存开销。
* 内存占用: FlashTensor的全局内存占用显著小于PyTorch和TensorRT,尤其是在序列长度增加时。PyTorch和TensorRT在慢速全局内存上分配了随序列长度二次方增长($O(seqlen^2)$)的大型中间张量,而FlashTensor旨在最小化这类张量的分配。
我们将FlashTensor和TensorRT与手动调优的算子库FlashAttention和FlashInfer进行了比较,这两者仅支持Vanilla Attention和Gemma2。Gemma2【34】在Vanilla Attention中引入了逻辑软上限(logic soft capping),如图14(a)所示。在图14(b)中,FlashTensor的性能与FlashInfer和FlashAttention相当,后两者均采用了领域专家手动调优的算子。虽然FlashTensor没有超越这些算子库,但它提供了无需手动调优即可在各种场景中泛化和优化的优势。
我们进行了分解分析以展示FlashTensor各组件的性能影响,如图15所示。基线是PyTorch(加速比为1)。
1. Fission (算子分裂): 将Softmax等复杂算子分解为Exp、Reduce、Div等基本算子。这引入了更多小算子,增加了内存访问量,导致性能下降。
2. B.T. (广播属性感知变换): 最小化中间张量的大小,提升性能并为后续融合满足硬件约束创造条件。
3. K.M. w/o Fusion (无融合的核映射): 识别具有高并行效率的凸核和非凸核。由于识别出的核需要重计算一些算子来解决非凸核的循环依赖,在没有融合的情况下性能仍然会下降。
4. Fusion (融合): 带来最显著的性能提升,通过减少内存访问开销和利用已识别核的高并行效率,最高实现了14.2倍的加速。
5. V.T. (值属性感知变换): 利用因果掩码消除冗余计算,在所有模型上进一步实现了近2倍的额外加速。
我们以H2O为例,在A100 GPU上展示FlashTensor如何利用多种关键属性进行优化。
本文为解决长上下文带来的显著内存开销问题,提出了一个名为FlashTensor的DNN优化系统,该系统利用细粒度的张量属性来优化整体性能。我们总结了四个关键的张量属性,包括归约依赖、广播、尺寸和值,并从计算图中识别它们。FlashTensor进一步采用属性感知的变换和核映射来实现最优性能。实验结果表明,FlashTensor在典型的GPU平台上优于最先进的工作。
FlashTensor是一个开源项目,可公开访问:https://github.com/monellz/FlashTensor。复现论文中报告结果的说明可在以下地址找到:https://github.com/monellz/FlashTensor-AE 和 https://zenodo.org/records/14220175。