大型语言模型(LLMs)正越来越多地用于处理具有共享前缀的树状结构中的多个生成调用,例如少样本提示、多步推理和推测解码等。然而,现有的推理系统在处理这些树状应用时效率低下,主要因为在注意力计算过程中对查询(queries)和键值缓存(KV cache)的划分不当,导致了两个主要问题:(1)共享前缀的 KV 缓存缺乏内存访问(IO)复用;(2)负载均衡性差。这造成了 GPU 全局内存和共享内存之间冗余的 KV 缓存 IO,并降低了 GPU 的利用率。
为了应对这些挑战,本文提出了 DEFT (Decoding with Flash Tree-Attention),一种具有前缀感知和负载均衡 KV 缓存分区的硬件高效注意力算法。DEFT 的核心贡献如下:
核心问题与目标: 当前基于序列解码优化的 LLM 推理系统在处理具有树状结构(如图1所示)的应用时,存在计算、存储和内存访问(IO)三个层面的冗余,尤其是在内存访问上,共享前缀的 KV 缓存被反复加载,成为性能瓶颈。此外,直接将序列优化的 KV 缓存切分策略应用于树状结构会导致严重的负载不均衡,降低 GPU 利用率。本文旨在设计一种能感知前缀共享并实现负载均衡的树状注意力算法,以加速树状结构的 LLM 推理。
创新点与解决方案:
主要成果:
图 1: 序列化解码与树状解码的图示。
表 1 展示了在一个推理任务中,基于序列的思维链(CoT)与基于树的思维树(ToT)在解码效率上的对比。ToT 生成的 token 数量远超 CoT,导致其端到端延迟和 IO(包括 KV 缓存 IO 和注意力计算中间结果 IO)开销巨大。
LLM 推理及其瓶颈。LLM 推理包括两个阶段:(1)预填充(prefill)和(2)解码(decoding)。在预填充阶段,模型处理输入的提示(prompt)以初始化。其输出成为解码阶段的输入。解码阶段是自回归的,前一步的输出 token 作为下一步的输入 token。由于自回归解码的顺序性,LLM 推理是内存密集型(memory-bound)的,每次前向传播都需要将所有模型参数和 KV 缓存从速度较慢但容量大的高带宽内存(HBM)传输到速度快但容量小得多的 GPU 共享内存中。另一个潜在瓶颈是 GPU 利用率低,当并行度(通常受限于批处理大小)远小于 GPU 上的流式多处理器(SMs)数量时,操作只会利用一小部分 GPU 资源。
GPU 上的注意力算法执行模式。我们可以将注意力算法的执行分为两个主要阶段:(1)QKV 准备阶段:将查询(Query)、键(Key)和值(Value)(QKV)逻辑上分组成分区,并将这些 QKV 组映射到 GPU 的不同流式多处理器(SMs)上;(2)注意力计算阶段:将 QKV 分区加载到不同 SM 的共享内存中,并对每个组应用注意力算法以得到最终的注意力结果。
分段注意力的 QKV 分区。在基于序列的解码中,当并行度(通常受限于批处理大小)远小于 GPU 上的流式多处理器(SMs)数量时,QKV 分区至关重要。为了实现高 GPU 利用率,Flash-Decoding【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】通过对查询 和 KV 缓存进行分区,然后并行计算注意力。具体细节如下:(1)QKV 准备阶段:对于批处理中的每个查询,将其顺序的 KV 缓存分割成块(chunks)作为不同的 QKV 分区。(2)注意力计算阶段:它分别计算三个分段上的注意力 $A_0, A_1, A_2$,然后通过在线 Softmax 合并【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】基于不 同 QKV 分区的分段注意力结果得到最终的注意力。具体过程如下:
QKV 分区的重要性。对于树状解码,逻辑上对 QKV 进行分区对于实现高并行度的注意力计算是必要的。当树状结构 KV 缓存中的 token 数量很大时,由于内存容量限制,树状生成请求的分支数量可能不足以充分利用 GPU。例如,一个对128个数字进行排序的推理任务,在 Llama2-7B 模型中涉及约40K个 token,其 KV 缓存占用20GB,这意味着一个80GB的 A100 最多只能处理4个具有这样 token 数量的请求。
DEFT 的动机。DEFT 旨在解决 LLM 推理在处理树状结构 KV 序列时遇到的两个潜在瓶颈(即 IO 和 GPU 利用率)。如图2左侧所示的一个简单的两级级联树:对于两个查询 $Q_a$ 和 $Q_b$,对应的键满足 $K_a = K_0 \parallel K_1$ 和 $K_b = K_0 \parallel K_2$,值也遵循相同的规则。DEFT 的设计目标是:(1)通过消除对共享前缀 KV 缓存($K_0$ 和 $V_0$)的冗余内存访问来最小化 IO;(2)确保工作负载均衡以实现高 GPU 利用率,使得计算每个分段注意力 $A_i$ 的开销几乎相同。因为公式1中的全局归约需要所有部分注意力,如果计算 $A_i$ 的开销显著大于 $A_j$,负责计算 $A_j$ 的 SM 将会长时间空闲。
DEFT 技术概述。DEFT 旨在通过减少内存访问和确保树状解码的负载均衡,成为一种硬件高效的注意力算法。详细信息见图2:
1. QKV 准备阶段:为了实现前缀感知和负载均衡的 QKV 分区,我们引入了 KV-Guided Grouping 策略来复用共享前缀的 KV 缓存 IO,以及 Flattened Tree KV Splitting 策略,通过均衡并行的注意力计算实现高 GPU 利用率。详见第3.3节。
2. 注意力计算阶段:我们设计了 DEFT 注意力核,以内存高效的方式加载由 QKV 准备阶段逻辑分组的 QKV 分片,然后执行注意力计算。关键技术如下,详细信息推迟到附录A.4:
1. 通用核融合与分块策略:避免了对中间结果(即 $QK^T$ 和 Softmax)的显著 IO 操作,这是 Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】所缺乏的。
2. 树拓扑感知的全局归约:扩展了 Flash-Decoding【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】中的全局归约机制。该方法通过聚合来 自 QKV 组的部分注意力结果,并考虑树状结构,高效地计算每个查询的最终注意力。
DEFT 的系统框架。除了高效的 DEFT 注意力核,我们的 DEFT 系统还有另外两个优点:1)高效管理树状结构中的 KV 缓存,2)通过任意用户定义的函数灵活控制树解码过程,决定何时以及如何分支/剪枝。系统中关键组件及其协调的详细信息请参考附录A.1。
图 2: DEFT 概览。输入元数据在附录 A.1 详细阐述的系统中准备。在 QKV 准备阶段(见 3.3 节),QKV 将被逻辑地分组到分区中,同时兼顾共享前缀 KV 缓存的 IO 感知和负载均衡。这些分区将指导注意力计算阶段(见附录 A.4)的 QKV 加载,在该阶段将执行注意力计算。
从一个简单的解码例子开始。我们从一个使用图2所示的树状结构 KV 缓存的解码例子开始。如果我们将整个树状结构的 KV 缓存和查询分组成一个 G0 而不进行任何分区,我们可以参考图3(a)部分所示的 Vanilla Tree Attention 方法。这种方法在单个流式多处理器(SM)中,借助密集因果掩码(DCM),同时计算所有查询的注意力。然而,由于 GPU 利用率低,这种方法效率低下,正如3.2节所讨论的。为了解决这种低效问题,有效的 QKV 分区是必不可셔的。这个过程涉及两个关键考虑因素:(1)前缀感知以最小化对 KV 缓存的内存访问,(2)负载均衡以确保工作负载在 GPU 间的均匀分布。
图 3: DEFT-Node/Node-Chunk/Flatten 与不同注意力算法基线在 QKV 准备阶段的 QKV 分区策略比较。请注意,分区是逻辑设计的,不会产生任何 QKV 的数据移动成本。每个组所需的 GPU HBM 和共享内存之间的 IO 量已在红色矩形中突出显示。部分 (a) 展示了一个两级级联解码树示例的数据流和三类 QKV 分区策略:无分区(Vanilla Tree Attention)、Q-Guided Grouping 和 KV-Guided Grouping。分区策略将指导后续注意力计算阶段的 QKV 加载,其中每个 QKV 组 Gi 将被加载到 GPU 上的 SMi 中。部分 (b) 显示了 Q-Guided Grouping 和 KV-Guided Grouping 的比较,后者可以感知前缀 KV 缓存 KV0 的 IO,并只加载一次。DEFT-Node-Chunk 是 DEFT-Node 的一个弱负载均衡改进,通过将大节点(如 KV0)拆分为块。部分 (c) 展示了 DEFT-Flatten 中用于负载均衡分区的扁平化树 KV 切分细节(在备注 3.1 中讨论),包括深度优先扁平化策略、均匀分块策略和位掩码。有关基线和 DEFT 的摘要,请参见表 2。有关树注意力基线(Cai 等人,2024;Miao 等人,2023)的分析,请参见备注 3.2。
Q-Guided 与 KV-Guided Grouping。大多数现有的内存高效注意力算法【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】【54 ,Efficiently programming large language models using sglang,2023,arXiv preprint】采用 Q-Guided Grouping 进行 QKV 分区,其中每个查询作为分区的指示符,与其对应的 KV 缓存分组。然而,这种方法不是前缀感知的,例如,在 Flash-Attention 中(如图3所示),KV0 被加载了两次,即一次为 Qa,一次为 Qb。我们转而采用另一种 KV-Guided Grouping 方法:通过将每个节点的 KV 缓存与所有共享它的查询分组,可以使分区具有前缀感知能力,从而减少对 KV 缓存的内存访问。例如,DEFT-Node(如图3所示)在注意力计算中只加载一次前缀 KV 缓存 KV0。查询的额外 IO 成本可以忽略不计,因为每个查询只包含一个 token,而 KV 缓存可能包含数千个 token。
表 2: 基线(大部分如图3所示)和 DEFT 的 QKV 分区策略比较。对于 IO 冗余,显著问题用红色突出显示,可忽略的问题用蓝色突出显示。“Q” 指查询,“KV” 指 KV 缓存。“DCM” 代表密集因果掩码(一个矩阵),“BCM” 指位因果掩码(一组64位整数)。“PA” 代表注意力计算过程中的部分结果,包括 QKT、Softmax 等。更多 ⋆ 符号表示 QKV 分区的工作负载更均衡。IO 复杂度的详细信息可在附录 A.5 中找到。
Tree KV Splitting 与负载均衡。得益于 KV-Guided Grouping,DEFT-Node 对 KV 缓存 IO 是前缀感知的。然而,它引入了一个潜在的瓶颈:不同 SM 之间的工作负载不均衡。例如,在图3的 DEFT-Node 中可以看到,KV0 可能包含1000个 token,而 KV1 只包含2个 token。如果 G0 和 G1 分别分配给 SM0 和 SM1,SM1 会早得多完成计算并保持空闲,导致 SM 利用率低下。
一种直接的负载均衡方法。为了解决这个问题,我们需要更均匀地平衡 QKV 分区。一种直接的方法是在物理层面将 K0、K1 和 K2 分块,同时在逻辑层面保持节点级分区,如 DEFT-Node-Chunk(图3所示)。然而,这种负载均衡策略较弱:它只将大节点(例如,约1k token 的提示)分解为更小的 KV 块,而不能处理包含许多小节点的情况(例如,推测性解码),这可能由于需要更多轮次的 GPU 执行来处理额外的 QKV 组而减慢推理速度。
提出 DEFT-Flatten。由于 KV 缓存加载是注意力计算的主要瓶颈【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【35,Quest: Query-aware sparsity for efficient long-context llm inference,2024,arXiv preprint】,因此在每个 KV 缓存分区中实现均匀的 token 长度非常重要。因此,我们提出了 DEFT-Flatten,在备注 3.1 中详细阐述。
备注 3.1(扁平化树 KV 切分技术)。如图3(c)部分所示,有三个关键组成部分:
备注 3.2(关于树注意力算法的讨论)。现有的注意力算法是为推测性解码设计的,其中注意力是为整个树状结构的查询计算的。然而,这些方法在内存效率上并不高。有关分区细节,请参见附录A.5中的图11。
* Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】。该方法基于 Vanilla Tree Attention(如图3左侧所示),使用 PyTorch 的通用矩阵乘法(GEMM)对 Q 和 KV 张量进行分区。它的内存效率不高,原因有二:(1)它没有利用 Flash-Attention 来减少计算中间结果(如 Softmax)时的内存访问;(2)它引入了一个密集的因果掩码,其内存访问量很大。
* Tree Attention-SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。该算法采用基于 Vanilla Tree Attention 的 Q-Guided Grouping,并通过 Flash-Decoding 对 KV 序列进行分区。它在内存效率上不高,因为会为每个查询冗余地加载整个树状结构的 KV 缓存。
IO 复杂度分析。我们证明了 DEFT-Flatten 在 IO 复杂度上优于现有的注意力算法,包括 Flash-Decoding【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】 、Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】和 Tree Attention-SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。详见附录 A.5。
实现细节。我们使用 OpenAI Triton【36,Triton: an intermediate language and compiler for tiled neural network computations,2019,Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages】实现 DEFT 注意力核,这使我们能够以线程块的粒度控制从全局内存到共享内存的内存访问和注意力计算。DEFT-Node 和 DEFT-Flatten 算法的两个阶段的 Python 风格实现分别见附录 A.8 和附录 A.9。
表 3: 基线与 DEFT 的比较。基线的注意力核实现与其内存管理方式相适应。因此,为了与基线进行公平比较,我们实现了同时适用于分页(Kwon et al., 2023)/非分页内存管理的 DEFT-Node 和 DEFT-Flatten。
表 4: 工作负载生成。ToT-BFS 代表使用广度优先搜索的思维树(Yao et al., 2023)。APPS(Hendrycks et al., 2021)是一个竞争性编程问题数据集。Medusa(Cai et al., 2024)是一个推测性解码框架。“GoT” 代表思维图(Besta et al., 2023),其中包含使用 GPT-3.5 在 ToT-BFS 中进行复杂推理任务的迭代记录。更多细节见表 13。
图 4: 推测性解码(32个查询的 token 树)的延迟分解,树拓扑来自 Medusa (Cai et al., 2024)。U 表示非分页内存。
表 5: DEFT-Flatten 与基线在树状解码平均解码延迟(秒)上的比较。b 代表树的宽度,t 代表 token 树的大小(即树状结构查询的数量)。最快的方法用粗体表示,次快的方法用下划线表示。Radix Attention 是解码延迟方面的最佳基线。⋆ 表示 A100 80GB GPU 出现内存不足(OOM)错误。Speedup Upper-bound (no attention) 指的是如果我们排除注意力计算,只运行包括 MLP 在内的其他组件,Radix Attention 所能达到的最大加速比。更多关于注意力加速的细节,请参见表 16。
表 6: [不同 KV 切分策略] DEFT-Node、DEFT-Node-Chunk 和 DEFT-Flatten 在 Llama3-8B 模型(GQA)上使用 NVIDIA A100(80GB)的平均注意力延迟(秒)比较。此表为表 16 的补充。最快的方法用粗体表示,次快的方法用下划线表示。Radix Attention 是解码延迟方面的最佳基线。更多基线的详细信息请参见表 16。
表 7: [不同提示长度] DEFT-Flatten 和 Radix Attention 在多步推理任务“排序”中的效率比较。原始提示长度约为 1K token,我们将其填充到 5K、8K 或 10K token 的长度。
表 8: [不同模型大小] DEFT 和 Radix Attention 在 Codellama-34B 和 Codellama-7B 模型上的解码延迟加速比和注意力/FFN 延迟比(简称 A/F-LR)的比较。Radix Attention 是解码延迟方面的最佳基线。b 代表树的宽度,t 代表 token tree 的大小。对于多步推理,我们测试了提示长度约为 1k token 的排序任务。
本文提出了 DEFT-Flatten,一种为树状结构 LLM 推理优化的硬件高效注意力算法。它通过复用共享前缀的 KV 缓存和均匀分配工作负载,有效解决了内存访问和 GPU 利用率的瓶颈问题。DEFT-Flatten 的核心优势在于其前缀共享感知能力和负载均衡性,使其能够广泛适用于各种树状结构任务,并能很好地扩展到更大的搜索空间和更多的分支。实验结果表明,DEFT-Flatten 在解码和注意力延迟方面实现了高达 2.23倍/3.59倍的加速,在少样本提示、多步推理和推测解码等任务中均优于基线方法。我们的消融研究强调:(1)均衡的分区至关重要;(2)DEFT-Flatten 在各种 LLM 模型和 GPU 架构上都能提供显著的加速;(3)随着树状请求中 token 数量的增加(例如更长的提示)和分支数量的增多,DEFT-Flatten 的加速效果会更大。
图5左侧展示了高效灵活的树状解码中不同组件的协调工作。DEFT 系统组件的功能详情如下:
图 5: DEFT 图解。(左) 系统概览。(右) 使用一个解码树示例展示 DEFT-Node 的数据流 (DEFT-Flatten 类似,仅 QKV 分区不同)。
图5右侧通过一个解码树示例进一步展示了系统的关键数据流。为简化起见,我们在此展示了 DEFT-Node,DEFT-Flatten 的流程类似,只是 QKV 分区方式不同。输入元数据将由我们上面提到的三个组件提取,然后在3.3节讨论的 QKV 准备阶段之后,以分组方式从 HBM 加载到共享内存。接着,QKV 组将由 DEFT 注意力核在 DEFT 的注意力计算阶段进行处理。关于这两个阶段技术的详细信息,请参阅附录A.4。
(a) (左) 序列 KV 与树状查询用于并行解码【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】,其中应用因果掩码记录树状 token 查询间的因果信息。(右) 树状 KV 与并行查询用于多步推理中的共享前缀。
图 6: 关于使用树状查询【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】和树状 KV 的树状解码的讨论。
(b) SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】中的位掩码,用于记录树状结构中查询 token 间的因果信息。解码树位于 6a 的左侧部分。
树状解码的两种模式。树状解码可以采用树状结构的 KV 缓存来存储并感知共享前缀【54,Efficiently programming large language models using sglang,2023,arXiv preprint】,或者在并行/推测解码中采用树状结构的查询【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】,如图6所示。一个通用的解码过程可以同时处理树状 KV 和树状查询,这既可以减少共享前缀的冗余(例如 IO、存储、计算等),也可以增加每次解码迭代生成的 token 数量。
现有框架的局限性。现有的针对树状解码效率优化的推理框架【54,Efficiently programming large language models using sglang,2023,arXiv preprint】【9,Prompt cache: Modular attention reuse for low-latency inference,2023,arXiv preprint】主要目标是:(1)减少内存占用【54,Efficiently programming large language models using sglang,2023,arXiv preprint】以支持更大的批处理大小,从而提高吞吐量;(2)复用提示缓存【9,Prompt cache: Modular attention reuse for low-latency inference,2023,arXiv preprint】以避免 KV 缓存的重新计算,从而加快首个 token 的生成时间(TTFT)。然而,它们的设计并未专门针对减少整个解码过程的延迟。我们观察到,LLM 推理的树状结构特性可以为我们提供一些加速解码本身的优势。
树状解码的加速潜力分析。在树状解码中,KV 缓存和查询可以组织成树状结构。我们不仅可以以树状结构存储 KV 缓存,还可以在注意力计算期间以感知树拓扑的方式加载 QKV,以最小化 HBM 和 GPU 片上共享内存之间昂贵的 IO。我们通过两个复杂交互场景的案例研究来解释这一点:(1)多步推理;(2)推测解码。
案例研究1:多步推理。如图7左侧所示,多步推理【11,Reasoning with language model is planning with world model,2023,arXiv preprint】【47,Tree of thoughts: Deliberate problem solving with large language models,22,2023,arXiv preprint】【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】过程可概括为三个阶段:(1)思想生成:基于生成提示 $P_g$ 和先前步骤 $S$ 生成下一步的 k 个候选思想;(2)思想评估:LLM 作为状态评估器,基于评估提示 $P_e$ 评估先前的思想 $S$ 对解决问题的贡献,作为搜索算法的启发式信息;(3)基于树搜索的扩展:采用不同的搜索算法探索搜索空间。在阶段(1)和(2)中,我们都可以在树注意力计算期间共享 $P_g/P_e$ 和 $S$ 的 KV 缓存 IO。
图 7: 树状解码两个案例研究的分析。(左) 多步推理。(右) 推测解码。蓝色框表示在树注意力计算期间可在存储和内存访问中共享的过去 KV 缓存,而黄色框表示生成上下文的 KV 缓存。
案例研究2:推测解码。如图7右侧所示,推测解码【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】过程可概括为三个阶段:(1)Token 树生成:多个小型草稿模型或微调的头基于提示 $P$ 生成多个 token 序列,然后合并成一个推测的 token 树 $T_t$;(2)Token 验证:基于这些树状结构的候选 token $T_t$,对照 LLM 的输出验证其正确性,其中树注意力计算是该过程的瓶颈;(3)保留 KV 缓存。在阶段(2)中,我们可以在树注意力计算期间共享 $P$ 和 $S$ 的 KV 缓存 IO。
为什么现有的树注意力算法不够用? 现有的树注意力算法要么在内存访问上效率低下【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】,要么不适用于 token 树中超过64个 token 的通用树状解码【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。
* SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】中,如图6b所示,使用位掩码记录 token 树查询间的因果信息。每个查询 token $t_i$ 都有一个64位整数作为位掩码,其中第j位表示 $t_i$ 的查询与 $t_j$ 的 KV 缓存之间的因果关系。这种设计的优点是大大减少了 IO,但导致树中 token 的最大数量仅为64,这对于具有树状结构 KV 缓存的场景不实用。此外,它对 KV 缓存的 IO 不敏感,因为它会为每个查询加载整个树的 KV 缓存。
* Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】适用于通用的树状解码,但由于密集的因果掩码和注意力计算过程中的部分结果(如 softmax)产生的大量 IO,其硬件效率不高。
与同期工作的共同点与不同点。有一些同期工作【3,Bifurcated attention for single-context large-batch sampling,2024,arXiv preprint】【48,Chunkattention: Efficient self-attention with prefix-aware kv cache and two-phase partition,2024a,arXiv preprint】【19,Hydragen: High-throughput llm inference with shared prefixes,2024,arXiv preprint】【55,Relayattention for efficient large language model serving with long system prompts,2024,Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】关注于单上下文大批量采样场景的注意力算法设计,其目标是从单个上下文(如系统提示或少样本示例)生成多个序列,这是树状解码深度为1的特例。它们的算法设计基于这一特性,这意味着它们无法高效地适用于具有超过两级前缀的树的注意力计算。
共同的见解和技术。同期工作和 DEFT 都认识到内存访问是 LLM 推理的瓶颈,并通过分解注意力来减少前缀 KV 的内存访问:(1)分别计算前缀和后缀上的注意力 $A_p$ 和 $A_s$;(2)基于 $A_p$ 和 $A_s$ 通过在线 softmax 合并【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】得到最终的注意力。其正确性证明如下 :
- 设键张量 $K \in R^{(l_{kv},d)}$,值张量 $V \in R^{(l_{kv},d)}$,查询张量 $Q \in R^{(l_q,d)}$。考虑一般情况,K 和 V 沿序列(行)维度被划分为前缀和后缀两部分:$K = K_p \parallel K_s$ 和 $V = V_p \parallel V_s$。
- 我们分别计算前缀和后缀上的注意力 $A_p = \langle Q, K_p, V_p \rangle$ 和 $A_s = \langle Q, K_s, V_s \rangle$。
- 根据公式1,我们可以得到分段注意力 $\langle Q, K, V \rangle = \text{SegAttn}(A_p, A_s)$。
表 9: DEFT 与同期工作在单上下文大批量采样场景中的比较,包括 Chunk-Attention (Ye et al., 2024a)、Hygragen (Juravsky et al., 2024) 和 Bifurcated-Attention (Athiwaratkun et al., 2024)。RelayAttention (Zhu et al., 2024) 和 Cascade-inference (Ye et al., 2024b) 类似于 Hygragen。更多的 ⋆ 表示树切分后工作负载更均衡,也显示了加速对树拓扑的不敏感程度。
差异比较。现有的单上下文大批量采样工作对于通用的树状解码在硬件上不够高效,原因有二,如表9所示:
- 仅为两级树设计。它们专为只有两层的解码树设计——根部的所有前缀和深度为1的所有后缀。对于具有多级前缀的解码树,它们的算法只能减少树根部提示的 IO。然而,在多步推理等场景中,非根前缀的 token 长度也可能很长(例如,数千个 token),而它们的 KV 缓存 IO 并未被复用。DEFT 可以复用通用解码树中所有非叶前缀的 KV IO,提供了更大的加速潜力。
- 未解决负载不均衡问题。它们没有解决树状解码中的工作负载不均衡问题。解码树中的节点大小差异可能很大,因此以一种能确保每个 QKV 组计算均衡的方式切分树和分组 QKV 至关重要。仅仅根据深度进行划分是不够的。
本小节总结并讨论了现有高效注意力算法和核设计中的常用技术,并解释了 DEFT 注意力核的设计细节。
表 10: DEFT 的技术列表。我们提出的技术用红色标出。前四项技术的细节在第 3.3 节,后续技术的细节在本章讨论。
核融合与分块策略。核融合 (Kernel Fusion) 是一种常见的 IO 减少技术:如果多个操作在相同的输入上执行,那么从 HBM 一次性加载输入比为每个操作多次加载更高效。为了将所有注意力操作融合到一个 GPU 核中,我们进一步利用了常用的 分块策略 (Tiling strategy)【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】:将每 个 QKV 组内的查询和 KV 缓存分割成小块,通过在有限的共享内存中计算注意力,避免了在 HBM 中物化注意力矩阵,然后根据公式1增量地执行 softmax 归约来重构注意力。
备注 A.1: Tree Attention in Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】没有采用分块和核融合,导致了中间结果(如 $QK^T$ 和 Softmax)的大量 IO,如图8所示。而 Flash Decoding、Tree Attention in SpecInfer 和我们的 DEFT 都采用了分块和融合核。
图 8: Tree Attention-Medusa (Cai et al., 2024) 的操作。未应用核融合或分块策略,这导致了像 QK⊤、DCM 和 Softmax 等部分结果在 GPU 全局内存和片上共享内存之间的显著 IO。
树拓扑感知的因果掩码。因果掩码 (Causal Mask) 在推测解码工作【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】中被引入,通过记录解码树中查询和 KV 缓存之间的因果关系,实现在单个 GPU 核内计算整个解码树的注意力。
备注 A.2: 因果掩码带来了两部分冗余:
- 内存访问:Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】在 HBM 中物化了密集的因果掩码(DCM),带来了巨大的 IO 成本。SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】引入了位因果掩码(BCM),IO 成本极小,但仅限于64个 token。DEFT-Flatten 采用了受 SpecInfer 启发的位因果掩码,以最小化掩码的 IO。
- 计算:除了生成掩码本身的计算成本外,许多 $QK^T$ 的矩阵乘法结果被掩码掉而从未使用,造成了计算冗余。
切分与全局归约。切分 (Splitting) 被引入以提高 GPU 在序列解码中的利用率【16,Flashdecoding++: Faster large language model inference on gpus,2023,arXiv preprint】。Flash-Decoding 首先基于 Q 对长 KV 进行切分和分组,然后将这些组分配给不同的 SM 计算部分注意力。为了获得准确的最终注意力,需要对来自具有相同查询的 QKV 组的部分注意力进行全局归约 (Global Reduction)。类似地,DEFT 也将解码树切分成不同的 QKV 组(即我们提出的扁平化树 KV 切分策略),以实现 SM 的负载均衡。DEFT 也需要一个全局归约,但 Flash-Decoding 的归约是为序列解码设计的,不能感知树拓扑。因此,我们提出了树拓扑感知的全局归约,如图10b所示。
DEFT 注意力核设计。基于上述技术,我们设计了具有两个阶段的 DEFT 注意力核,如图9所示。
- 阶段1:计算部分注意力。基于 QKV 准备阶段的分组结果,每个 QKV 组 ($G_i$) 被分配到一个线程块,使用 Flash Attention【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】计算部分注意力 ($PA_i$) 和 LogSumExp ($LSE_i$)。
- 阶段2:全局归约。DEFT 执行树拓扑感知的全局归约,根据树拓扑逻辑上重映射部分注意力和 LogSumExp,从而为每个查询获得正确的最终注意力。
图 9: DEFT 注意力核两个阶段的概览(以 DEFT-Node 为例,DEFT-Flatten 类似)。阶段 1 – 计算部分注意力。基于上述 KV-Guided Grouping 策略和 Tree Split 后的 QKV 分组结果,每个 QKV 组 (Gi) 将被分配到一个线程块,用于使用通用的核融合和分块策略进行 Flash Attention (Dao et al., 2022) 计算。与 Flash-Decoding (Dao et al., 2023) 类似,我们不仅得到部分注意力 (PAi),还返回 “LogSumExp” (LSEi) 作为下一阶段归约的权重参数。阶段 2 – 全局归约。在收到每个 QKV 组 Gi 的 PAi 和 LSEi 后,DEFT 现在执行树拓扑感知的全局归约 (DeFT_reduction)。在解码树中 KV 序列节点间的树拓扑指导下,DEFT 逻辑上重映射注意力和 LogSumExp 的部分结果,以在归约后为每个查询获得正确的最终注意力。解码树与图 3 左侧的相同。SMi 表示 GPU 中的流式多处理器 i。
(a) 左:DEFT 注意力核两阶段图示。右:图 10b 中 DEFT 阶段 2 调用的全局归约核。QKV 组 G0、G1 和 G2 来自图 3 中的 DEFT QKV 组。
图 10: DEFT 核的详细注意力操作(以 DEFT-Node 为例,DEFT-Flatten 类似)。基于图 3 中的相同解码树。
(b) DEFT 阶段 2:全局归约(以 DEFT-Node 为例)。基于图 3 中的树拓扑,我们可以根据查询对 LogSumExp 和部分注意力进行分组,然后调用图 10a 右侧的全局归约核来获得最终的注意力。
分析设定。本节分析 DEFT 的 IO 复杂度,显示其相比现有注意力算法在 HBM 访问上显著减少。我们基于单次迭代中的解码树快照来比较 IO,使用的符号如表11所示。
表 11: 符号表。
IO 复杂度对比 (表12)。
- 序列解码方法 (Naive Attention, Flash-Decoding): 由于缺乏树拓扑感知,其 KV 缓存的内存访问开销比 DEFT 和 Tree Attention-Medusa 高出 $F_s$ 倍。
- Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】: 尽管 KV 缓存 IO 较少,但由于缺乏分块和核融合,其在部分结果(如 $QK^T$ 和 Softmax)上的 IO 开销很高。此外,其引入的密集掩码也带来了显著的 IO 成本。
- Tree Attention-SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】: 如图11所示,该方法采用 Q-Guided Grouping,导致每个查询都需要加载整个树的 KV 缓存,KV 缓存的 IO 开销巨大。
- DEFT-Flatten: 通过 KV-Guided Grouping 和核融合,DEFT-Flatten 避免了 KV 缓存和部分结果的冗余 IO。其采用的位掩码 IO 成本极低,使其在所有方法中具有最低的 IO 复杂度。
图 11: (图 3 补充) Tree Attention (Cai et al., 2024; Miao et al., 2023) 的 QKV 分区和内存访问。Tree Attention-Medusa (Cai et al., 2024) 通过 PyTorch 中的通用矩阵乘法 (GEMM) 对 QKV 进行分区。Tree Attention-SpecInfer (Miao et al., 2023) 采用 Q-Guided Grouping。QBCM 是 SpecInfer 的 Q-Guided 位因果掩码,其中每个位表示查询与 KV 缓存中一段 token 之间的因果关系。例如,Qa 的 Q-BCM 是 "110",意味着 KV 缓存的前两段 KV0 和 KV1 对 Qa 的注意力有效。图中的 Qi 和 Kj 与图 3 中的相同。
表 12: 各种方法的 IO 复杂度分解。O(1) 表示张量中单个数据在所有层和头上的 IO 成本,相当于 #heads * #layer * dtype_size。表中所有方法中的最佳者用红色标出,(潜在的)最差者用蓝色标出。查询 IO 被省略,因为它对所有方法都是 O(k * ln * d_head),其中 k 是 QKV 组的数量。对于 DEFT-Node,k = #node;对于 DEFT-Node-Chunk,k = Σ(#node_i=1) ceil(ni/bs);对于 DEFT-Flatten,k = N_tree / bs。Tree Attention-S 中的 S 代表 SpecInfer (Miao et al., 2023)。
工作负载设置的合理性。为了验证 DEFT 在不同解码树拓扑下的加速效果,我们从真实任务中编译了解码树,涵盖以下三个方面:
- 少样本提示:这是一个两级树,包含一个提示前缀和多个用于后缀生成的分支。作为案例研究,我们将提示长度固定在约4000个 token,并改变分支数量。
- 多步推理【47,Tree of thoughts: Deliberate problem solving with large language models,22,2023,arXiv preprint】【11,Reasoning with language model is planning with world model,2023,arXiv preprint】【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】: 我们记录了真实推理任务交互【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】中的树形、提示和所有思想的长度,并以此为指导进行树解码,以验证 DEFT 在推理的思想生成阶段的加速效果。生成细节见图12。
- 推测解码【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】: 我们使用了 Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】的 token 树拓扑,并记录了与 APPS【12,Measuring coding challenge competence with apps,2021,arXiv preprint】的真实交互数据,包括每一步接受的 token 长度。这作为指导来模拟推测解码的瓶颈——token 验证阶段的注意力计算。
图 12: 重建多步推理树模板的详细过程。(左) 从实际推理记录中重建推理树,如 (Besta et al., 2023) 所述,涉及捕获以下方面:(1) 树的结构,以其深度 d 和宽度 w 为特征;(2) 与每个思想相关的 token 长度;(3) 每个深度的最佳思想及其对应的分数。对于文档合并任务,树的深度设置为 d=3,每个深度的宽度为 w=10。对于排序 128 个数字,深度减少到 d=10,同时保持相同的宽度 w=10。有关其他多步推理任务的树拓扑详细信息,请参见表 13。(右) 利用从左侧提取的思想信息,我们可以生成用于解码的树模板,包括分支记录和剪枝记录。这些记录有助于指导树解码过程,以生成忠实复制思想树结构的解码树。
实验范式的合理性。我们的实验范式包括:首先,从真实的树状解码任务中获取解码树;其次,通过强制 LLM 推理在同一框架内精确复制这些解码树,以研究注意力加速对实际运行时间性能的影响。这种范式有两个优点:
- 我们可以在一个统一的系统中使用来自真实任务的解码树作为基准,从而能够在解码延迟方面对不同的注意力算法进行公平比较。
- 我们既考虑了具有不同树结构的任务的独特性,也考虑了通用树状解码的更广泛适用性。
表 13: 生成的工作负载详情。对于多步推理,我们包括了来自 Besta et al. (2023) 的这 4 个任务:(1) 排序 128 个数字(简称 sorting);(2) 文档合并(简称 document);(3) 关键词计数(简称 keyword);(4) 集合交集(简称 set)。d 和 w 分别表示树的深度和宽度。t 表示推测性解码的 token 树大小,其树拓扑来自 Medusa (Cai et al., 2024)。
DEFT-Node 的 GPU 利用率微基准测试。如表14所示,我们测试了 DEFT-Node 和 DEFT-Flatten 在一个有64个查询和4k token 提示的推测解码任务上的表现。对于 DEFT-Node,QKV 分区是不均衡的。结果显示,DEFT-Flatten 在内存利用率(内存吞吐率)和计算利用率(计算吞吐率和低利用率时间比)方面都优于 DEFT-Node。
表 14: [GPU 利用率微基准测试] 在 NVIDIA A100 (80GB) 上使用 LLama3-8B 模型 (GQA) 时,DEFT 单层注意力的延迟(μs)、SM 计算吞吐率、内存吞吐率和低利用率时间比。工作负载为具有 64 个查询和 4k token 提示的推测性解码。计算吞吐率指的是 GPU 中流式多处理器(SM)的利用率。内存吞吐率表示实际内存吞吐量与最大带宽的比率。低利用率时间比定义为计算吞吐率低于 5% 的时间比例。
注意力延迟和 IO 分解。注意力延迟和 IO 的详细比较分别见表16和表17。
推理准确性。如表15所示,DEFT 的注意力分数与 Huggingface Transformers 中的原始注意力相比可能略有差异(约0.4%的相对误差),但生成的 token 和 PPL(困惑度)几乎没有差别。这种差异是由于 GPU 上的浮点运算不遵循结合律造成的,这在其他引入在线 Softmax 和归约的方法中也很常见。
表 15: DEFT 在注意力分数和困惑度(PPL)方面的推理准确性。PPL 是在 400 次解码迭代后计算的。Vanilla Attention 是 Huggingface Transformers 的实现。
动态行为:逐迭代延迟。图13展示了在多步推理任务“排序”中,DEFT-Node 和 DEFT-Flatten 的逐迭代延迟。我们观察到,DEFT-Node 相对于 DEFT-Flatten 的加速比与树节点大小的离散程度呈强正相关。这是因为 DEFT-Flatten 的性能相对稳定,而 DEFT-Node 的性能受树拓扑影响更大。
图 13: 在排序任务中比较切分策略 DEFT-Node 和 DEFT-Flatten。加速比指的是 DEFT-Node 和 DEFT-Flatten 的每次迭代延迟之比。树节点长度标准差表示每次迭代的树节点长度的标准差。
消融研究:解码树宽度的影响。如图14所示,我们固定提示长度为4000,改变解码树的宽度。随着树宽度的增加,DEFT-Flatten 的解码加速效果更显著(从宽度10的几乎无加速到宽度50的1.33倍加速),因为更宽的树意味着提示前缀的 KV 缓存 IO 被更频繁地复用。
图 14: 不同树宽度的少样本提示任务的每次迭代延迟。e2e 表示解码延迟(最优端到端延迟),而 Attn 表示仅注意力开销。
消融研究:KV 切分中块大小的影响。如图15所示,块大小的选择是在 IO 冗余和线程块调度之间的权衡。更大的块大小意味着查询 IO 冗余更少,但可能因线程块较少而导致 GPU SM 空闲。结论是:(1)最佳块大小受序列长度和查询数量的影响;(2)DEFT-Flatten 在所有测试的块大小上都优于 DEFT-Node-Chunk。
图 15: DEFT 的 KV 块大小消融研究。t 是推测性解码中的 token 树大小。
消融研究:提示长度、不同GPU、模型架构的影响。图16-18显示,随着提示长度增加,DEFT 的优势更明显。表19显示 DEFT 在 RTX 4090 上也有明显加速。表20和表21显示 DEFT 对 MHA 和 GQA 架构的模型都能显著加速。
表 16: 树状解码的平均注意力延迟(秒)及其对解码延迟的影响。b 代表树宽,t 代表 token 树大小。Attention Speedup over the best attention 指的是 DEFT-Flatten 相对于最佳基线(通常是 Tree Attention-Medusa)在注意力计算上的加速比。Radix Attention 是解码延迟的最佳基线。注意,KV 缓存管理不包含在注意力延迟内。⋆ 表示 A100 80GB GPU 内存不足。更多解码延迟细节见表 5。
表 17: 解码期间的平均端到端 IO (TB)。数据格式为左/右:(左) KV 缓存 IO;(右) 部分结果 IO,包括 QKT、QK⊤/sc、Mask M、M + QK⊤/sc 和 Softmax。b 表示树宽。t 表示 token 树大小。⋆ 表示 A100 80GB 内存不足。
图 16: 推测性解码中不同提示长度下 DEFT 的每输出 token 时间 (TPOT)。
图 17: 推测性解码中不同提示长度下 DEFT 的解码延迟。
图 18: 推测性解码中不同提示长度下 DEFT 的注意力延迟。
表 18: [模型大小和提示长度的消融研究] 在排序推理任务中,不同提示长度下,DEFT 和 Radix Attention 在 Codellama-34B 和 Codellama-7B 上的解码加速比和注意力/FFN 延迟比 (A/F-LR) 的比较。Radix Attention 是解码延迟方面的最佳基线。
表 19: [不同 GPU] 在 NVIDIA RTX 4090 (24GB) 上使用 LLama3-8B 模型 (GQA) 时,DEFT 的平均注意力延迟(秒)加速。Radix Attention 是解码延迟方面的最佳基线。
表 20: [不同模型架构 (GQA)] 在 NVIDIA A100(80GB) 上使用 Codellama-34B 模型 (GQA) 时,DEFT 的平均注意力延迟(秒)加速。Radix Attention 是解码延迟方面的最佳基线。
表 21: [不同模型架构(MHA)] 在 NVIDIA A100(80GB) 上使用 Codellama-7B 模型(MHA) 的 DEFT 平均注意力延迟(秒)加速。Radix Attention 是解码延迟方面的最佳基线。
算法1 DEFT-Node 算法-阶段1:QKV 准备
<blockquote>输入: 查询 $Q \in R^{(b_q, d)}$,键缓存列表 $KL = (K_0, ..., K_{N-1})$,值缓存列表 $VL = (V_0, ..., V_{N-1})$(对应树中每个序列节点),树 $T$ 及其拓扑信息。
处理:
1. 对于 $Q$ 中的每个查询 $q$,获取其所有前缀的 KV 索引,存入 QMapKV。
2. 对于每个序列的 KV 缓存 $K_i, V_i$,将其与所有共享它的查询进行分组,得到 $Q_i$,存入 KVMapQ。
返回: QMapKV, KVMapQ
算法2 DEFT-Node 算法-阶段2:注意力计算
<blockquote>输入: 查询 $Q$, 键缓存列表 $KL$, 值缓存列表 $VL$, 树 $T$, 以及从阶段1得到的 QKV 分组信息 QMapKV, KVMapQ。
处理:
1. 阶段1:计算部分注意力
- 为每个 QKV 组 $(Q_i, K_i, V_i)$ 并行调用 FlashAttention,得到部分注意力 $o_i$ 和 LogSumExp $lse_i$。
- 将每个 $o_i, lse_i$ 根据其对应的原始查询索引映射回全局存储。
2. 阶段2:全局归约
- 对于每个查询 $q$,在收集完其所有相关的部分注意力和 LogSumExp 后,执行全局归约(类似在线 Softmax 合并),计算出最终的注意力结果 $FO[idx]$。
返回: 最终的注意力输出 $FO$
DEFT-Node 采用节点粒度的切分策略,简单但可能导致负载不均衡。因此,我们提出 DEFT-Flatten,以更均衡的子树粒度进行切分。
DEFT-Flatten 算法-阶段1:QKV 准备
<blockquote>输入: 与 DEFT-Node 相同,额外增加子树大小 $S_t$。
处理:
1. 均匀切分树 KV 缓存: 将整个树的 KV 缓存按深度优先顺序扁平化,然后均匀地切分为多个子树,每个子树最多包含 $S_t$ 个 token。得到子树信息 SubInfo,子树键缓存 KSub,子树值缓存 VSub。
2. 分组与掩码生成:
- 对于每个子树的 KV 缓存,将其与所有共享它的查询分组,得到 $Q_i$。
- 为每个子树生成一个位因果掩码 CausalMask,用于记录子树内不同节点与查询之间的因果关系。
返回: QMapKV, KVMapQ, CausalMask, SubInfo
DEFT-Flatten 算法-阶段2:注意力计算
<blockquote>该阶段与 DEFT-Node 类似,但有关键区别:
1. 阶段1:计算部分注意力
- 在调用 FlashAttention 之前,需要根据 CausalMask 和 SubInfo 重构出适用于当前子树的注意力掩码。
- 其余部分与 DEFT-Node 相同,计算每个子树 QKV 组的部分注意力 $o_i$ 和 LogSumExp $lse_i$。
2. 阶段2:全局归约
- 与 DEFT-Node 完全相同,对每个查询收集到的所有部分结果执行全局归约。
返回: 最终的注意力输出 $FO$