DEFT: Decoding with Flash Tree-Attention for Efficient Tree-Structured LLM Inference

  • 文章标题: DEFT:利用闪存树注意力机制为树状结构大语言模型实现高效解码
  • 作者/机构: Jinwei Yao, Kaiqi Chen, Kexun Zhang, Jiaxuan You, Binhang Yuan, Zeke Wang, Tao Lin
  • 机构: Westlake University, Zhejiang University, Carnegie Mellon University, University of Illinois Urbana-Champaign, Hong Kong University of Science and Technology

A1 主要贡献

大型语言模型(LLMs)正越来越多地用于处理具有共享前缀的树状结构中的多个生成调用,例如少样本提示、多步推理和推测解码等。然而,现有的推理系统在处理这些树状应用时效率低下,主要因为在注意力计算过程中对查询(queries)和键值缓存(KV cache)的划分不当,导致了两个主要问题:(1)共享前缀的 KV 缓存缺乏内存访问(IO)复用;(2)负载均衡性差。这造成了 GPU 全局内存和共享内存之间冗余的 KV 缓存 IO,并降低了 GPU 的利用率。

为了应对这些挑战,本文提出了 DEFT (Decoding with Flash Tree-Attention),一种具有前缀感知和负载均衡 KV 缓存分区的硬件高效注意力算法。DEFT 的核心贡献如下:

  • 核心问题与目标: 当前基于序列解码优化的 LLM 推理系统在处理具有树状结构(如图1所示)的应用时,存在计算、存储和内存访问(IO)三个层面的冗余,尤其是在内存访问上,共享前缀的 KV 缓存被反复加载,成为性能瓶颈。此外,直接将序列优化的 KV 缓存切分策略应用于树状结构会导致严重的负载不均衡,降低 GPU 利用率。本文旨在设计一种能感知前缀共享并实现负载均衡的树状注意力算法,以加速树状结构的 LLM 推理。

  • 创新点与解决方案:

    1. KV-Guided Grouping(KV 引导分组): 为解决前缀 KV 缓存的冗余加载问题(挑战 C1),DEFT 提出了一种新的分组策略。与传统方法(Q-Guided Grouping)中每个查询(Query)与其所有对应的 KV 缓存分组不同,KV-Guided Grouping 将每个前缀节点的 KV 缓存与所有共享该前缀的查询进行分组。这确保了共享前缀的 KV 缓存仅被加载一次,从而显著减少了 IO,而重新加载查询所带来的额外 IO 开销可以忽略不计。
    2. Flattened Tree KV Splitting(扁平化树 KV 切分): 为解决树状 KV 缓存切分时的负载不均衡问题(挑战 C2),DEFT 提出了一种扁平化切分机制。该机制将整个树状结构的 KV 缓存逻辑上扁平化为一个序列,然后将其均匀地切分成多个块(chunks)。这种方法确保了每个分区中的 KV 缓存长度基本相等,从而实现了负载均衡,提高了 GPU 在注意力计算期间的利用率。同时,使用位因果掩码(bit causal masks)来高效地表示查询和 KV 缓存之间的因果关系。
    3. 硬件高效的注意力核: DEFT 在 OpenAI Triton 上实现了一个高效的注意力核,通过精确控制内存访问,并将所有注意力操作(包括分块计算、掩码应用和全局归约)融合成一个单一的 GPU 内核,避免了计算中间结果(如 $QK^T$ 和 Softmax)时产生的大量 IO。
  • 主要成果:

    • 理论上,DEFT 在 IO 复杂度上优于现有的注意力算法。
    • 实验证明,DEFT 能够减少 73-99% 的 KV 缓存 IO 和近 100% 的注意力计算中间结果 IO。
    • 在少样本提示、多步推理和推测解码这三种实际的树状工作负载上,与最先进的注意力算法相比,DEFT 实现了高达 2.23倍的解码延迟加速3.59倍的注意力计算延迟加速


图 1: 序列化解码与树状解码的图示。

表 1 展示了在一个推理任务中,基于序列的思维链(CoT)与基于树的思维树(ToT)在解码效率上的对比。ToT 生成的 token 数量远超 CoT,导致其端到端延迟和 IO(包括 KV 缓存 IO 和注意力计算中间结果 IO)开销巨大。

A3 背景知识与设计原则

3.1 预备知识

LLM 推理及其瓶颈。LLM 推理包括两个阶段:(1)预填充(prefill)和(2)解码(decoding)。在预填充阶段,模型处理输入的提示(prompt)以初始化。其输出成为解码阶段的输入。解码阶段是自回归的,前一步的输出 token 作为下一步的输入 token。由于自回归解码的顺序性,LLM 推理是内存密集型(memory-bound)的,每次前向传播都需要将所有模型参数和 KV 缓存从速度较慢但容量大的高带宽内存(HBM)传输到速度快但容量小得多的 GPU 共享内存中。另一个潜在瓶颈是 GPU 利用率低,当并行度(通常受限于批处理大小)远小于 GPU 上的流式多处理器(SMs)数量时,操作只会利用一小部分 GPU 资源。

GPU 上的注意力算法执行模式。我们可以将注意力算法的执行分为两个主要阶段:(1)QKV 准备阶段:将查询(Query)、键(Key)和值(Value)(QKV)逻辑上分组成分区,并将这些 QKV 组映射到 GPU 的不同流式多处理器(SMs)上;(2)注意力计算阶段:将 QKV 分区加载到不同 SM 的共享内存中,并对每个组应用注意力算法以得到最终的注意力结果。

分段注意力的 QKV 分区。在基于序列的解码中,当并行度(通常受限于批处理大小)远小于 GPU 上的流式多处理器(SMs)数量时,QKV 分区至关重要。为了实现高 GPU 利用率,Flash-Decoding【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】通过对查询 和 KV 缓存进行分区,然后并行计算注意力。具体细节如下:(1)QKV 准备阶段:对于批处理中的每个查询,将其顺序的 KV 缓存分割成块(chunks)作为不同的 QKV 分区。(2)注意力计算阶段:它分别计算三个分段上的注意力 $A_0, A_1, A_2$,然后通过在线 Softmax 合并【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】基于不 同 QKV 分区的分段注意力结果得到最终的注意力。具体过程如下:

  • 假设我们有键张量 $K \in R^{l_{kv} \times d}$,值张量 $V \in R^{l_{kv} \times d}$,以及查询张量 $Q \in R^{l_q \times d}$。在一般情况下,K 和 V 沿着序列(行)维度被划分为三个部分以进行并行计算:$K = K_0 \parallel K_1 \parallel K_2$ 和 $V = V_0 \parallel V_1 \parallel V_2$,其中“$\parallel$”表示沿行轴的拼接。
  • 我们在 GPU 的不同流式多处理器(SMs)中计算 KV 块上的注意力 $A_0, A_1, A_2$,其中 $A_0 = \langle Q, K_0, V_0 \rangle, A_1 = \langle Q, K_1, V_1 \rangle, A_2 = \langle Q, K_2, V_2 \rangle$,并且 $\langle q, k, v \rangle = \text{Softmax}(\frac{qk^T}{\sqrt{d}})v$。
  • 我们计算 LogSumExp (LSE) 作为合并 $A_0, A_1, A_2$ 的权重。我们定义 $\text{LSE}(q, k) = \log \sum \exp(\frac{qk^T}{\sqrt{d}})$。
  • 我们有 $\langle Q, K, V \rangle = \text{SegAttn}(A_0, A_1, A_2)$,这意味着通过在线 Softmax【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】进行分段注意力计算。

3.2 DEFT 概览

QKV 分区的重要性。对于树状解码,逻辑上对 QKV 进行分区对于实现高并行度的注意力计算是必要的。当树状结构 KV 缓存中的 token 数量很大时,由于内存容量限制,树状生成请求的分支数量可能不足以充分利用 GPU。例如,一个对128个数字进行排序的推理任务,在 Llama2-7B 模型中涉及约40K个 token,其 KV 缓存占用20GB,这意味着一个80GB的 A100 最多只能处理4个具有这样 token 数量的请求。

DEFT 的动机。DEFT 旨在解决 LLM 推理在处理树状结构 KV 序列时遇到的两个潜在瓶颈(即 IO 和 GPU 利用率)。如图2左侧所示的一个简单的两级级联树:对于两个查询 $Q_a$ 和 $Q_b$,对应的键满足 $K_a = K_0 \parallel K_1$ 和 $K_b = K_0 \parallel K_2$,值也遵循相同的规则。DEFT 的设计目标是:(1)通过消除对共享前缀 KV 缓存($K_0$ 和 $V_0$)的冗余内存访问来最小化 IO;(2)确保工作负载均衡以实现高 GPU 利用率,使得计算每个分段注意力 $A_i$ 的开销几乎相同。因为公式1中的全局归约需要所有部分注意力,如果计算 $A_i$ 的开销显著大于 $A_j$,负责计算 $A_j$ 的 SM 将会长时间空闲。

DEFT 技术概述。DEFT 旨在通过减少内存访问和确保树状解码的负载均衡,成为一种硬件高效的注意力算法。详细信息见图2:
1. QKV 准备阶段:为了实现前缀感知和负载均衡的 QKV 分区,我们引入了 KV-Guided Grouping 策略来复用共享前缀的 KV 缓存 IO,以及 Flattened Tree KV Splitting 策略,通过均衡并行的注意力计算实现高 GPU 利用率。详见第3.3节。
2. 注意力计算阶段:我们设计了 DEFT 注意力核,以内存高效的方式加载由 QKV 准备阶段逻辑分组的 QKV 分片,然后执行注意力计算。关键技术如下,详细信息推迟到附录A.4:
1. 通用核融合与分块策略:避免了对中间结果(即 $QK^T$ 和 Softmax)的显著 IO 操作,这是 Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】所缺乏的。
2. 树拓扑感知的全局归约:扩展了 Flash-Decoding【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】中的全局归约机制。该方法通过聚合来 自 QKV 组的部分注意力结果,并考虑树状结构,高效地计算每个查询的最终注意力。

DEFT 的系统框架。除了高效的 DEFT 注意力核,我们的 DEFT 系统还有另外两个优点:1)高效管理树状结构中的 KV 缓存,2)通过任意用户定义的函数灵活控制树解码过程,决定何时以及如何分支/剪枝。系统中关键组件及其协调的详细信息请参考附录A.1。


图 2: DEFT 概览。输入元数据在附录 A.1 详细阐述的系统中准备。在 QKV 准备阶段(见 3.3 节),QKV 将被逻辑地分组到分区中,同时兼顾共享前缀 KV 缓存的 IO 感知和负载均衡。这些分区将指导注意力计算阶段(见附录 A.4)的 QKV 加载,在该阶段将执行注意力计算。

A2 方法细节

3.3 前缀感知和负载均衡的树状结构 KV 缓存分区

从一个简单的解码例子开始。我们从一个使用图2所示的树状结构 KV 缓存的解码例子开始。如果我们将整个树状结构的 KV 缓存和查询分组成一个 G0 而不进行任何分区,我们可以参考图3(a)部分所示的 Vanilla Tree Attention 方法。这种方法在单个流式多处理器(SM)中,借助密集因果掩码(DCM),同时计算所有查询的注意力。然而,由于 GPU 利用率低,这种方法效率低下,正如3.2节所讨论的。为了解决这种低效问题,有效的 QKV 分区是必不可셔的。这个过程涉及两个关键考虑因素:(1)前缀感知以最小化对 KV 缓存的内存访问,(2)负载均衡以确保工作负载在 GPU 间的均匀分布。


图 3: DEFT-Node/Node-Chunk/Flatten 与不同注意力算法基线在 QKV 准备阶段的 QKV 分区策略比较。请注意,分区是逻辑设计的,不会产生任何 QKV 的数据移动成本。每个组所需的 GPU HBM 和共享内存之间的 IO 量已在红色矩形中突出显示。部分 (a) 展示了一个两级级联解码树示例的数据流和三类 QKV 分区策略:无分区(Vanilla Tree Attention)、Q-Guided Grouping 和 KV-Guided Grouping。分区策略将指导后续注意力计算阶段的 QKV 加载,其中每个 QKV 组 Gi 将被加载到 GPU 上的 SMi 中。部分 (b) 显示了 Q-Guided Grouping 和 KV-Guided Grouping 的比较,后者可以感知前缀 KV 缓存 KV0 的 IO,并只加载一次。DEFT-Node-Chunk 是 DEFT-Node 的一个弱负载均衡改进,通过将大节点(如 KV0)拆分为块。部分 (c) 展示了 DEFT-Flatten 中用于负载均衡分区的扁平化树 KV 切分细节(在备注 3.1 中讨论),包括深度优先扁平化策略、均匀分块策略和位掩码。有关基线和 DEFT 的摘要,请参见表 2。有关树注意力基线(Cai 等人,2024;Miao 等人,2023)的分析,请参见备注 3.2。

Q-Guided 与 KV-Guided Grouping。大多数现有的内存高效注意力算法【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】【54 ,Efficiently programming large language models using sglang,2023,arXiv preprint】采用 Q-Guided Grouping 进行 QKV 分区,其中每个查询作为分区的指示符,与其对应的 KV 缓存分组。然而,这种方法不是前缀感知的,例如,在 Flash-Attention 中(如图3所示),KV0 被加载了两次,即一次为 Qa,一次为 Qb。我们转而采用另一种 KV-Guided Grouping 方法:通过将每个节点的 KV 缓存与所有共享它的查询分组,可以使分区具有前缀感知能力,从而减少对 KV 缓存的内存访问。例如,DEFT-Node(如图3所示)在注意力计算中只加载一次前缀 KV 缓存 KV0。查询的额外 IO 成本可以忽略不计,因为每个查询只包含一个 token,而 KV 缓存可能包含数千个 token。

表 2: 基线(大部分如图3所示)和 DEFT 的 QKV 分区策略比较。对于 IO 冗余,显著问题用红色突出显示,可忽略的问题用蓝色突出显示。“Q” 指查询,“KV” 指 KV 缓存。“DCM” 代表密集因果掩码(一个矩阵),“BCM” 指位因果掩码(一组64位整数)。“PA” 代表注意力计算过程中的部分结果,包括 QKT、Softmax 等。更多 ⋆ 符号表示 QKV 分区的工作负载更均衡。IO 复杂度的详细信息可在附录 A.5 中找到。

Tree KV Splitting 与负载均衡。得益于 KV-Guided Grouping,DEFT-Node 对 KV 缓存 IO 是前缀感知的。然而,它引入了一个潜在的瓶颈:不同 SM 之间的工作负载不均衡。例如,在图3的 DEFT-Node 中可以看到,KV0 可能包含1000个 token,而 KV1 只包含2个 token。如果 G0 和 G1 分别分配给 SM0 和 SM1,SM1 会早得多完成计算并保持空闲,导致 SM 利用率低下。

一种直接的负载均衡方法。为了解决这个问题,我们需要更均匀地平衡 QKV 分区。一种直接的方法是在物理层面将 K0、K1 和 K2 分块,同时在逻辑层面保持节点级分区,如 DEFT-Node-Chunk(图3所示)。然而,这种负载均衡策略较弱:它只将大节点(例如,约1k token 的提示)分解为更小的 KV 块,而不能处理包含许多小节点的情况(例如,推测性解码),这可能由于需要更多轮次的 GPU 执行来处理额外的 QKV 组而减慢推理速度。

提出 DEFT-Flatten。由于 KV 缓存加载是注意力计算的主要瓶颈【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【35,Quest: Query-aware sparsity for efficient long-context llm inference,2024,arXiv preprint】,因此在每个 KV 缓存分区中实现均匀的 token 长度非常重要。因此,我们提出了 DEFT-Flatten,在备注 3.1 中详细阐述。

备注 3.1(扁平化树 KV 切分技术)。如图3(c)部分所示,有三个关键组成部分:

  • 深度优先扁平化策略。这种方法通过利用父子 KV 节点之间的层级关系,最小化了冗余的查询 IO 和计算。例如,父节点 KV0 的查询(如 Qa 和 Qb)包含了子节点 KV1 的查询(如 Qa)。与广度优先相比,深度优先扁平化最大化了来自不同节点但分配到同一块的 KV 缓存之间的查询重叠,减少了像 $QK^T$ 中被掩码部分的冗余计算。
  • 均匀分块策略。这是切分的核心,它确保每个 QKV 组中 KV 的长度相等,从而为 GPU 中的流式多处理器(SMs)提供均衡的工作负载。
  • 位掩码【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。它是一组64位整数,用于记录树中 token 的因果信息。因此,与密集因果掩码【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】相比,其 IO 开销(例如,图3(c)部分 KV-BCM1 中的两个64位整数)可以忽略不计。

备注 3.2(关于树注意力算法的讨论)。现有的注意力算法是为推测性解码设计的,其中注意力是为整个树状结构的查询计算的。然而,这些方法在内存效率上并不高。有关分区细节,请参见附录A.5中的图11。
* Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】。该方法基于 Vanilla Tree Attention(如图3左侧所示),使用 PyTorch 的通用矩阵乘法(GEMM)对 Q 和 KV 张量进行分区。它的内存效率不高,原因有二:(1)它没有利用 Flash-Attention 来减少计算中间结果(如 Softmax)时的内存访问;(2)它引入了一个密集的因果掩码,其内存访问量很大。
* Tree Attention-SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。该算法采用基于 Vanilla Tree Attention 的 Q-Guided Grouping,并通过 Flash-Decoding 对 KV 序列进行分区。它在内存效率上不高,因为会为每个查询冗余地加载整个树状结构的 KV 缓存。

IO 复杂度分析。我们证明了 DEFT-Flatten 在 IO 复杂度上优于现有的注意力算法,包括 Flash-Decoding【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】 、Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】和 Tree Attention-SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。详见附录 A.5。

实现细节。我们使用 OpenAI Triton【36,Triton: an intermediate language and compiler for tiled neural network computations,2019,Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages】实现 DEFT 注意力核,这使我们能够以线程块的粒度控制从全局内存到共享内存的内存访问和注意力计算。DEFT-Node 和 DEFT-Flatten 算法的两个阶段的 Python 风格实现分别见附录 A.8 和附录 A.9。

A4 实验环境

  • 硬件配置: NVIDIA A100 (80GB)。
  • 模型架构: Llama3-8B【38,Llama 2: Open foundation and fine-tuned chat models,2023b,arXiv preprint】。
  • 软件配置: DEFT 注意力核通过 OpenAI Triton【36,Triton: an intermediate language and compiler for tiled neural network computations,2019,Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages】实现。
  • 基线系统:
    • Flash-Decoding: SOTA 的序列解码注意力算法,采用非分页内存管理,通过 Triton 实现。
    • Tree Attention-Medusa: SOTA 的树状解码注意力算法之一,采用非分页内存管理,通过 PyTorch 实现。
    • Radix Attention: SOTA 的树状解码注意力算法之一,采用分页内存管理,通过 Triton 实现。
    • 我们没有将 SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】的树注意力算子作为基线,因为其内核最多只支持 token 树中的64个 token,不适用于具有树状结构 KV 的通用树状解码场景。

表 3: 基线与 DEFT 的比较。基线的注意力核实现与其内存管理方式相适应。因此,为了与基线进行公平比较,我们实现了同时适用于分页(Kwon et al., 2023)/非分页内存管理的 DEFT-Node 和 DEFT-Flatten。

  • 数据集/工作负载: 为了确保公平性,我们从真实的多步推理和推测解码任务中重建解码树作为工作负载,如表4所示。
    • 少样本提示: 使用 APPS 数据集【12,Measuring coding challenge competence with apps,2021,arXiv preprint】,将提示填充到4000个 token,执行400次迭代。
    • 多步推理: 使用了 Graph-of-Thoughts (GoT)【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】中的四项任务(128个数字排序、文档合并、关键词计数、集合交集),并从 ToT-BFS【47,Tree of thoughts: Deliberate problem solving with large language models,2023,arXiv preprint】交互记录中重建解码树。
    • 推测解码: 使用 APPS 数据集作为提示,解码树的拓扑结构来源于 Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】。

表 4: 工作负载生成。ToT-BFS 代表使用广度优先搜索的思维树(Yao et al., 2023)。APPS(Hendrycks et al., 2021)是一个竞争性编程问题数据集。Medusa(Cai et al., 2024)是一个推测性解码框架。“GoT” 代表思维图(Besta et al., 2023),其中包含使用 GPT-3.5 在 ToT-BFS 中进行复杂推理任务的迭代记录。更多细节见表 13。

A5 实验结果

内存管理与系统瓶颈分析

  • 实验内容: 分析不同内存管理策略(分页/非分页)对系统性能的影响,并确定不同配置下的性能瓶颈。
  • 实验结果与结论: 如图4所示,采用非分页(unpaged)内存管理时,系统的主要瓶颈是将树状 KV 缓存物化(materialize)为连续张量所需的数据移动开销,占总延迟的 69.1%-83.4%。而采用分页(paged)内存管理时,由于避免了大规模数据移动,注意力计算成为新的瓶颈,占总延迟的 51.1%-58.3%。这表明,在高效的内存管理下,优化注意力计算本身对于提升树状解码性能至关重要。


图 4: 推测性解码(32个查询的 token 树)的延迟分解,树拓扑来自 Medusa (Cai et al., 2024)。U 表示非分页内存。

解码延迟与 IO 对比

  • 实验内容: 在少样本提示、多步推理和推测解码三种任务上,比较 DEFT-Flatten 与基线方法的平均解码延迟和 IO。
  • 实验结果与结论 (表5):
    • 少样本提示: DEFT-Flatten 实现了 1.33倍的解码速度提升,这得益于 1.70倍的注意力计算加速和约 90% 的 IO 减少。
    • 推测解码: DEFT-Flatten 实现了高达 2.23倍的解码速度提升,注意力计算加速高达 3.59倍。这是因为所有查询(整个 token 树)可以共享长前缀的 IO。
    • 多步推理: DEFT-Flatten 的注意力速度提升高达 1.36倍,但解码加速不明显。原因有二:(1) 树宽度较窄(10)限制了 KV 缓存的复用;(2) 树中 token 数量较少,使得注意力计算仅占解码延迟的30%左右。
    • 总体而言,DEFT-Flatten 在所有测试场景中均优于包括 Radix Attention 在内的所有基线方法。

表 5: DEFT-Flatten 与基线在树状解码平均解码延迟(秒)上的比较。b 代表树的宽度,t 代表 token 树的大小(即树状结构查询的数量)。最快的方法用粗体表示,次快的方法用下划线表示。Radix Attention 是解码延迟方面的最佳基线。⋆ 表示 A100 80GB GPU 出现内存不足(OOM)错误。Speedup Upper-bound (no attention) 指的是如果我们排除注意力计算,只运行包括 MLP 在内的其他组件,Radix Attention 所能达到的最大加速比。更多关于注意力加速的细节,请参见表 16。

消融研究

  • KV 切分策略的影响 (表6):
    • 实验内容: 比较 DEFT-Node、DEFT-Node-Chunk 和 DEFT-Flatten 三种切分策略的注意力延迟。
    • 实验结果与结论: DEFT-Flatten 在所有树状结构设置中始终表现最佳。DEFT-Node-Chunk 通常优于 DEFT-Node,因为它能将大节点分块以实现更均衡的计算。然而,当存在大量小节点时(如推测解码 t=256),DEFT-Node-Chunk 会因产生过多的 QKV 组而需要更多轮 GPU 执行,导致性能下降。这证明了 DEFT-Flatten 的扁平化和均匀分块策略在实现负载均衡上的优越性。

表 6: [不同 KV 切分策略] DEFT-Node、DEFT-Node-Chunk 和 DEFT-Flatten 在 Llama3-8B 模型(GQA)上使用 NVIDIA A100(80GB)的平均注意力延迟(秒)比较。此表为表 16 的补充。最快的方法用粗体表示,次快的方法用下划线表示。Radix Attention 是解码延迟方面的最佳基线。更多基线的详细信息请参见表 16。

  • 提示长度的影响 (表7):
    • 实验内容: 在多步推理任务中,将提示长度从1k增加到10k,评估 DEFT-Flatten 相对于 Radix Attention 的性能变化。
    • 实验结果与结论: 随着提示长度增加,DEFT-Flatten 的加速效果更加显著(解码加速从1.09倍提升至1.67倍)。这是因为注意力计算的开销与解码树中的 token 总数成正比,而 FFN 的开销基本不变,因此更长的提示使得注意力成为更大的瓶颈,DEFT 的优化效果也更明显。

表 7: [不同提示长度] DEFT-Flatten 和 Radix Attention 在多步推理任务“排序”中的效率比较。原始提示长度约为 1K token,我们将其填充到 5K、8K 或 10K token 的长度。

  • 模型大小的影响 (表8):
    • 实验内容: 在 Codellama-7B 和 Codellama-34B 模型上比较 DEFT-Flatten 的性能。
    • 实验结果与结论: 在更大的 Codellama-34B 模型上,DEFT-Flatten 的解码加速比略有下降,但仍然显著(高达 1.78倍)。性能下降的原因是更大模型的隐藏维度更大,导致 FFN 的开销增加,从而降低了注意力/FFN 延迟比(A/F-LR)。

表 8: [不同模型大小] DEFT 和 Radix Attention 在 Codellama-34B 和 Codellama-7B 模型上的解码延迟加速比和注意力/FFN 延迟比(简称 A/F-LR)的比较。Radix Attention 是解码延迟方面的最佳基线。b 代表树的宽度,t 代表 token tree 的大小。对于多步推理,我们测试了提示长度约为 1k token 的排序任务。

A6 结论

本文提出了 DEFT-Flatten,一种为树状结构 LLM 推理优化的硬件高效注意力算法。它通过复用共享前缀的 KV 缓存和均匀分配工作负载,有效解决了内存访问和 GPU 利用率的瓶颈问题。DEFT-Flatten 的核心优势在于其前缀共享感知能力和负载均衡性,使其能够广泛适用于各种树状结构任务,并能很好地扩展到更大的搜索空间和更多的分支。实验结果表明,DEFT-Flatten 在解码和注意力延迟方面实现了高达 2.23倍/3.59倍的加速,在少样本提示、多步推理和推测解码等任务中均优于基线方法。我们的消融研究强调:(1)均衡的分区至关重要;(2)DEFT-Flatten 在各种 LLM 模型和 GPU 架构上都能提供显著的加速;(3)随着树状请求中 token 数量的增加(例如更长的提示)和分支数量的增多,DEFT-Flatten 的加速效果会更大。

A7 附录

A.1 DEFT 系统支持组件

图5左侧展示了高效灵活的树状解码中不同组件的协调工作。DEFT 系统组件的功能详情如下:

  1. 分支控制器 (Branch Controller): 它通过用户定义的函数强制执行树解码过程(例如,图5右侧示例中,每3次迭代分支为两个子节点)。基于树搜索的算法可在此处应用,利用解码树的拓扑信息。
  2. 序列树管理器 (Sequence Tree Manager): 它根据分支控制器的树操作和 token 维护解码树的拓扑。剪枝和分支等树操作将由该组件中的树处理器 (Tree Handler) 执行。分支结果存储 (Branch Result Storage) 将记录解码树中所有分支的 token 生成结果,并在解码停止时输出。
  3. KV 缓存管理器 (KV cache Manager): 它将维护一个树状结构的 KV 缓存。解码树中的序列 ID 与 KV 缓存索引之间的映射关系会被保留,并根据序列树管理器的 KV 操作进行更新。我们在此部分同时提供了分页式【21,Efficient memory management for large language model serving with pagedattention,2023,arXiv preprint】和非分页式内存管理,以适应不同的注意力核。
  4. 模型接口 (Model Interface): 将输入元数据传递给 DEFT 注意力核和 MLP 模块,然后返回 logits 和更新后 KV 缓存的内存指针。


图 5: DEFT 图解。(左) 系统概览。(右) 使用一个解码树示例展示 DEFT-Node 的数据流 (DEFT-Flatten 类似,仅 QKV 分区不同)。

图5右侧通过一个解码树示例进一步展示了系统的关键数据流。为简化起见,我们在此展示了 DEFT-Node,DEFT-Flatten 的流程类似,只是 QKV 分区方式不同。输入元数据将由我们上面提到的三个组件提取,然后在3.3节讨论的 QKV 准备阶段之后,以分组方式从 HBM 加载到共享内存。接着,QKV 组将由 DEFT 注意力核在 DEFT 的注意力计算阶段进行处理。关于这两个阶段技术的详细信息,请参阅附录A.4。

A.2 关于树状解码的讨论

(a) (左) 序列 KV 与树状查询用于并行解码【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】,其中应用因果掩码记录树状 token 查询间的因果信息。(右) 树状 KV 与并行查询用于多步推理中的共享前缀。


图 6: 关于使用树状查询【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】和树状 KV 的树状解码的讨论。

(b) SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】中的位掩码,用于记录树状结构中查询 token 间的因果信息。解码树位于 6a 的左侧部分。

树状解码的两种模式。树状解码可以采用树状结构的 KV 缓存来存储并感知共享前缀【54,Efficiently programming large language models using sglang,2023,arXiv preprint】,或者在并行/推测解码中采用树状结构的查询【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】,如图6所示。一个通用的解码过程可以同时处理树状 KV 和树状查询,这既可以减少共享前缀的冗余(例如 IO、存储、计算等),也可以增加每次解码迭代生成的 token 数量。

现有框架的局限性。现有的针对树状解码效率优化的推理框架【54,Efficiently programming large language models using sglang,2023,arXiv preprint】【9,Prompt cache: Modular attention reuse for low-latency inference,2023,arXiv preprint】主要目标是:(1)减少内存占用【54,Efficiently programming large language models using sglang,2023,arXiv preprint】以支持更大的批处理大小,从而提高吞吐量;(2)复用提示缓存【9,Prompt cache: Modular attention reuse for low-latency inference,2023,arXiv preprint】以避免 KV 缓存的重新计算,从而加快首个 token 的生成时间(TTFT)。然而,它们的设计并未专门针对减少整个解码过程的延迟。我们观察到,LLM 推理的树状结构特性可以为我们提供一些加速解码本身的优势。

树状解码的加速潜力分析。在树状解码中,KV 缓存和查询可以组织成树状结构。我们不仅可以以树状结构存储 KV 缓存,还可以在注意力计算期间以感知树拓扑的方式加载 QKV,以最小化 HBM 和 GPU 片上共享内存之间昂贵的 IO。我们通过两个复杂交互场景的案例研究来解释这一点:(1)多步推理;(2)推测解码。

案例研究1:多步推理。如图7左侧所示,多步推理【11,Reasoning with language model is planning with world model,2023,arXiv preprint】【47,Tree of thoughts: Deliberate problem solving with large language models,22,2023,arXiv preprint】【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】过程可概括为三个阶段:(1)思想生成:基于生成提示 $P_g$ 和先前步骤 $S$ 生成下一步的 k 个候选思想;(2)思想评估:LLM 作为状态评估器,基于评估提示 $P_e$ 评估先前的思想 $S$ 对解决问题的贡献,作为搜索算法的启发式信息;(3)基于树搜索的扩展:采用不同的搜索算法探索搜索空间。在阶段(1)和(2)中,我们都可以在树注意力计算期间共享 $P_g/P_e$ 和 $S$ 的 KV 缓存 IO。


图 7: 树状解码两个案例研究的分析。(左) 多步推理。(右) 推测解码。蓝色框表示在树注意力计算期间可在存储和内存访问中共享的过去 KV 缓存,而黄色框表示生成上下文的 KV 缓存。

案例研究2:推测解码。如图7右侧所示,推测解码【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】过程可概括为三个阶段:(1)Token 树生成:多个小型草稿模型或微调的头基于提示 $P$ 生成多个 token 序列,然后合并成一个推测的 token 树 $T_t$;(2)Token 验证:基于这些树状结构的候选 token $T_t$,对照 LLM 的输出验证其正确性,其中树注意力计算是该过程的瓶颈;(3)保留 KV 缓存。在阶段(2)中,我们可以在树注意力计算期间共享 $P$ 和 $S$ 的 KV 缓存 IO。

为什么现有的树注意力算法不够用? 现有的树注意力算法要么在内存访问上效率低下【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】,要么不适用于 token 树中超过64个 token 的通用树状解码【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】。
* SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】中,如图6b所示,使用位掩码记录 token 树查询间的因果信息。每个查询 token $t_i$ 都有一个64位整数作为位掩码,其中第j位表示 $t_i$ 的查询与 $t_j$ 的 KV 缓存之间的因果关系。这种设计的优点是大大减少了 IO,但导致树中 token 的最大数量仅为64,这对于具有树状结构 KV 缓存的场景不实用。此外,它对 KV 缓存的 IO 不敏感,因为它会为每个查询加载整个树的 KV 缓存。
* Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】适用于通用的树状解码,但由于密集的因果掩码和注意力计算过程中的部分结果(如 softmax)产生的大量 IO,其硬件效率不高。

A.3 关于同期工作的讨论

与同期工作的共同点与不同点。有一些同期工作【3,Bifurcated attention for single-context large-batch sampling,2024,arXiv preprint】【48,Chunkattention: Efficient self-attention with prefix-aware kv cache and two-phase partition,2024a,arXiv preprint】【19,Hydragen: High-throughput llm inference with shared prefixes,2024,arXiv preprint】【55,Relayattention for efficient large language model serving with long system prompts,2024,Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】关注于单上下文大批量采样场景的注意力算法设计,其目标是从单个上下文(如系统提示或少样本示例)生成多个序列,这是树状解码深度为1的特例。它们的算法设计基于这一特性,这意味着它们无法高效地适用于具有超过两级前缀的树的注意力计算。

共同的见解和技术。同期工作和 DEFT 都认识到内存访问是 LLM 推理的瓶颈,并通过分解注意力来减少前缀 KV 的内存访问:(1)分别计算前缀和后缀上的注意力 $A_p$ 和 $A_s$;(2)基于 $A_p$ 和 $A_s$ 通过在线 softmax 合并【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】得到最终的注意力。其正确性证明如下 :
- 设键张量 $K \in R^{(l_{kv},d)}$,值张量 $V \in R^{(l_{kv},d)}$,查询张量 $Q \in R^{(l_q,d)}$。考虑一般情况,K 和 V 沿序列(行)维度被划分为前缀和后缀两部分:$K = K_p \parallel K_s$ 和 $V = V_p \parallel V_s$。
- 我们分别计算前缀和后缀上的注意力 $A_p = \langle Q, K_p, V_p \rangle$ 和 $A_s = \langle Q, K_s, V_s \rangle$。
- 根据公式1,我们可以得到分段注意力 $\langle Q, K, V \rangle = \text{SegAttn}(A_p, A_s)$。

表 9: DEFT 与同期工作在单上下文大批量采样场景中的比较,包括 Chunk-Attention (Ye et al., 2024a)、Hygragen (Juravsky et al., 2024) 和 Bifurcated-Attention (Athiwaratkun et al., 2024)。RelayAttention (Zhu et al., 2024) 和 Cascade-inference (Ye et al., 2024b) 类似于 Hygragen。更多的 ⋆ 表示树切分后工作负载更均衡,也显示了加速对树拓扑的不敏感程度。

差异比较。现有的单上下文大批量采样工作对于通用的树状解码在硬件上不够高效,原因有二,如表9所示:
- 仅为两级树设计。它们专为只有两层的解码树设计——根部的所有前缀和深度为1的所有后缀。对于具有多级前缀的解码树,它们的算法只能减少树根部提示的 IO。然而,在多步推理等场景中,非根前缀的 token 长度也可能很长(例如,数千个 token),而它们的 KV 缓存 IO 并未被复用。DEFT 可以复用通用解码树中所有非叶前缀的 KV IO,提供了更大的加速潜力。
- 未解决负载不均衡问题。它们没有解决树状解码中的工作负载不均衡问题。解码树中的节点大小差异可能很大,因此以一种能确保每个 QKV 组计算均衡的方式切分树和分组 QKV 至关重要。仅仅根据深度进行划分是不够的。

A.4 高效注意力算法设计中的技术讨论

本小节总结并讨论了现有高效注意力算法和核设计中的常用技术,并解释了 DEFT 注意力核的设计细节。

表 10: DEFT 的技术列表。我们提出的技术用红色标出。前四项技术的细节在第 3.3 节,后续技术的细节在本章讨论。

核融合与分块策略核融合 (Kernel Fusion) 是一种常见的 IO 减少技术:如果多个操作在相同的输入上执行,那么从 HBM 一次性加载输入比为每个操作多次加载更高效。为了将所有注意力操作融合到一个 GPU 核中,我们进一步利用了常用的 分块策略 (Tiling strategy)【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】【7,Flash-decoding for long-context inference,2023,https://pytorch.org/blog/flash-decoding/】:将每 个 QKV 组内的查询和 KV 缓存分割成小块,通过在有限的共享内存中计算注意力,避免了在 HBM 中物化注意力矩阵,然后根据公式1增量地执行 softmax 归约来重构注意力。
备注 A.1: Tree Attention in Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】没有采用分块和核融合,导致了中间结果(如 $QK^T$ 和 Softmax)的大量 IO,如图8所示。而 Flash Decoding、Tree Attention in SpecInfer 和我们的 DEFT 都采用了分块和融合核。


图 8: Tree Attention-Medusa (Cai et al., 2024) 的操作。未应用核融合或分块策略,这导致了像 QK⊤、DCM 和 Softmax 等部分结果在 GPU 全局内存和片上共享内存之间的显著 IO。

树拓扑感知的因果掩码因果掩码 (Causal Mask) 在推测解码工作【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】中被引入,通过记录解码树中查询和 KV 缓存之间的因果关系,实现在单个 GPU 核内计算整个解码树的注意力。
备注 A.2: 因果掩码带来了两部分冗余:
- 内存访问:Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】在 HBM 中物化了密集的因果掩码(DCM),带来了巨大的 IO 成本。SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】引入了位因果掩码(BCM),IO 成本极小,但仅限于64个 token。DEFT-Flatten 采用了受 SpecInfer 启发的位因果掩码,以最小化掩码的 IO。
- 计算:除了生成掩码本身的计算成本外,许多 $QK^T$ 的矩阵乘法结果被掩码掉而从未使用,造成了计算冗余。

切分与全局归约切分 (Splitting) 被引入以提高 GPU 在序列解码中的利用率【16,Flashdecoding++: Faster large language model inference on gpus,2023,arXiv preprint】。Flash-Decoding 首先基于 Q 对长 KV 进行切分和分组,然后将这些组分配给不同的 SM 计算部分注意力。为了获得准确的最终注意力,需要对来自具有相同查询的 QKV 组的部分注意力进行全局归约 (Global Reduction)。类似地,DEFT 也将解码树切分成不同的 QKV 组(即我们提出的扁平化树 KV 切分策略),以实现 SM 的负载均衡。DEFT 也需要一个全局归约,但 Flash-Decoding 的归约是为序列解码设计的,不能感知树拓扑。因此,我们提出了树拓扑感知的全局归约,如图10b所示。

DEFT 注意力核设计。基于上述技术,我们设计了具有两个阶段的 DEFT 注意力核,如图9所示。
- 阶段1:计算部分注意力。基于 QKV 准备阶段的分组结果,每个 QKV 组 ($G_i$) 被分配到一个线程块,使用 Flash Attention【6,Flashattention: Fast and memoryefficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】计算部分注意力 ($PA_i$) 和 LogSumExp ($LSE_i$)。
- 阶段2:全局归约。DEFT 执行树拓扑感知的全局归约,根据树拓扑逻辑上重映射部分注意力和 LogSumExp,从而为每个查询获得正确的最终注意力。


图 9: DEFT 注意力核两个阶段的概览(以 DEFT-Node 为例,DEFT-Flatten 类似)。阶段 1 – 计算部分注意力。基于上述 KV-Guided Grouping 策略和 Tree Split 后的 QKV 分组结果,每个 QKV 组 (Gi) 将被分配到一个线程块,用于使用通用的核融合和分块策略进行 Flash Attention (Dao et al., 2022) 计算。与 Flash-Decoding (Dao et al., 2023) 类似,我们不仅得到部分注意力 (PAi),还返回 “LogSumExp” (LSEi) 作为下一阶段归约的权重参数。阶段 2 – 全局归约。在收到每个 QKV 组 Gi 的 PAi 和 LSEi 后,DEFT 现在执行树拓扑感知的全局归约 (DeFT_reduction)。在解码树中 KV 序列节点间的树拓扑指导下,DEFT 逻辑上重映射注意力和 LogSumExp 的部分结果,以在归约后为每个查询获得正确的最终注意力。解码树与图 3 左侧的相同。SMi 表示 GPU 中的流式多处理器 i。

(a) 左:DEFT 注意力核两阶段图示。右:图 10b 中 DEFT 阶段 2 调用的全局归约核。QKV 组 G0、G1 和 G2 来自图 3 中的 DEFT QKV 组。


图 10: DEFT 核的详细注意力操作(以 DEFT-Node 为例,DEFT-Flatten 类似)。基于图 3 中的相同解码树。

(b) DEFT 阶段 2:全局归约(以 DEFT-Node 为例)。基于图 3 中的树拓扑,我们可以根据查询对 LogSumExp 和部分注意力进行分组,然后调用图 10a 右侧的全局归约核来获得最终的注意力。

A.5 DEFT 的 IO 复杂度分析

分析设定。本节分析 DEFT 的 IO 复杂度,显示其相比现有注意力算法在 HBM 访问上显著减少。我们基于单次迭代中的解码树快照来比较 IO,使用的符号如表11所示。

表 11: 符号表。

IO 复杂度对比 (表12)
- 序列解码方法 (Naive Attention, Flash-Decoding): 由于缺乏树拓扑感知,其 KV 缓存的内存访问开销比 DEFT 和 Tree Attention-Medusa 高出 $F_s$ 倍。
- Tree Attention-Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】: 尽管 KV 缓存 IO 较少,但由于缺乏分块和核融合,其在部分结果(如 $QK^T$ 和 Softmax)上的 IO 开销很高。此外,其引入的密集掩码也带来了显著的 IO 成本。
- Tree Attention-SpecInfer【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】: 如图11所示,该方法采用 Q-Guided Grouping,导致每个查询都需要加载整个树的 KV 缓存,KV 缓存的 IO 开销巨大。
- DEFT-Flatten: 通过 KV-Guided Grouping 和核融合,DEFT-Flatten 避免了 KV 缓存和部分结果的冗余 IO。其采用的位掩码 IO 成本极低,使其在所有方法中具有最低的 IO 复杂度。


图 11: (图 3 补充) Tree Attention (Cai et al., 2024; Miao et al., 2023) 的 QKV 分区和内存访问。Tree Attention-Medusa (Cai et al., 2024) 通过 PyTorch 中的通用矩阵乘法 (GEMM) 对 QKV 进行分区。Tree Attention-SpecInfer (Miao et al., 2023) 采用 Q-Guided Grouping。QBCM 是 SpecInfer 的 Q-Guided 位因果掩码,其中每个位表示查询与 KV 缓存中一段 token 之间的因果关系。例如,Qa 的 Q-BCM 是 "110",意味着 KV 缓存的前两段 KV0 和 KV1 对 Qa 的注意力有效。图中的 Qi 和 Kj 与图 3 中的相同。

表 12: 各种方法的 IO 复杂度分解。O(1) 表示张量中单个数据在所有层和头上的 IO 成本,相当于 #heads * #layer * dtype_size。表中所有方法中的最佳者用红色标出,(潜在的)最差者用蓝色标出。查询 IO 被省略,因为它对所有方法都是 O(k * ln * d_head),其中 k 是 QKV 组的数量。对于 DEFT-Node,k = #node;对于 DEFT-Node-Chunk,k = Σ(#node_i=1) ceil(ni/bs);对于 DEFT-Flatten,k = N_tree / bs。Tree Attention-S 中的 S 代表 SpecInfer (Miao et al., 2023)。

A.6 工作负载生成讨论

工作负载设置的合理性。为了验证 DEFT 在不同解码树拓扑下的加速效果,我们从真实任务中编译了解码树,涵盖以下三个方面:
- 少样本提示:这是一个两级树,包含一个提示前缀和多个用于后缀生成的分支。作为案例研究,我们将提示长度固定在约4000个 token,并改变分支数量。
- 多步推理【47,Tree of thoughts: Deliberate problem solving with large language models,22,2023,arXiv preprint】【11,Reasoning with language model is planning with world model,2023,arXiv preprint】【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】: 我们记录了真实推理任务交互【4,Graph of thoughts: Solving elaborate problems with large language models,2023,arXiv preprint】中的树形、提示和所有思想的长度,并以此为指导进行树解码,以验证 DEFT 在推理的思想生成阶段的加速效果。生成细节见图12。
- 推测解码【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】【28,Specinfer: Accelerating generative llm serving with speculative inference and token tree verification,2023,arXiv preprint】: 我们使用了 Medusa【5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,arXiv preprint】的 token 树拓扑,并记录了与 APPS【12,Measuring coding challenge competence with apps,2021,arXiv preprint】的真实交互数据,包括每一步接受的 token 长度。这作为指导来模拟推测解码的瓶颈——token 验证阶段的注意力计算。


图 12: 重建多步推理树模板的详细过程。(左) 从实际推理记录中重建推理树,如 (Besta et al., 2023) 所述,涉及捕获以下方面:(1) 树的结构,以其深度 d 和宽度 w 为特征;(2) 与每个思想相关的 token 长度;(3) 每个深度的最佳思想及其对应的分数。对于文档合并任务,树的深度设置为 d=3,每个深度的宽度为 w=10。对于排序 128 个数字,深度减少到 d=10,同时保持相同的宽度 w=10。有关其他多步推理任务的树拓扑详细信息,请参见表 13。(右) 利用从左侧提取的思想信息,我们可以生成用于解码的树模板,包括分支记录和剪枝记录。这些记录有助于指导树解码过程,以生成忠实复制思想树结构的解码树。

实验范式的合理性。我们的实验范式包括:首先,从真实的树状解码任务中获取解码树;其次,通过强制 LLM 推理在同一框架内精确复制这些解码树,以研究注意力加速对实际运行时间性能的影响。这种范式有两个优点:
- 我们可以在一个统一的系统中使用来自真实任务的解码树作为基准,从而能够在解码延迟方面对不同的注意力算法进行公平比较。
- 我们既考虑了具有不同树结构的任务的独特性,也考虑了通用树状解码的更广泛适用性。

表 13: 生成的工作负载详情。对于多步推理,我们包括了来自 Besta et al. (2023) 的这 4 个任务:(1) 排序 128 个数字(简称 sorting);(2) 文档合并(简称 document);(3) 关键词计数(简称 keyword);(4) 集合交集(简称 set)。d 和 w 分别表示树的深度和宽度。t 表示推测性解码的 token 树大小,其树拓扑来自 Medusa (Cai et al., 2024)。

A.7 附加结果

DEFT-Node 的 GPU 利用率微基准测试。如表14所示,我们测试了 DEFT-Node 和 DEFT-Flatten 在一个有64个查询和4k token 提示的推测解码任务上的表现。对于 DEFT-Node,QKV 分区是不均衡的。结果显示,DEFT-Flatten 在内存利用率(内存吞吐率)和计算利用率(计算吞吐率和低利用率时间比)方面都优于 DEFT-Node。

表 14: [GPU 利用率微基准测试] 在 NVIDIA A100 (80GB) 上使用 LLama3-8B 模型 (GQA) 时,DEFT 单层注意力的延迟(μs)、SM 计算吞吐率、内存吞吐率和低利用率时间比。工作负载为具有 64 个查询和 4k token 提示的推测性解码。计算吞吐率指的是 GPU 中流式多处理器(SM)的利用率。内存吞吐率表示实际内存吞吐量与最大带宽的比率。低利用率时间比定义为计算吞吐率低于 5% 的时间比例。

注意力延迟和 IO 分解。注意力延迟和 IO 的详细比较分别见表16和表17。

推理准确性。如表15所示,DEFT 的注意力分数与 Huggingface Transformers 中的原始注意力相比可能略有差异(约0.4%的相对误差),但生成的 token 和 PPL(困惑度)几乎没有差别。这种差异是由于 GPU 上的浮点运算不遵循结合律造成的,这在其他引入在线 Softmax 和归约的方法中也很常见。

表 15: DEFT 在注意力分数和困惑度(PPL)方面的推理准确性。PPL 是在 400 次解码迭代后计算的。Vanilla Attention 是 Huggingface Transformers 的实现。

动态行为:逐迭代延迟。图13展示了在多步推理任务“排序”中,DEFT-Node 和 DEFT-Flatten 的逐迭代延迟。我们观察到,DEFT-Node 相对于 DEFT-Flatten 的加速比与树节点大小的离散程度呈强正相关。这是因为 DEFT-Flatten 的性能相对稳定,而 DEFT-Node 的性能受树拓扑影响更大。


图 13: 在排序任务中比较切分策略 DEFT-Node 和 DEFT-Flatten。加速比指的是 DEFT-Node 和 DEFT-Flatten 的每次迭代延迟之比。树节点长度标准差表示每次迭代的树节点长度的标准差。

消融研究:解码树宽度的影响。如图14所示,我们固定提示长度为4000,改变解码树的宽度。随着树宽度的增加,DEFT-Flatten 的解码加速效果更显著(从宽度10的几乎无加速到宽度50的1.33倍加速),因为更宽的树意味着提示前缀的 KV 缓存 IO 被更频繁地复用。


图 14: 不同树宽度的少样本提示任务的每次迭代延迟。e2e 表示解码延迟(最优端到端延迟),而 Attn 表示仅注意力开销。

消融研究:KV 切分中块大小的影响。如图15所示,块大小的选择是在 IO 冗余和线程块调度之间的权衡。更大的块大小意味着查询 IO 冗余更少,但可能因线程块较少而导致 GPU SM 空闲。结论是:(1)最佳块大小受序列长度和查询数量的影响;(2)DEFT-Flatten 在所有测试的块大小上都优于 DEFT-Node-Chunk。


图 15: DEFT 的 KV 块大小消融研究。t 是推测性解码中的 token 树大小。

消融研究:提示长度、不同GPU、模型架构的影响。图16-18显示,随着提示长度增加,DEFT 的优势更明显。表19显示 DEFT 在 RTX 4090 上也有明显加速。表20和表21显示 DEFT 对 MHA 和 GQA 架构的模型都能显著加速。

表 16: 树状解码的平均注意力延迟(秒)及其对解码延迟的影响。b 代表树宽,t 代表 token 树大小。Attention Speedup over the best attention 指的是 DEFT-Flatten 相对于最佳基线(通常是 Tree Attention-Medusa)在注意力计算上的加速比。Radix Attention 是解码延迟的最佳基线。注意,KV 缓存管理不包含在注意力延迟内。⋆ 表示 A100 80GB GPU 内存不足。更多解码延迟细节见表 5。

表 17: 解码期间的平均端到端 IO (TB)。数据格式为左/右:(左) KV 缓存 IO;(右) 部分结果 IO,包括 QKT、QK⊤/sc、Mask M、M + QK⊤/sc 和 Softmax。b 表示树宽。t 表示 token 树大小。⋆ 表示 A100 80GB 内存不足。


图 16: 推测性解码中不同提示长度下 DEFT 的每输出 token 时间 (TPOT)。


图 17: 推测性解码中不同提示长度下 DEFT 的解码延迟。


图 18: 推测性解码中不同提示长度下 DEFT 的注意力延迟。

表 18: [模型大小和提示长度的消融研究] 在排序推理任务中,不同提示长度下,DEFT 和 Radix Attention 在 Codellama-34B 和 Codellama-7B 上的解码加速比和注意力/FFN 延迟比 (A/F-LR) 的比较。Radix Attention 是解码延迟方面的最佳基线。

表 19: [不同 GPU] 在 NVIDIA RTX 4090 (24GB) 上使用 LLama3-8B 模型 (GQA) 时,DEFT 的平均注意力延迟(秒)加速。Radix Attention 是解码延迟方面的最佳基线。

表 20: [不同模型架构 (GQA)] 在 NVIDIA A100(80GB) 上使用 Codellama-34B 模型 (GQA) 时,DEFT 的平均注意力延迟(秒)加速。Radix Attention 是解码延迟方面的最佳基线。

表 21: [不同模型架构(MHA)] 在 NVIDIA A100(80GB) 上使用 Codellama-7B 模型(MHA) 的 DEFT 平均注意力延迟(秒)加速。Radix Attention 是解码延迟方面的最佳基线。

A.8 DEFT-Node 算法

算法1 DEFT-Node 算法-阶段1:QKV 准备

<blockquote>

输入: 查询 $Q \in R^{(b_q, d)}$,键缓存列表 $KL = (K_0, ..., K_{N-1})$,值缓存列表 $VL = (V_0, ..., V_{N-1})$(对应树中每个序列节点),树 $T$ 及其拓扑信息。
处理:
1. 对于 $Q$ 中的每个查询 $q$,获取其所有前缀的 KV 索引,存入 QMapKV
2. 对于每个序列的 KV 缓存 $K_i, V_i$,将其与所有共享它的查询进行分组,得到 $Q_i$,存入 KVMapQ
返回: QMapKV, KVMapQ

</blockquote>

算法2 DEFT-Node 算法-阶段2:注意力计算

<blockquote>

输入: 查询 $Q$, 键缓存列表 $KL$, 值缓存列表 $VL$, 树 $T$, 以及从阶段1得到的 QKV 分组信息 QMapKV, KVMapQ
处理:
1. 阶段1:计算部分注意力
- 为每个 QKV 组 $(Q_i, K_i, V_i)$ 并行调用 FlashAttention,得到部分注意力 $o_i$ 和 LogSumExp $lse_i$。
- 将每个 $o_i, lse_i$ 根据其对应的原始查询索引映射回全局存储。
2. 阶段2:全局归约
- 对于每个查询 $q$,在收集完其所有相关的部分注意力和 LogSumExp 后,执行全局归约(类似在线 Softmax 合并),计算出最终的注意力结果 $FO[idx]$。
返回: 最终的注意力输出 $FO$

</blockquote>

A.9 DEFT-Flatten 算法

DEFT-Node 采用节点粒度的切分策略,简单但可能导致负载不均衡。因此,我们提出 DEFT-Flatten,以更均衡的子树粒度进行切分。

DEFT-Flatten 算法-阶段1:QKV 准备

<blockquote>

输入: 与 DEFT-Node 相同,额外增加子树大小 $S_t$。
处理:
1. 均匀切分树 KV 缓存: 将整个树的 KV 缓存按深度优先顺序扁平化,然后均匀地切分为多个子树,每个子树最多包含 $S_t$ 个 token。得到子树信息 SubInfo,子树键缓存 KSub,子树值缓存 VSub
2. 分组与掩码生成:
- 对于每个子树的 KV 缓存,将其与所有共享它的查询分组,得到 $Q_i$。
- 为每个子树生成一个位因果掩码 CausalMask,用于记录子树内不同节点与查询之间的因果关系。
返回: QMapKV, KVMapQ, CausalMask, SubInfo

</blockquote>

DEFT-Flatten 算法-阶段2:注意力计算

<blockquote>

该阶段与 DEFT-Node 类似,但有关键区别:
1. 阶段1:计算部分注意力
- 在调用 FlashAttention 之前,需要根据 CausalMaskSubInfo 重构出适用于当前子树的注意力掩码
- 其余部分与 DEFT-Node 相同,计算每个子树 QKV 组的部分注意力 $o_i$ 和 LogSumExp $lse_i$。
2. 阶段2:全局归约
- 与 DEFT-Node 完全相同,对每个查询收集到的所有部分结果执行全局归约。
返回: 最终的注意力输出 $FO$

</blockquote>