FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

文章标题: FlashAttention-2:通过更好的并行性和工作分区实现更快的注意力机制
作者: Tri Dao
机构: 1. 普林斯顿大学计算机科学系 2. 斯坦福大学计算机科学系

A1 主要贡献

本文的核心问题是解决 Transformer 模型在处理长序列时的性能瓶颈。注意力层的运行时间和内存占用随序列长度呈二次方增长,这限制了模型处理长文档、高分辨率图像或长视频的能力。尽管 FlashAttention 通过优化 GPU 内存使用,将内存占用从二次方降低到线性,并实现了2-4倍的加速,但其计算效率(FLOPs/s)仍远低于优化后的矩阵乘法(GEMM)操作,仅达到理论峰值的25-40%。

本文通过分析发现,FlashAttention 效率不高的主要原因是 GPU 上不同线程块和 warp 之间的工作分区不理想,导致了低占用率或不必要的共享内存读写。

为解决上述问题,本文提出了 FlashAttention-2,其研究目标是通过改进并行策略和工作分区来进一步提升注意力计算的效率。主要创新点和贡献如下:

  1. 算法调整以减少非矩阵乘法浮点运算:对 FlashAttention 算法进行了微调,减少了非矩阵乘法(non-matmul)的浮点运算次数。由于 GPU 上的专用计算单元(如 Tensor Cores)使得矩阵乘法的吞吐量远高于非矩阵乘法(可达16倍),这一优化能让计算时间更多地用于高效的矩阵乘法上。
  2. 增强并行性以提升GPU占用率:除了在批次大小和头数量维度上进行并行化,FlashAttention-2 还增加了沿序列长度维度的并行化。这在处理长序列(此时批次大小通常较小)的场景下,能显著提高 GPU 资源的利用率(即占用率)。
  3. 优化线程块内部工作分区:在每个线程块(thread block)内部,重新设计了不同 warp 之间的工作分配方式,以减少它们之间通过共享内存进行的通信和数据读写。

这些改进使得 FlashAttention-2 相比于 FlashAttention 实现了约2倍的速度提升,在 A100 GPU 上的前向传播计算效率达到了理论峰值的50-73%。在端到端的 GPT-style 模型训练中,每块 A100 GPU 的训练速度高达 225 TFLOPs/s,模型 FLOPs 利用率达到72%。

A3 背景知识

2.1 硬件特性

2.2 标准注意力实现

2.3 FlashAttention

A2 方法细节

我们描述了 FlashAttention-2 算法,它包含了对 FlashAttention 的几处调整以减少非矩阵乘法 FLOPs。然后,我们描述了如何在不同的线程块上并行化计算以充分利用 GPU 资源。最后,我们描述了在一个线程块内如何在不同的 warp 之间划分工作以减少共享内存的访问量。这些改进带来了2-3倍的加速,这在第4节中得到了验证。

3.1 算法

3.1.1 前向传播

3.1.2 反向传播

3.2 并行性

3.3 Warp之间的工作分区

A4 实验环境

A4 实验结果

注意力基准测试

端到端性能

A5 结论

FlashAttention-2 比 FlashAttention 快2倍,这意味着现在训练一个16k上下文长度的模型的成本与之前训练一个8k上下文长度的模型相当。这一进步有望推动模型在理解长篇书籍报告、高分辨率图像、音频和视频等领域的应用。同时,FlashAttention-2 也将加速现有模型的训练、微调和推理过程。

未来工作展望
1. 扩展硬件和数据类型支持:计划与研究人员和工程师合作,将 FlashAttention 推广到不同类型的设备(如 H100 GPU、AMD GPU)和新的数据类型(如 FP8)。
2. 针对 H100 的深度优化:下一步计划是优化 FlashAttention-2 以利用 H100 GPU 的新硬件特性(如 TMA、第四代 Tensor Cores、FP8)。
3. 结合高级算法:将 FlashAttention-2 的底层优化与高级算法(如局部注意力、扩张注意力、块稀疏注意力)相结合,可能使我们能够训练上下文更长的 AI 模型。
4. 提升可编程性:与编译器研究者合作,使这些优化技术更易于编程实现。

方法细节中的引用汇总