Linear Attention

韩广云,NVIDIA GPU 加速计算专家团队 高级工程师 | AI Open Day/2025-11-07

目录

议程 (Agenda)

线性注意力 (Linear Attention)

当前状态 (Current Status)

基础 (Basics)

并行性 - 回归 Flash Attention (Parallelism - Back to Flash Attention)

Page 6, Flash Attention 调度示意图
Page 6, Flash Attention 调度示意图

并行性 - 完全融合的线性注意力 (Parallelism - Fully Fused Linear Attention)

Page 7, 完全融合线性注意力调度示意图
Page 7, 完全融合线性注意力调度示意图

算法分析 (Algorithm Analysis)

Page 8, 算法分析的公式和表格
Page 8, 算法分析的公式和表格

Hopper GPU 上的完全融合线性注意力 (Fully Fused Linear Attention On Hopper GPUs)

Page 9, Hopper GPU 上的流水线示意图
Page 9, Hopper GPU 上的流水线示意图

上图展示了一个理想的指令顺序,旨在实现高效的流水线操作,包括数据加载(Load Q, K, V)、数学计算(Math WG1, WG2)和结果存储(Store O)。图中展示了 Acquire/Release 各种 pipe(Q, K, V, O pipe)以协调数据流,最终计算出完整的 tile。

Hopper GPU 上的完全融合线性注意力 - 实现细节

Page 10, Hopper GPU 实现细节图
Page 10, Hopper GPU 实现细节图

Hopper GPU 上的完全融合线性注意力 - 映射到硬件的细节

Page 11, WGMMA 配置和硬件映射细节
Page 11, WGMMA 配置和硬件映射细节

基准测试 (Benchmarks)

Page 12, 性能基准测试结果
Page 12, 性能基准测试结果

并行性问题 (Parallelism Problem)

并行性问题 - 上下文并行 (Parallelism Problem - Context Parallelism)

扩展至 Delta Rule (Extension to Delta Rule)

基础 (Basics)

Page 15, Delta Rule 公式
Page 15, Delta Rule 公式

全融合Delta法则 (Fully Fused Delta Rule)

在Hopper GPU上的实现

下图展示了在Hopper GPU上实现全融合Delta法则的时间线图。该图解了不同工作组(Math WG1, Math WG2)中各种计算任务(如加载Q/K/V,计算T=KK,O1=QS,V-SK等)的并行与依赖关系。

Page 16: 全融合Delta法则在Hopper GPU上的执行流程图
Page 16: 全融合Delta法则在Hopper GPU上的执行流程图

针对下三角矩阵的矩阵求逆

下图展示了分块求逆的过程,其中绿色部分代表原始矩阵,浅绿色部分代表求逆后的矩阵。

Page 17: 下三角矩阵的分块求逆示意图
Page 17: 下三角矩阵的分块求逆示意图

3阶段流水线 (3-Stages Pipelining)

下图展示了3阶段流水线的执行流程,对比了不同工作组(Aux Math WG1, Math WG2, and WG3)的任务调度。

Page 18: 全融合Delta法则的3阶段流水线示意图
Page 18: 全融合Delta法则的3阶段流水线示意图

基准测试 (Benchmarks)

下表展示了在固定序列长度和固定批大小两种情况下的基准测试结果,对比了fla和我们(Ours)的实现。

Page 19: 基准测试结果表格


<font size="1">1. For technical discussion and reference only, perf. may vary based on different product portfolio.</font>
<font size="1">2. Flash attention performance is tested with version v2.5.3 commit 49b3c3b</font>
<font size="1">3. Our kernel is still in development</font>

开放性问题 (Open Question)