FlashInfer-Bench: Building the Virtuous Cycle for AI-Driven LLM Systems

作者/机构: Shanli Xing, Yiyan Zhai, Alexander Jiang, Yixin Dong, Yong Wu, Zihao Ye, Charlie Ruan, Yingyi Huang, Yineng Zhang, Liangsheng Yin, Aksara Bayyapu, Luis Ceze, Tianqi Chen


A1 主要贡献

核心问题:尽管大型语言模型(LLM)作为能够生成GPU内核的自主代理展现了前景,但将这些由AI生成的内核集成到实际的推理系统中仍然充满挑战。

研究目标:为了弥合AI生成与实际部署之间的鸿沟,本文提出了FlashInfer-Bench,一个标准化的闭环框架,旨在连接内核的生成、基准测试和部署。

创新点
* FlashInfer Trace:提出了一种用于标准化描述AI生成工作负载的任务、负载和解决方案的规范。
* FlashInfer-Bench Dataset:整理了一个基于真实世界工作负载的数据集,为评估AI生成的内核提供了一个丰富的平台。
* 实用的操作流程:提出了一个持续生成AI内核并将其直接应用到实际生产系统中的工作流。
* 综合分析:对LLM生成的内核在LLM系统上的表现进行了全面的分析。

FlashInfer-Bench的整体架构如图1所示。其核心是FlashInfer Trace,它为指定内核契约和语义、以及交流实现和评估结果提供了一个标准模式。FlashInfer-Bench Dataset则整理了生产环境中的LLM服务工作负载。通过flashinfer bench.apply()机制,可以将经过验证的最快实现直接部署到LLM推理引擎中。

图1. FlashInfer-Bench 架构。FlashInfer Trace 为指定内核契约和语义、交流实现和评估结果提供了标准模式;FlashInfer-Bench Dataset 整理了生产 LLM 服务工作负载;flashinfer bench.apply() 将经过验证的最快实现直接部署到 LLM 推理引擎中。
图1. FlashInfer-Bench 架构。FlashInfer Trace 为指定内核契约和语义、交流实现和评估结果提供了标准模式;FlashInfer-Bench Dataset 整理了生产 LLM 服务工作负载;flashinfer bench.apply() 将经过验证的最快实现直接部署到 LLM 推理引擎中。

A3 背景知识

2.1 LLM 推理流水线与 GPU 内核

现代LLM推理的构成。现代LLM推理由LLM服务引擎驱动,这些引擎负责处理批处理、调度和并行化,并由GPU内核调用和CPU逻辑组成。由于GPU内核占据了大部分执行时间,因此优化它们能直接降低LLM引擎的延迟。

通用GPU内核集。尽管模型多种多样,但大多数模型共享一小组GPU内核,包括:
1. GEMM:输入和输出可以是bf16或低比特(如fp8)。它需要使用张量核心指令来实现最高速度。低比特变体需要额外的量化/反量化逻辑。
2. Attention及其变体:例如,分页(paged)、分组(grouped)、基数(radix)和多潜(multi-latent)注意力。这需要张量核心,并且需要特殊的优化,如FlashAttention来实现。
3. 融合的专家混合(MoE):一个融合内核,处理MoE的路由逻辑和对应于多个专家的多个MLP。
4. 采样与后处理:例如top-p、top-k、温度(temperature)。这些是非确定性算子,其结果取决于输入分布和随机数。

2.2 内核优化方法

内核优化的敏感性。内核优化对硬件(SM数量、内存层次结构、张量核心代数)、数值格式(FP16/BF16/FP8/INT8)和工作负载形状(序列/批次长度、缓存布局、稀疏性)高度敏感,这使得通用的“一刀切”内核难以实现。

系统构建者的三种技术。系统构建者依赖三类技术进行内核优化:
1. 内核库和模板。高度优化的库提供了强大的基线,但通常无法在没有自定义内核的情况下利用特定于工作负载的结构(例如,不规则序列、融合的尾声)(Thakkar et al., 2023)。
2. 基于搜索的自动调度。像模板自动调优器这样的系统在固定的搜索空间内探索参数化的调度方案,以找到好的分块(tiling)、线程/块映射和融合策略。它们功能强大但受限于模板的表达能力,并且当需要在不同硬件或形状上重新访问搜索空间时,搜索成本可能过高(Chen et al., 2018; Zheng et al., 2020; Shao et al., 2022)。
3. 生成式程序合成。最近的LLM可以直接编写低级GPU代码,有时能发现超越现有模板的新颖融合和数据流模式。这开启了一个巨大的设计空间,但没有严格的验证也引入了正确性和安全风险;像Triton这样的轻量级DSL使得自定义内核变得易于访问(Tillet et al., 2019)。

FlashInfer-Bench的综合方法。FlashInfer-Bench利用了后两种方法的优点:它使生成模型能够提出结构上新颖的内核,同时用一个严格的、生产级的评估工具来包围它,以防止性能回归和奖励作弊(reward hacking)。

2.3 用于GPU代码生成的LLM

LLM在GPU内核合成中的潜力与挑战。最近的进展表明,当提供接口描述和反馈循环时,LLM可以合成重要的GPU内核和融合算子。像KernelBench(Ouyang et al., 2025)这样的公开评估主要评估生成能力——模型是否能生成一个编译通过的内核,在选定的输入上与参考实现匹配,并实现合理的加速。

从能力到生产的实际需求。在实践中,从能力转向生产需要额外的要素:精确的任务规范(API语义、支持的形状/数据类型、内存布局约束)、对奖励作弊和非确定性的防御、对现实工作负载分布的覆盖,以及部署或回滚候选内核的敏捷路径。

FlashInfer-Bench的贡献。像Triton(Tillet et al., 2019)这样的轻量级DSL使得人类和AI代理都能轻松编写自定义内核,而FlashInfer-Bench则提供了缺失的操作性支架——用于标准化任务交换的FlashInfer Trace,确保安全性和正确性的健壮验证器,以及无需重写引擎即可实现系统级收益的即时动态替换机制。


A2 方法细节

3.1 FlashInfer Trace

图2. FlashInfer Trace 模式设计。Definition 描述内核任务。Workload 描述内核的真实世界输入。Solution 描述 AI 生成的解决方案。Evaluation 描述基准测试的评估结果。每个组件还包括用于分组和过滤的辅助字段。
图2. FlashInfer Trace 模式设计。Definition 描述内核任务。Workload 描述内核的真实世界输入。Solution 描述 AI 生成的解决方案。Evaluation 描述基准测试的评估结果。每个组件还包括用于分组和过滤的辅助字段。

标准化语言的设计。为了打通从内核生成到评估再到部署的闭环,需要一种人类和AI代理都能理解的标准化语言。FlashInfer Trace担当了这一通用语言的角色。它清晰地阐述了内核的语义契约、实现和具体的评估结果。该抽象被刻意设计得极为精简(例如,在内核Definition中不暴露与实现相关的系统元数据),但又足够充分(不同的算子引入了它们所需的关键维度和约束)。

FlashInfer Trace的四个核心组件。FlashInfer Trace模式包含四个组件,如图2所示。这四个组件共同构成一个自包含的Trace对象,确保了可移植性和可复现性。
* Definition(定义):一个JSON规范,描述了算子的输入/输出张量及其数据类型(dtypes)、维度轴(可以是静态值const或由工作负载决定的值var),以及一个基于纯PyTorch的参考函数作为数学语义的唯一来源。可选的约束(constraints)用于编码维度轴之间的关系。
* Workload(工作负载):一个绑定到特定内核Definition的具体测试输入。所有var类型的轴都被赋予整数值,每个输入通过以下方式之一具体化:记录的safetensors文件、随机生成或字面标量值。
* Solution(解决方案):一个满足所选Definition接口和语义的具体实现。它提供源文件和一个可调用的入口点,以及兼容性元数据(例如,目标GPU架构和软件版本)。还包括可扩展的语言/DSL支持。
* Evaluation(评估):一个不可变的基准测试记录,它精确地绑定了一个特定的Definition × Solution × workload组合,并报告运行状态、正确性和性能摘要,以及执行环境的快照。

对动态形状和不规则输入的支持。FlashInfer Trace模式原生支持动态和静态内核形状,其中每个轴可以定义为var类型(其值由工作负载决定)或const类型(其值在编译时固定)。这使得AI能够针对特定形状优化内核。它还支持不规则输入,例如Attention中使用的页表。为了实现这一点,完整的页表张量和存储索引指针的整数张量都可以作为输入提供,从而允许系统精确地描述不规则张量输入。具体示例见附录A。

3.2 FlashInfer-Bench 数据集

图3. FlashInfer-Bench 数据集收集流程。我们使用默认的通用配置,针对真实世界的流量为主要模型提供服务,并整理内核定义和工作负载。
图3. FlashInfer-Bench 数据集收集流程。我们使用默认的通用配置,针对真实世界的流量为主要模型提供服务,并整理内核定义和工作负载。

数据集的构建理念。在FlashInfer Trace的基础上,我们整理了FlashInfer-Bench数据集,这是一个不断发展的标准,它将常见的服务内核Definition与代表性的Workload配对,持续跟踪和收集可供人类和AI代理优化的目标。

数据收集的真实世界相关性。我们的目标是关注LLM服务的真实世界相关性。我们覆盖了DeepSeek-V3、Llama-3.1-8B、Qwen3-30B-A3B等模型,涵盖了GEMM、Attention、Normalization、Sampling和MoE等算子家族。工作负载是通过在SGLang中使用默认、常用的配置(例如,为DeepSeek-V3使用原生FP8量化和8路张量并行)运行这些模型,并为它们提供ShareGPT提示来收集的。

内核定义的分类标准。在收集内核Definition时,我们将两个内核调用归类于同一个Definition,当且仅当它们:(i) 共享相同的I/O规范和运行参考语义;(ii) 暴露相同的一组轴,且具有相同的const/var角色;(iii) 在所有const轴的值上都一致。

定义的设计哲学。我们特意倾向于使用特定的Definition而非宽泛的定义,理想情况下具体到一个特定模型层的内核调用,以实现尽力而为的内核优化和明确的调度。我们不允许可选输入或行为标志,并将默认行为编码在运行参考函数中,使其成为契约的一部分。当行为必须不同时,我们引入一个新的Definition,而不是添加运行时开关。

工作负载的整理与去重。我们通过一种性能感知、保持多样性的规约方法来整理和去重收集到的工作负载。当输入值对内核性能(例如,采样概率分布)或正确性(例如,极端边缘情况)有重大影响时,我们转储完整的张量;否则,我们使用带种子的随机运行时张量来节省存储空间。

最终数据集的构成。然后,我们沿着对性能敏感的轴(如批处理大小)和张量统计数据(如attention的平均序列长度)进行去重,在保留多样性和代表性的同时精简集合。最终,每个Definition保留约50个工作负载用于评估。

3.3 稳健的内核基准测试

基准测试子系统的目标。为了保证整个内核评估过程的稳健性和效率,我们构建了一个基准测试子系统,它提供严格的正确性验证和稳健、可复现的计时,并能在多设备环境中原生运行。

确定性内核的正确性验证。对于期望产生确定性输出的操作(例如GEMM、归一化、注意力等),我们直接将内核的输出与参考输出进行逐元素比较。如果内核输出ysol的每个元素都满足以下条件,则认为其是正确的:

$$|y_{\text{sol}}-y_{\text{ref}}| \le \epsilon_{\text{abs}} + \epsilon_{\text{rel}} \cdot |y_{\text{ref}}|$$

其中yref是参考输出。我们也会拒绝任何包含非有限值(NaN或Inf)的输出。对于每个内核,我们记录在所有测试元素和试验中观察到的最大误差。一个确定性内核只有在其每个输出元素都落在允许的误差范围内时,才能通过正确性检查。

低精度内核的正确性验证。使用较低精度算术(例如FP8)的内核会引入比全精度基线系统性更大的误差。我们不使用单一宽松的全局容忍度,而是采用匹配比率规则:如果至少有ρ比例的输出满足标准误差准则,则内核是正确的。例如,当ρ = 0.95时,我们要求95%的输出元素通过严格的误差界限,允许一小部分离群元素存在。

随机内核的正确性验证。对于像采样这样的随机算子,逐元素比较是无效的,因为每次运行的输出都不同。取而代之,我们验证采样输出是否遵循正确的概率分布。我们从输入概率p和可选掩码M(例如,top-k或nucleus/top-p)推导出基准分布q,对其进行归一化,并重复执行内核以获得经验分布ˆf。最后,我们计算经验分布和期望分布之间的总变差距离(TVD),并要求TVD小于选定的阈值τTVD。我们使用TVD因为它直接给出了任何事件上最坏情况概率误差的上限。除了TVD检查,我们还验证每个样本是否被阈值掩码接受;任何掩码违规(例如,采样了一个本应被给定top-k排除的索引)都会导致立即失败。

性能测量。我们维护一个每个GPU、多进程可见的设备锁。为防止同一设备上进程/任务之间的干扰,计时例程仅在获取设备锁后运行。每个内核执行w次不计时的预热运行,然后进行m次计时运行。我们使用基于CUDA事件的设备端计时,并报告m次测量运行的平均延迟。

隔离机制。为了最小化解决方案之间的交叉干扰,以及防止LLM“攻击”基准测试(例如,通过读取残留内存来推断参考输出),我们提供了一种完全隔离的基准测试模式:每个解决方案在其自己的子进程中运行,并在完成或超时后终止,同时拆除CUDA上下文以避免跨运行的状态遗留。我们还提供了一种持久模式,每个GPU有一个长生命周期的工作进程和一个小的预热备用工作进程池,这大大减少了子进程和CUDA上下文的初始化开销,并能在解决方案失败并损坏上下文时快速恢复。这两种模式共同平衡了大规模扫描所需的效率与完全隔离所提供的稳健性和安全保证。

系统支持。随着数据集的扩展,高效的基准测试对于我们工作流的及时和可持续运作变得至关重要。基于这种效率意识,我们构建了一个可扩展、容错的多设备基准测试服务。对于每个就绪的Solution×Workload作业,我们有一个调度器,它构建一个成本矩阵,该矩阵考虑了基线驻留时间和热编译缓存与可用设备工作进程的关系。然后,它使用匈牙利算法分配微批次,并通过指数加权移动平均在线更新成本模型,然后再解决下一个批次。调度器执行工作进程健康检查并处理故障恢复。执行默认为持久模式;重复失败的解决方案会被推迟到隔离模式下运行。输入和参考输出在目标设备上具体化,并在可能时重用。编译后的解决方案缓存在内存中,必要时持久化到磁盘(例如,CUDA二进制文件)以防止重复构建。

3.4 公开排行榜和持续评估

排行榜的构建与功能。我们基于第3.3节的基准测试堆栈,托管了一个公开排行榜(见图4)。它接受FlashInfer Trace格式的提交,在真实工作负载上进行评估,并报告按内核和设备分层的指标,包括正确性、相对于加速阈值的性能曲线、每个工作负载的延迟以及端到端延迟增量。

可复现性与反作弊。为确保可引用性和可复现性,我们定期发布带有版本化数据集的冻结快照,而滚动排行榜则反映最新的评估结果。该服务强制执行反奖励作弊防御措施(运行时隔离、隐藏的工作负载,以及用于确定性、低精度和采样内核的专用验证器)。我们将在第4节分析当前快照。

3.5 在生产中动态替换内核

问题的提出。以前的工作流程需要在服务引擎内部手动更改代码才能部署优化的内核,这为AI代理驱动的自动化创造了瓶颈,并阻碍了评估-部署循环。

解决方案:flashinfer-bench.apply()。为了浮现经过验证的优化内核并闭合此循环,我们引入了flashinfer-bench.apply()。它提供了一种零侵入、低开销的路由机制,可动态地将服务请求映射到数据集中性能最佳的实现。

图4. FlashInfer-Bench 排行榜。在 fast0.95 指标下表现最好的模型是 gemini-2.5-pro、gpt-o3 和 gpt-5-2025-08-07。在正确性方面表现最好的模型是 gpt-5-2025-08-07(83.9% 通过)、gpt-o3(71.3% 通过)和 gemini-2.5-pro(48.8% 通过)。
图4. FlashInfer-Bench 排行榜。在 fast0.95 指标下表现最好的模型是 gemini-2.5-pro、gpt-o3 和 gpt-5-2025-08-07。在正确性方面表现最好的模型是 gpt-5-2025-08-07(83.9% 通过)、gpt-o3(71.3% 通过)和 gemini-2.5-pro(48.8% 通过)。

接口与用法flashinfer bench.apply()提供了两种API。装饰器API包装一个算子;被包装的函数作为后备(fallback)方案。它接受一个固定的Definition名称或一个将绑定的运行时参数映射到Definition名称的解析器。命令式API允许从任意位置进行自定义的内核调用;它为给定输入解析出性能最佳的解决方案并立即返回结果。我们提供与FlashInfer的一流集成,因此只需启用环境变量FIB_ENABLE_APPLY=1即可路由常见的操作,无需更改代码。当全局禁用flashinfer bench.apply()时,调用会透明地传递给原始实现。

离线缓存预构建。为了最小化apply()对服务性能的影响,我们引入了一种用于动态调度的预编译(AOT)索引,将在线调度计算减少为几次O(1)的索引查找。

索引构建过程。在服务引擎启动前,apply()运行时会根据本地数据集初始化一个索引。它首先根据可配置的错误阈值筛选trace,然后从trace中的工作负载中提取特征(如形状)形成一个键,并为每个键选择最快的解决方案作为索引值。在所有选定的解决方案中,我们按照可配置的比例,选择被选用次数最多的那些,并进行预编译(AOT)成可执行文件。其余的将进行即时编译(JIT),以平衡构建成本和运行时开销。

图5. flashinfer bench.apply() 的工作流程。它是一个动态调度器,在运行时根据内核输入检索最佳的 Solution,并返回执行结果。
图5. flashinfer bench.apply() 的工作流程。它是一个动态调度器,在运行时根据内核输入检索最佳的 Solution,并返回执行结果。

在线轻量级调度。在运行时,我们从内核的输入参数构建键,在索引中执行一次O(1)查找,并找到或编译一个有效的内核来执行。当启用CUDA图并进行适当的预热后,其开销可以忽略不计(见第4.5节)。


A4 实验环境与结果

实验环境总结

Algorithm 1 Feedback-loop Agent

Input :Definition, Language, Hardware
Output : Sol* (best solution)
S ← ∅
Agent ← CodeAgent.Initialize(Definition, Language, Hardware)
Sol0 ← Agent.Generate()
for i ← 0 to N − 1 do
    Tracei ← FlashInfer-Bench.Benchmark(Definition, Soli)
    if Tracei.Status = PASSED then
        S ← S ∪ {(Soli, Tracei)}
    Soli+1 ← Agent.Optimize(Tracei)
Sol* ← arg max(Soli,Tracei)∈S Tracei.Speedup
return Sol

实验结果总结

AI代理能力分析

AI代理生成内核案例研究
* GEMM - 编译器的帮助:GPT-5生成的Triton内核比其生成的CUDA内核快4.5倍(0.11ms vs 0.5ms)。主要性能差异源于张量核心原语的选择:Triton编译器通过tl.dot()自动选择了最新的tcgen05指令,而CUDA代码生成则默认使用了较旧的WMMA指令。这表明高级DSL通过抽象底层复杂性,使AI代理能更有效地探索和应用高级优化技术。
* GQA Paged Decode - 优化困难:GPT-5生成的GQA paged decode CUDA实现仅包含在线softmax等少量优化,缺乏块级分块、异步执行等SOTA技术。即使明确提示这些优化策略,LLM代理也未能在10次尝试中生成利用这些优化的正确内核。这表明对于复杂内核,预训练LLM难以利用CUDA提供的细粒度硬件控制能力。

端到端系统中的内核替换
* apply()引入的开销极小(图8):通过将原生内核替换为相同的实现来隔离apply()的固定成本,实验显示apply()为每次内核调用引入1-2微秒的开યો,端到端开销在所有批次大小下均低于0.8%。
* apply()将内核收益转化为端到端延迟改进(图8):实验证明,内核级别的性能提升能转化为可测量的端到端收益。将基准内核(FlashInfer)替换为性能更好(Gemini-2.5-Pro)和更差(GPT-5)的AI生成内核后,端到端请求延迟与内核的独立性能表现一致。这验证了动态内核替换机制的有效性,即替换为更好的内核能带来更好的端到端效率。

图8. 内核延迟与端到端延迟比较。(上)三种实现的 fused_add_rmsnorm_h4096 内核在批大小为 1、16 和 64 时的延迟。(下)比较原始基准与不同内核替换机制的端到端请求延迟。所有测量单位为毫秒;越低越好。
图8. 内核延迟与端到端延迟比较。(上)三种实现的 fused_add_rmsnorm_h4096 内核在批大小为 1、16 和 64 时的延迟。(下)比较原始基准与不同内核替换机制的端到端请求延迟。所有测量单位为毫秒;越低越好。

A7 补充细节

5.1 用于CUDA生成的LLM

与现有基准测试的区别。近年来,LLM在生成GPU内核方面展现出卓越的能力,诸如KernelBench (Ouyang et al., 2025) 和TritonBench (Li et al., 2025) 等基准测试系统地评估了这些能力。这些基准测试主要侧重于评估模型的能力。BackendBench (Saroufim et al., 2025) 研究了如何使用LLM为PyTorch生成内核并将其集成。相比之下,FlashInfer-Bench专注于LLM系统中的主要工作负载,并提供一个端到端的生产系统,将评估、验证和部署集成到一个统一的框架中。

与LLM训练和代理设计的互补性。近期的进展探索了训练LLM和设计代理以生成高效内核。Kevin (Baronio et al., 2025) 和KernelLLM (Fisches et al., 2025) 开发了用于内核生成的后训练模型。Dong等人 (2025) 为内核转译任务设计了一个代理,而Wei等人 (2025) 研究了基于LLM的CUDA内核生成代理设计。这些代理和模型设计与FlashInfer-Bench是互补的,FlashInfer-Bench可以进一步评估这些代理和模型在真实LLM系统内核任务上的性能。

5.2 用于系统优化的机器学习

与传统搜索方法的对比。机器学习早已应用于编译器和系统优化。TVM (Chen et al., 2018) 开创了张量程序自动优化的先河,随后的AutoTVM和Ansor (Zheng et al., 2020) 引入了基于搜索的调优和学习成本模型。最近,Meta-Schedule (Shao et al., 2022) 通过设计空间抽象和策略学习将这些方法泛化。这些系统在由人工设计的模板定义的固定优化空间内运行。

FlashInfer-Bench的生成式方法。FlashInfer-Bench采用一种基于生成的方法:LLM直接提出候选内核(可能超出任何预定义的调度空间),然后对这些内核进行严格的功能验证和性能基准测试。这将有效的探索边界从在固定模板集内搜索扩展到合成新的实现模式。

5.3 内核库和自定义内核DSL

生态系统中的定位。高度优化的内核库,如CUTLASS (Thakkar et al., 2023)、cuBLAS (NVI, 2025) 和FlashInfer (Ye et al., 2025),为LLM系统中的内核提供了强大的基线。像Triton (Tillet et al., 2019) 这样的领域特定语言降低了编写自定义GPU内核的门槛。FlashInfer-Bench将这些生态系统视为代理的实现目标,并专注于将候选内核连接到生产系统。

对推理工作负载的特化。一些工作专门针对推理工作负载。FlashAttention (Dao et al., 2022) 引入了IO感知的精确注意力机制,显著减少了内存流量并提高了吞吐量。Multi-Query Attention (Shazeer, 2019) 通过在多个头之间共享键/值来减少KV缓存占用并改善解码延迟。服务引擎采用分页KV缓存和调度策略(如PagedAttention)以在不规则、多租户工作负载下保持高利用率 (Kwon et al., 2023)。FlashInfer-Bench通过将任务基于真实追踪(形状分布、缓存布局)并评估对正确部署至关重要的数值和批处理属性,来捕捉这些现实情况。

5.4 LLM推理系统

与推理框架的共生关系。诸如vLLM (Kwon et al., 2023)、SGLang (Zheng et al., 2025)、TensorRT-LLM (NVIDIA, 2025) 和MLC-LLM (MLC team, 2023-2025) 等框架展示了适用于大型模型的可扩展推理基础设施。这些系统指导了FlashInfer-Bench的内核选择——优先考虑现代LLM操作,如注意力和MoE,而不是传统的卷积等操作——并提供了参考实现。反过来,FlashInfer-Bench使这些框架能够快速评估优化内核并将其部署到生产中,从而创造了一个互利的生态系统。


A5 结论

核心贡献。我们介绍了FlashInfer-Bench,一个系统性的方法,它打通了从AI内核生成到生产影响的闭环。其核心是FlashInfer Trace模式,它标准化了算子契约、真实服务工作负载、候选实现和不可变的评估。在此基础上,我们的基准测试可以衡量确定性、低精度和采样内核,并通过apply()机制,可以将经过验证的最佳内核替换到SGLang和vLLM等引擎中,而无需更改任何代码。

主要发现。我们的实时排行榜持续追踪前沿模型在真实世界和LLM工作负载上的GPU编程能力。我们的评估得出了三个实践性的结论:(1) 编译是主要的失败模式;(2) 模型难以利用硬件特性;(3) 语言选择是一种权衡——Triton提供了高正确性和可用性,而CUDA在成功时能达到更高的峰值性能。在端到端方面,动态替换增加了微不足道的开销,并可靠地将内核级别的增益转化为LLM服务中更低的延迟和更高的吞吐量。

局限性与未来工作。我们当前的工作范围尚未覆盖多GPU或通信内核,并且支持的模型、硬件设备和编程语言范围仍然有限。未来的工作可以进一步扩展FlashInfer Trace数据集的广度,改进内核正确性验证以防止奖励作弊并确保可靠的基准测试结果,以及基于FlashInfer-Bench反馈循环为LLM系统开发内核代理和微调模型。


A6 附录

A FLASHINFER TRACE 示例

本附录提供了FlashInfer Trace格式的具体示例。

A.1 GEMM

我们从一个通用矩阵乘法内核的trace定义开始。

"name": "gemm_n128_k2048",
"description": "GEMM C = A @ B.T. Captured from Qwen 3 30B A3B moe.gate.",
"op_type": "gemm",
"tags": ["status:verified", "model:qwen3-30b-a3b"],
"axes": {
    "M": { "type": "var" },
    "N": { "type": "const", "value": 128 },
    "K": { "type": "const", "value": 2048 }
},
"inputs": {
    "A": { "shape": ["M", "K"], "dtype": "float16" },
    "B": { "shape": ["N", "K"], "dtype": "float16" }
},
"outputs": {
    "C": { "shape": ["M", "N"], "dtype": "float16" }
},
"reference": "import torch

def run(A, B):
    C = torch.matmul(A, B.T)
    return C"
}

一个相应的生成解决方案对象可能如下所示:

"name": "claude-opus-4-1-20250805_triton_a20c42",
"definition": "gemm_n128_k2048",
"author": "claude-opus-4-1-20250805",
"spec": {
    "language": "triton",
    "target_hardware": ["B200"],
    "entry_point": "main.py::run",
    "dependencies": []
},
"sources": [
    { "path": "main.py", "content": "<source code omitted>" }
],
"description": "claude-opus-4-1-20250805 optimized kernel for gemm_n128_k2048 (round 1)"

一个实例化此定义的工作负载示例如下:

"definition": "gemm_n128_k2048",
"solution": null,
"workload": {
    "uuid": "6ba7c7de-dc5a-48d2-8ada-1382feb5ceac",
    "axes": { "M": 6 },
    "inputs": {
        "A": { "type": "random" },
        "B": { "type": "random" }
    }
},
"evaluation": null

在NVIDIA B200 GPU上评估此解决方案对此工作负载,会产生以下trace记录:

"definition": "gemm_n128_k2048",
"workload": {
    "axes": { "M": 6 },
    "inputs": {
        "A": { "type": "random" },
        "B": { "type": "random" }
    },
    "uuid": "6ba7c7de-dc5a-48d2-8ada-1382feb5ceac"
},
"solution": "claude-opus-4-1-20250805_triton_a20c42",
"evaluation": {
    "status": "PASSED",
    "environment": {
        "hardware": "NVIDIA B200",
        "libs": {
            "torch": "2.8.0+cu128",
            "triton": "3.4.0",
            "cuda": "12.8"
        }
    },
    "timestamp": "2025-10-16T01:10:32.241021",
    "log": "",
    "correctness": {
        "max_relative_error": 0,
        "max_absolute_error": 0,
        "extra": null
    },
    "performance": {
        "latency_ms": 0.023046740692633086,
        "reference_latency_ms": 0.025240250456929125,
        "speedup_factor": 1.0951765715399921
    }
}

A.2 Attention

接下来,我们展示一个基于分页式分组查询注意力(paged grouped-query attention)解码算子的更复杂的示例。与GEMM案例相比,该算子具有更复杂的接口和非平凡的参考实现。

"name": "gqa_paged_decode_h32_kv4_d128_ps1",
"description": "Batched Grouped Query Attention decode",
"op_type": "gqa_paged",
"tags": ["stage:decode", "status:verified", "model:qwen3-30b-a3b"],
"axes": {
    "batch_size": { "type": "var", "description": "Total number of query tokens." },
    "num_qo_heads": { "type": "const", "value": 32 },
    "num_kv_heads": { "type": "const", "value": 4 },
    "head_dim": { "type": "const", "value": 128 },
    "num_pages": { "type": "var" },
    "page_size": { "type": "const", "value": 1 },
    "len_indptr": { "type": "var", "description": "Length of kv_indptr array." },
    "num_kv_indices": { "type": "var", "description": "Total number of KV page indices." }
},
"constraints": [
    "len_indptr == batch_size + 1",
    "num_kv_indices == kv_indptr[-1].item()"
],
"inputs": {
    "q": { "shape": ["batch_size", "num_qo_heads", "head_dim"], "dtype": "bfloat16" },
    "k_cache": { "shape": ["num_pages", "page_size", "num_kv_heads", "head_dim"], "dtype": "bfloat16" },
    "v_cache": { "shape": ["num_pages", "page_size", "num_kv_heads", "head_dim"], "dtype": "bfloat16" },
    "kv_indptr": { "shape": ["len_indptr"], "dtype": "int32", "description": "KV page offsets for each sequence." },
    "kv_indices": { "shape": ["num_kv_indices"], "dtype": "int32", "description": "Page IDs for KV cache lookups." },
    "sm_scale": { "shape": null, "dtype": "float32", "description": "Softmax scale. Default is (1/sqrt(head_dim))." }
},
"outputs": {
    "output": { "shape": ["batch_size", "num_qo_heads", "head_dim"], "dtype": "bfloat16" },
    "lse": { "shape": ["batch_size", "num_qo_heads"], "dtype": "float32", "description": "The 2-based log-sum-exp of attention logits." }
},
"reference": "<reference code shown below>"

相应的PyTorch参考实现如下所示:

import torch
import math

@torch.no_grad()
def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):
    batch_size, num_qo_heads, head_dim = q.shape
    _, page_size, num_kv_heads, _ = k_cache.shape
    len_indptr = kv_indptr.shape[0]
    num_kv_indices = kv_indices.shape[0]
    
    # Check constants
    assert num_qo_heads == 32
    assert num_kv_heads == 4
    assert head_dim == 128
    assert page_size == 1
    
    # Check constraints
    assert len_indptr == batch_size + 1
    assert num_kv_indices == kv_indptr[-1].item()
    
    device = q.device
    output = torch.zeros(
        (batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device
    )
    lse = torch.full(
        (batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device
    )
    
    gqa_ratio = num_qo_heads // num_kv_heads
    
    k_cache_flat = k_cache.squeeze(1).to(
        torch.float32
    )  # [num_pages, num_kv_heads, head_dim]
    v_cache_flat = v_cache.squeeze(1).to(
        torch.float32
    )  # [num_pages, num_kv_heads, head_dim]
    
    for b in range(batch_size):
        page_start = int(kv_indptr[b].item())
        page_end = int(kv_indptr[b + 1].item())
        
        if page_start >= page_end:
            # No KV cache for this batch element
            output[b].zero_()
            continue
        
        # Pages are the token indices for page_size=1
        token_indices = kv_indices[page_start:page_end].to(torch.long)
        
        # Number of tokens is the number of pages for page_size=1
        num_tokens = token_indices.shape[0]
        if num_tokens == 0:
            output[b].zero_()
            continue
        
        # Get Q, K, V for this batch
        k_batch = k_cache_flat[token_indices]  # [num_tokens, num_kv_heads, head_dim]
        v_batch = v_cache_flat[token_indices]  # [num_tokens, num_kv_heads, head_dim]
        q_batch = q[b].to(torch.float32)  # [num_qo_heads, head_dim]
        
        for h in range(num_qo_heads):
            # Find corresponding KV head for GQA
            kv_head = h // gqa_ratio
            q_head = q_batch[h]  # [head_dim]
            k_head = k_batch[:, kv_head]  # [num_tokens, head_dim]
            v_head = v_batch[:, kv_head]  # [num_tokens, head_dim]
            
            logits = torch.matmul(q_head, k_head.T)  # [num_tokens]
            logits_scaled = logits * sm_scale
            
            # Compute 2-base LSE
            lse[b, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)
            
            attn = torch.softmax(logits_scaled, dim=-1)  # [num_tokens]
            out_head = torch.matmul(attn, v_head)  # [head_dim]
            
            output[b, h] = out_head.to(torch.bfloat16)
            
    return output, lse

一个可能生成的Triton解决方案可以表示如下:

"name": "claude-opus-4-1_triton_de54a2",
"definition": "gqa_paged_decode_h32_kv4_d128_ps1",
"description": "claude-opus-4-1-20250805 optimized kernel (round 5)",
"author": "claude-opus-4-1-20250805",
"spec": {
    "language": "triton",
    "target_hardware": ["B200"],
    "entry_point": "main.py::run",
    "dependencies": []
},
"sources": {
    "path": "main.py",
    "content": "<source code omitted>"
}

与GEMM案例一样,我们可以为此定义捕获一个具体的工作负载实例。在此示例中,我们有一个标量输入和一些从safetensors转储中加载的其他输入:

"definition": "gqa_paged_decode_h32_kv4_d128_ps1",
"solution": null,
"workload": {
    "uuid": "0c2489b2-f878-428b-b1bd-d0c6d4c39338",
    "axes": {
        "batch_size": 1,
        "num_pages": 8,
        "len_indptr": 2,
        "num_kv_indices": 7
    },
    "inputs": {
        "q": { "type": "random" },
        "k_cache": { "type": "random" },
        "v_cache": { "type": "random" },
        "kv_indptr": { "type": "safetensors", "path": "/path/to/safetensor", "tensor_key": "kv_indptr" },
        "kv_indices": { "type": "safetensors", "path": "/path/to/safetensor", "tensor_key": "kv_indices" },
        "sm_scale": { "type": "scalar", "value": 0.0883883461356163 }
    }
},
"evaluation": null

在此工作负载上评估上述解决方案会产生以下trace记录:

"definition": "gqa_paged_decode_h32_kv4_d128_ps1",
"workload": {
    "axes": {
        "batch_size": 1,
        "num_pages": 8,
        "len_indptr": 2,
        "num_kv_indices": 7
    },
    "inputs": {
        "q": { "type": "random" },
        "k_cache": { "type": "random" },
        "v_cache": { "type": "random" },
        "kv_indptr": { "type": "safetensors", "path": "/path/to/safetensor", "tensor_key": "kv_indptr" },
        "kv_indices": { "type": "safetensors", "path": "/path/to/safetensor", "tensor_key": "kv_indices" },
        "sm_scale": { "type": "scalar", "value": 0.0883883461356163 }
    },
    "uuid": "0c2489b2-f878-428b-b1bd-d0c6d4c39338"
},
"solution": "claude-opus-4-1_triton_de54a2",
"evaluation": {
    "status": "PASSED",
    "environment": {
        "hardware": "NVIDIA B200",
        "libs": {
            "torch": "2.8.0+cu128",
            "triton": "3.4.0",
            "cuda": "12.8"
        }
    },
    "timestamp": "2025-10-16T01:24:16.694452",
    "log": "",
    "correctness": {
        "max_relative_error": 0.01480561401695013,
        "max_absolute_error": 0.00048828125,
        "extra": null
    },
    "performance": {
        "latency_ms": 0.02266162589486805,
        "reference_latency_ms": 29.439284915015815,
        "speedup_factor": 1299.0808802329861
    }
}

B AI代理生成的内核

本附录包含案例研究中讨论的AI代理生成的内核的完整源代码。

B.1 Triton GEMM 内核

$$\text{TFLOPs} = \frac{2MNK}{t \cdot 10^{12}}.$$
import math
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_warps=8, num_stages=4),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=8, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=4),
    ],
    key=['M'],
)
@triton.jit
def _gemm_n4096_k4096_kernel(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bn, stride_bk,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    tl.static_assert(BLOCK_K % 16 == 0, "BLOCK_K must be a multiple of 16 for tensor cores")

    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    M_mask = offs_m[:, None] < M
    N_mask = offs_n[None, :] < N

    for k0 in range(0, K, BLOCK_K):
        offs_k = k0 + tl.arange(0, BLOCK_K)

        # Pointers
        a_ptrs = A_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
        b_ptrs = B_ptr + (offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk)
        
        a = tl.load(a_ptrs, mask=M_mask & (offs_k[None, :] < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & (offs_k[None, :] < K), other=0.0)
        
        acc += tl.dot(a, tl.trans(b))

    c_ptrs = C_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    tl.store(c_ptrs, acc.to(tl.float16), mask=M_mask & N_mask)

def run(A, B, **kwargs):
    """
    Compute C = A @ B.T where:
    - A: [M, 4096] float16
    - B: [4096, 4096] float16
    Returns C: [M, 4096] float16
    """
    if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor):
        raise TypeError("A and B must be torch.Tensor")

    if A.ndim != 2 or B.ndim != 2:
        raise ValueError(f"Expected 2D tensors, got A.ndim={A.ndim}, B.ndim={B.ndim}")

    M, K_a = A.shape
    N_b, K_b = B.shape

    if K_a != 4096 or K_b != 4096 or N_b != 4096:
        raise ValueError(f"Invalid shapes: A is {A.shape}, B is {B.shape}; expected A: [M,4096], B: [4096,4096]")

    # Dtype checks/conversions
    if A.dtype != torch.float16: A = A.to(torch.float16)
    if B.dtype != torch.float16: B = B.to(torch.float16)

    # Device management
    orig_dev_A = A.device
    orig_dev_B = B.device
    any_cuda_input = (A.is_cuda or B.is_cuda)

    if not torch.cuda.is_available():
        if any_cuda_input:
            raise RuntimeError("CUDA is not available but one or more inputs are CUDA tensors.")
        # Triton requires CUDA; no CPU fallback provided
        raise RuntimeError("No CUDA device is available.")

    # Choose compute device
    if A.is_cuda:
        compute_device = A.device
    elif B.is_cuda:
        compute_device = B.device
    else:
        compute_device = torch.device(f"cuda:{torch.cuda.current_device()}")

    # Move to compute device and make contiguous for optimal access
    A_dev = A.to(device=compute_device, non_blocking=True).contiguous()
    B_dev = B.to(device=compute_device, non_blocking=True).contiguous()

    # Allocate output on compute device
    N = 4096
    K = 4096
    C_dev = torch.empty((M, N), dtype=torch.float16, device=compute_device)

    # Kernel launch parameters
    def grid(meta):
        return (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))

    # Call kernel
    _gemm_n4096_k4096_kernel[grid](
        A_dev, B_dev, C_dev,
        M, N, K,
        A_dev.stride(0), A_dev.stride(1),
        B_dev.stride(0), B_dev.stride(1),
        C_dev.stride(0), C_dev.stride(1),
    )

    # Decide output device: preserve original locations; if both were CPU, return CPU; otherwise prefer A's device if CUDA, else B's
    if orig_dev_A.type == 'cpu' and orig_dev_B.type == 'cpu':
        out_device = torch.device('cpu')
    elif orig_dev_A.type == 'cuda':
        out_device = orig_dev_A
    elif orig_dev_B.type == 'cuda':
        out_device = orig_dev_B
    else:
        out_device = torch.device('cpu')

    C_out = C_dev if C_dev.device == out_device else C_dev.to(out_device, non_blocking=True)
    return C_out

B.2 CUDA GEMM 内核

kernel.h:

#pragma once
// Include guard
#ifndef GEMM_N4096_K4096_KERNEL_H
#define GEMM_N4096_K4096_KERNEL_H

#include <cuda_fp16.h>
#include <cuda_runtime.h>

/**
* @brief Host-side launcher for the GEMM operation C = A * B.T.
*
* This function launches a cuBLAS-based GEMM kernel optimized for
* B200-class GPUs with N=4096 and K=4096. It computes C = A * B.T
* where A is [M, K], B is [N, K], and C is [M, N]. All tensors are
* expected to have row-major layout.
*
* @param M The variable dimension of matrix A and C.
* @param A Pointer to the input matrix A [M, 4096] in device memory (dtype=half).
* @param B Pointer to the input matrix B [4096, 4096] in device memory (dtype=half).
* @param C Pointer to the output matrix C [M, 4096] in device memory (dtype=half).
* @param stream The CUDA stream for kernel execution.
*/
void gemm_n4096_k4096_launcher(
    int M,
    const half* A,
    const half* B,
    half* C,
    cudaStream_t stream
);

#endif // GEMM_N4096_K4096_KERNEL_H

http://kernel.cu:

#include "kernel.h"
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <stdexcept>
#include <string>

// Helper for checking cuBLAS status and converting it to a descriptive error
#define CUBLAS_CHECK(status)                                                   \
  do {                                                                         \
    cublasStatus_t err = (status);                                             \
    if (err != CUBLAS_STATUS_SUCCESS)                                          \
      throw std::runtime_error("cuBLAS error: " + std::to_string(err) + " at " \
                               __FILE__ ":" + std::to_string(__LINE__));       \
  } while (0)

// Manages a singleton cuBLAS handle for efficiency.
// This avoids the overhead of creating/destroying the handle on every call.
cublasHandle_t get_cublas_handle() {
  static struct CublasHandle {
    cublasHandle_t handle;
    CublasHandle() {
      CUBLAS_CHECK(cublasCreate(&handle));
      // Enable Tensor Core operations for peak performance on B200
      CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
    }
    ~CublasHandle() {
      if (handle) cublasDestroy(handle);
    }
  } singleton_handle;
  return singleton_handle.handle;
}

/**
* @brief Host-side launcher implementation using cuBLAS.
*/
void gemm_n4096_k4096_launcher(
    int M,
    const half* A,
    const half* B,
    half* C,
    cudaStream_t stream
) {
    cublasHandle_t handle = get_cublas_handle();
    CUBLAS_CHECK(cublasSetStream(handle, stream));

    const int N = 4096;
    const int K = 4096;

    const float alpha = 1.0f;
    const float beta = 0.0f;

    // The key to using cuBLAS (column-major) with row-major PyTorch tensors is
    // to rephrase the operation in a way that cuBLAS understands and that results
    // in the correct memory layout for the output.
    //
    // 1. Goal (Row-Major): C_rm[M, N] = A_rm[M, K] * B_rm.T[K, N]
    //
    // 2. cuBLAS View (Column-Major): cuBLAS interprets the memory of a row-major
    // matrix X_rm[rows, cols] as a column-major matrix X_cm[cols, rows].
    //    A_rm[M, K] is seen as A_cm[K, M].
    //    B_rm[N, K] is seen as B_cm[K, N].
    //    C_rm[M, N] is seen as C_cm[N, M].
    //
    // 3. Transformation: The equation C_rm = A_rm * B_rm.T is equivalent to
    //    C_cm.T = A_cm.T * (B_cm.T).T => C_cm.T = A_cm.T * B_cm.
    //    Taking the transpose of the whole equation gives us what cuBLAS should compute:
    //    C_cm = (A_cm.T * B_cm).T = B_cm.T * A_cm.
    //
    // 4. cuBLAS Call: We ask cuBLAS to compute D = op1 * op2, where the result D
    //    is written into the memory of C.
    //    - op1 = B_cm.T. This means the first matrix is B, and transa=CUBLAS_OP_T.
    //    - op2 = A_cm.  This means the second matrix is A, and transb=CUBLAS_OP_N.
    //
    // 5. Dimensions for cuBLAS:
    //    - m = rows of op1 (B.T) = N
    //    - n = cols of op2 (A)   = M
    //    - k = common dimension    = K
    //    The output matrix will be [m, n] = [N, M] in column-major layout, which
    //    perfectly matches the memory layout of our desired row-major C_rm[M, N].
    //    This resolves the illegal memory access and ensures correctness.
    const int lda = K;  // Leading dimension of A_rm[M, K] is K
    const int ldb = K;  // Leading dimension of B_rm[N, K] is K
    const int ldc = N;  // Leading dimension of C_rm[M, N] is N

    CUBLAS_CHECK(cublasGemmEx(
        handle,
        CUBLAS_OP_T,       // transa: Corresponds to first matrix (B), transposed
        CUBLAS_OP_N,       // transb: Corresponds to second matrix (A), not transposed
        N,                 // m: rows of op(B.T)
        M,                 // n: columns of op(A)
        K,                 // k: common dimension
        &alpha,            // alpha
        B,                 // Pointer to the first matrix (B)
        CUDA_R_16F,        // Btype
        ldb,               // ldb (leading dimension of B)
        A,                 // Pointer to the second matrix (A)
        CUDA_R_16F,        // Atype
        lda,               // lda (leading dimension of A)
        &beta,             // beta
        C,                 // Pointer to C
        CUDA_R_16F,        // Ctype
        ldc,               // ldc (leading dimension of C)
        CUDA_R_32F,        // computeType: Use FP32 accumulators for precision
        CUBLAS_GEMM_DEFAULT_TENSOR_OP // algorithm: Use default heuristic for Tensor Cores
    ));
}

main.cpp:

#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include "kernel.h"
#include <stdexcept>
#include <string>

// Helper macros for concise tensor validation
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_HALF(x) TORCH_CHECK(x.scalar_type() == torch::kFloat16, #x " must be a float16 tensor")

/**
* @brief PyTorch extension entry point for the GEMM operation.
*
* This function validates input tensors and calls the CUDA kernel launcher
* to perform the computation C = A * B.T on the GPU.
*
* @param A A torch::Tensor of shape [M, 4096] and dtype float16.
* @param B A torch::Tensor of shape [4096, 4096] and dtype float16.
* @return A torch::Tensor C of shape [M, 4096] and dtype float16 containing the result.
*/
torch::Tensor run(torch::Tensor A, torch::Tensor B) {
    // - Input Validation
    CHECK_CUDA(A);
    CHECK_CUDA(B);
    CHECK_CONTIGUOUS(A);
    CHECK_CONTIGUOUS(B);
    CHECK_HALF(A);
    CHECK_HALF(B);

    TORCH_CHECK(A.dim() == 2, "A must be a 2D tensor");
    TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");

    // Dimension Checks
    const int M = A.size(0);
    const int K_A = A.size(1);
    const int N_B = B.size(0);
    const int K_B = B.size(1);
    const int N_spec = 4096;
    const int K_spec = 4096;

    TORCH_CHECK(K_A == K_spec, "A must have shape [M, 4096], but K is ", K_A);
    TORCH_CHECK(N_B == N_spec, "B must have shape [4096, 4096], but N is ", N_B);
    TORCH_CHECK(K_B == K_spec, "B must have shape [4096, 4096], but K is ", K_B);
    TORCH_CHECK(A.device() == B.device(), "Tensors must be on the same CUDA device");

    // Output Tensor Allocation
    auto C_options = torch::TensorOptions()
        .device(A.device())
        .dtype(A.scalar_type());
    auto C = torch::empty({M, N_spec}, C_options);

    // Kernel Execution
    try {
        // Get the current CUDA stream from PyTorch's context to ensure proper synchronization
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        
        // Get raw data pointers. at::Half is compatible with cuda_fp16.h::half
        const half* A_ptr = reinterpret_cast<const half*>(A.data_ptr<at::Half>());
        const half* B_ptr = reinterpret_cast<const half*>(B.data_ptr<at::Half>());
        half* C_ptr = reinterpret_cast<half*>(C.data_ptr<at::Half>());
        
        // Launch the cuBLAS-based kernel
        gemm_n4096_k4096_launcher(M, A_ptr, B_ptr, C_ptr, stream);

    } catch (const std::exception& e) {
        // Propagate exceptions from the CUDA/cuBLAS calls to Python
        throw std::runtime_error(std::string("CUDA kernel execution failed: ") + e.what());
    }
    
    // Check for any asynchronous errors from the kernel launch. This is good practice.
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        throw std::runtime_error(std::string("CUDA asynchronous error: ") + cudaGetErrorString(err));
    }

    return C;
}

// Pybind11 module definition to expose the 'run' function to Python
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("run", &run, "GEMM (A * B.T) for N=4096, K=4096 using a B200-optimized cuBLAS kernel");
}

B.3 Triton GQA Paged Decode Attention

$$\mathrm{TFLOPs}=\frac{\mathrm{FLOPs}_{\text {step }}}{t \cdot 10^{12}}=\frac{4 H_{q} d \text { num\_kv\_indices }}{t \cdot 10^{12}}$$
import math
import torch
import triton
import triton.language as tl

@triton.jit
def gqa_paged_decode_kernel(
    q_ptr,           # *bf16 [B, 32, 128]
    k_ptr,           # *bf16 [N_pages, 8, 128] (page_size squeezed)
    v_ptr,           # *bf16 [N_pages, 8, 128] (page_size squeezed)
    kv_indptr_ptr,   # *int32 [B + 1]
    kv_indices_ptr,  # *int32 [num_kv_indices]
    sm_scale,        # fp32 scalar
    out_ptr,         # *bf16 [B, 32, 128]
    lse_ptr,         # *fp32 [B, 32]
    BLOCK_T: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    NUM_QO_HEADS: tl.constexpr,
    NUM_KV_HEADS: tl.constexpr,
):
    pid = tl.program_id(0)
    batch_idx = pid // NUM_QO_HEADS
    qo_head = pid % NUM_QO_HEADS

    gqa_ratio = NUM_QO_HEADS // NUM_KV_HEADS
    kv_head = qo_head // gqa_ratio

    # ---- strides (in elements, not bytes)
    stride_q_batch = NUM_QO_HEADS * HEAD_DIM
    stride_q_head = HEAD_DIM
    stride_k_page = NUM_KV_HEADS * HEAD_DIM  # page_size = 1
    stride_k_kv_head = HEAD_DIM
    stride_v_page = stride_k_page
    stride_v_kv_head = HEAD_DIM

    # ---- load query vector
    d_offs = tl.arange(0, HEAD_DIM)
    q_ptr_head = q_ptr + batch_idx * stride_q_batch + qo_head * stride_q_head + d_offs
    q_vec = tl.cast(tl.load(q_ptr_head), tl.float32)

    # ---- sequence token range
    start = tl.load(kv_indptr_ptr + batch_idx)
    end = tl.load(kv_indptr_ptr + batch_idx + 1)
    num_tokens = end - start

    # -- streaming softmax vars
    m_val = tl.full([], -1e30, tl.float32)  # running max
    d_val = tl.zeros([], tl.float32)       # running sum exp
    o_vec = tl.zeros([HEAD_DIM], tl.float32)  # running output vector

    offset = tl.zeros([], tl.int32)
    while offset < num_tokens:
        t_offs = tl.arange(0, BLOCK_T)
        remain = num_tokens - offset
        tok_mask = t_offs < remain

        # - load page indices
        pages = tl.load(kv_indices_ptr + start + offset + t_offs, mask=tok_mask, other=0)

        # gather K / V
        k_ptrs = k_ptr + pages[:, None] * stride_k_page + kv_head * stride_k_kv_head + d_offs[None, :]
        v_ptrs = v_ptr + pages[:, None] * stride_v_page + kv_head * stride_v_kv_head + d_offs[None, :]
        
        k_block = tl.cast(tl.load(k_ptrs, mask=tok_mask[:, None], other=0), tl.float32)
        v_block = tl.cast(tl.load(v_ptrs, mask=tok_mask[:, None], other=0), tl.float32)

        # logits
        logits = tl.sum(k_block * q_vec[None, :], axis=1) * sm_scale
        logits = tl.where(tok_mask, logits, -1e30)

        # - - block softmax
        m_block = tl.max(logits, axis=0)
        exp_logits = tl.exp(logits - m_block)
        sum_exp_block = tl.sum(exp_logits, axis=0)
        weighted_v = tl.sum(exp_logits[:, None] * v_block, axis=0)

        # merge with running values
        new_m = tl.maximum(m_val, m_block)
        alpha_prev = tl.exp(m_val - new_m)
        alpha_blk = tl.exp(m_block - new_m)

        o_vec = o_vec * alpha_prev + weighted_v * alpha_blk
        d_val = d_val * alpha_prev + sum_exp_block * alpha_blk
        m_val = new_m

        offset += BLOCK_T

    inv_d = tl.where(d_val == 0, 0.0, 1.0 / d_val)
    out_vec = o_vec * inv_d

    log2e = 1.4426950408889634
    lse_val = tl.where(d_val == 0, -1e30, (tl.log(d_val) + m_val) * log2e)

    # store
    out_ptr_head = out_ptr + batch_idx * stride_q_batch + qo_head * stride_q_head + d_offs
    tl.store(out_ptr_head, tl.cast(out_vec, tl.bfloat16))

    lse_ptr_head = lse_ptr + batch_idx * NUM_QO_HEADS + qo_head
    tl.store(lse_ptr_head, lse_val)


def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale: float | None = None):
    """
    Entry point for gqa_paged_decode_h32_kv8_d128_ps1.
    Returns (output, lse).
    """
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(128.0)

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA device is required to run Triton kernels.")

    # move tensors to GPU if necessary
    tensors = [q, k_cache, v_cache, kv_indptr, kv_indices]
    device_tensors = [t.cuda() if not t.is_cuda else t for t in tensors]
    q_dev, k_dev, v_dev, iptr_dev, idx_dev = [t.contiguous() for t in device_tensors]

    batch_size = q_dev.shape[0]
    num_qo_heads = 32
    head_dim = 128

    # squeeze page dimension (=1)
    k_dev_flat = k_dev.squeeze(1).contiguous()
    v_dev_flat = v_dev.squeeze(1).contiguous()

    out_dev = torch.empty((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=q_dev.device)
    lse_dev = torch.empty((batch_size, num_qo_heads), dtype=torch.float32, device=q_dev.device)

    # launch kernel
    BLOCK_T = 128
    grid = (batch_size * num_qo_heads,)
    gqa_paged_decode_kernel[grid](
        q_dev, k_dev_flat, v_dev_flat, iptr_dev, idx_dev,
        sm_scale,
        out_dev, lse_dev,
        BLOCK_T=BLOCK_T,
        HEAD_DIM=128,
        NUM_QO_HEADS=32,
        NUM_KV_HEADS=8,
        num_warps=4,
        num_stages=4,
    )

    # move back to original device if needed
    if not q.is_cuda:
        return out_dev.cpu(), lse_dev.cpu()
    return out_dev, lse_dev

B.4 CUDA GQA Paged Decode Attention

#include "kernel.h"
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <math_constants.h>
#include <cmath>
#include <limits>
#include <cstdio>

// Error checking macro
#ifndef CUDA_CHECK
#define CUDA_CHECK(expr)                                                       \
  do {                                                                         \
    cudaError_t __err = (expr);                                                \
    if (__err != cudaSuccess) {                                                \
      printf("CUDA Error %s at %s:%d
", cudaGetErrorString(__err), __FILE__,   \
             __LINE__);                                                        \
    }                                                                          \
  } while (0)
#endif

// Warp utilities
static inline __device__ float warp_reduce_sum(float val) {
    unsigned mask = 0xffffffffu;
    val += __shfl_down_sync(mask, val, 16);
    val += __shfl_down_sync(mask, val, 8);
    val += __shfl_down_sync(mask, val, 4);
    val += __shfl_down_sync(mask, val, 2);
    val += __shfl_down_sync(mask, val, 1);
    return val;
}

static inline __device__ float warp_broadcast(float val, int src_lane) {
    unsigned mask = 0xffffffffu;
    return __shfl_sync(mask, val, src_lane);
}

// Load 4 BF16s as 4 floats (contiguous) from ptr[offset .. offset+3]
static inline __device__ void load_bf16x4_to_f32(
    const __nv_bfloat16* __restrict__ ptr, int offset, float out[4]) {
    out[0] = __bfloat162float(ptr[offset + 0]);
    out[1] = __bfloat162float(ptr[offset + 1]);
    out[2] = __bfloat162float(ptr[offset + 2]);
    out[3] = __bfloat162float(ptr[offset + 3]);
}

// Store 4 floats as BF16s to ptr[offset .. offset+3]
static inline __device__ void store_f32x4_to_bf16(
    __nv_bfloat16* __restrict__ ptr, int offset, const float in[4]) {
    ptr[offset + 0] = __float2bfloat16(in[0]);
    ptr[offset + 1] = __float2bfloat16(in[1]);
    ptr[offset + 2] = __float2bfloat16(in[2]);
    ptr[offset + 3] = __float2bfloat16(in[3]);
}

template <int kBlockThreads>
__launch_bounds__(kBlockThreads, 2) __global__ void gqa_paged_decode_h32_kv8_d128_ps1_kernel(
    const __nv_bfloat16* __restrict__ q,       // [B, 32, 128]
    const __nv_bfloat16* __restrict__ k_cache, // [num_pages, 1, 8, 128] -> flat [num_pages*8, 128]
    const __nv_bfloat16* __restrict__ v_cache, // [num_pages, 1, 8, 128] -> flat [num_pages*8, 128]
    const int32_t* __restrict__ kv_indptr,   // [B+1]
    const int32_t* __restrict__ kv_indices,  // [num_kv_indices]
    float sm_scale,
    __nv_bfloat16* __restrict__ out,         // [B, 32, 128]
    float* __restrict__ lse_out,             // [B, 32]
    int num_batches,
    int num_pages_total
) {
    // Block mapping:
    // grid.x = batch index
    // grid.y = kv_head index in [0, 8)
    const int b = blockIdx.x;
    const int kv_head = blockIdx.y; // 0..7

    if (b >= num_batches || kv_head >= kNumKVHeads) {
        return;
    }

    // Thread mapping:
    const int tid = threadIdx.x;            // 0..127
    const int warp_id = tid >> 5;           // 0..3 (4 warps per block)
    const int lane_id = tid & 31;           // 0..31

    // The 4 query heads attached to this KV head
    const int q_head = kv_head * kGQARatio + warp_id; // 0..31

    // Pointers advance helpers
    const int q_stride_h = kHeadDim;
    const int q_stride_head = kNumQOHeads * kHeadDim;

    // Input sequence token range for this batch item
    const int32_t page_start = kv_indptr[b];
    const int32_t page_end = kv_indptr[b + 1];
    const int32_t num_tokens = page_end - page_start;

    // Shared buffers for one token's K and V vector for this kv_head
    extern __shared__ float smem[];
    float* sh_k = smem;               // [128]
    float* sh_v = smem + kHeadDim;    // [128]
    __shared__ int s_page;

    // Preload Q (each warp for its own q_head)
    // Each lane holds 4 elements to cover 128 dims: 32 lanes * 4 = 128
    const int q_base_offset = (b * q_stride_head) + (q_head * q_stride_h);
    const int d_base = lane_id * 4;
    float q_reg[4];
    load_bf16x4_to_f32(q + q_base_offset, d_base, q_reg);
    
    // Accumulators per warp/head
    float out_acc[4] = {0.f, 0.f, 0.f, 0.f};
    float m = -CUDART_INF_F; // running max of logits (scaled)
    float s = 0.f;           // running sum of exp(logit - m)
    
    // If no tokens: write zeros and lse = -inf and return
    if (num_tokens <= 0) {
        float zeros[4] = {0.f, 0.f, 0.f, 0.f};
        store_f32x4_to_bf16(out + (b * kNumQOHeads + q_head) * kHeadDim, d_base, zeros);
        if (lane_id == 0) {
            lse_out[b * kNumQOHeads + q_head] = -CUDART_INF_F;
        }
        return;
    }

    // Iterate over tokens
    for (int t = 0; t < num_tokens; ++t) {
        if (tid == 0) {
            s_page = kv_indices[page_start + t];
        }
        __syncthreads();

        // Bounds check for safety (though constraints guarantee validity)
        int page_id = s_page;
        if (page_id < 0) page_id = 0;
        if (page_id >= num_pages_total) page_id = (num_pages_total - 1);

        // Flattened (page_size=1): [num_pages, 1, 8, 128] -> [num_pages*8, 128]
        // Base index for this token and kv_head
        size_t base_idx = (static_cast<size_t>(page_id) * kNumKVHeads + kv_head) * kHeadDim;
        
        // Cooperatively load K and V vectors into shared memory as float
        if (tid < kHeadDim) {
            sh_k[tid] = __bfloat162float(k_cache[base_idx + tid]);
            sh_v[tid] = __bfloat162float(v_cache[base_idx + tid]);
        }
        __syncthreads();

        // Each warp computes its logit: dot(q, k) using 4 elements per lane
        float partial = 0.f;
        partial += q_reg[0] * sh_k[d_base + 0];
        partial += q_reg[1] * sh_k[d_base + 1];
        partial += q_reg[2] * sh_k[d_base + 2];
        partial += q_reg[3] * sh_k[d_base + 3];
        float sum = warp_reduce_sum(partial);
        float logit = warp_broadcast(sum, 0) * sm_scale;

        // Streaming softmax update
        float m_new = fmaxf(m, logit);
        float e1 = __expf(m - m_new); // scale for previous accumulators
        float e2 = __expf(logit - m_new); // new contribution
        s = s * e1 + e2;

        // Update vector accumulator
        out_acc[0] = out_acc[0] * e1 + e2 * sh_v[d_base + 0];
        out_acc[1] = out_acc[1] * e1 + e2 * sh_v[d_base + 1];
        out_acc[2] = out_acc[2] * e1 + e2 * sh_v[d_base + 2];
        out_acc[3] = out_acc[3] * e1 + e2 * sh_v[d_base + 3];
        m = m_new;

        __syncthreads();
    }
    
    // Finalize: normalize output by s, write lse base-2
    float inv_s = 1.f / s;
    float out_final[4] = {
        out_acc[0] * inv_s, out_acc[1] * inv_s,
        out_acc[2] * inv_s, out_acc[3] * inv_s
    };
    
    // Store output
    store_f32x4_to_bf16(out + (b * kNumQOHeads + q_head) * kHeadDim, d_base, out_final);

    // lse = logsumexp(logits_scaled) base 2 = (log(s) + m) / ln(2)
    if (lane_id == 0) {
        constexpr float ln2 = 0.693147180559945309417232121458176568f;
        float lse_val = (logf(s) + m) / ln2;
        lse_out[b * kNumQOHeads + q_head] = lse_val;
    }
}

// Host wrapper: validate inputs, set up launch config, and launch kernel
void gqa_paged_decode_h32_kv8_d128_ps1_cuda(
    const torch::Tensor& q,         // [B, 32, 128] bfloat16
    const torch::Tensor& k_cache,   // [num_pages, 1, 8, 128] bfloat16
    const torch::Tensor& v_cache,   // [num_pages, 1, 8, 128] bfloat16
    const torch::Tensor& kv_indptr, // [B+1] int32
    const torch::Tensor& kv_indices,// [num_kv_indices] int32
    float sm_scale,
    torch::Tensor& output,        // [B, 32, 128] bfloat16
    torch::Tensor& lse            // [B, 32] float32
) {
    TORCH_CHECK(q.is_cuda(), "q must be CUDA tensor");
    TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA tensor");
    TORCH_CHECK(v_cache.is_cuda(), "v_cache must be CUDA tensor");
    TORCH_CHECK(kv_indptr.is_cuda(), "kv_indptr must be CUDA tensor");
    TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");
    TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
    TORCH_CHECK(lse.is_cuda(), "lse must be CUDA tensor");

    TORCH_CHECK(q.dim() == 3, "q must be [B, 32, 128]");
    TORCH_CHECK(q.size(1) == kNumQOHeads && q.size(2) == kHeadDim, "q must be [B, 32, 128]");
    TORCH_CHECK(q.scalar_type() == at::kBFloat16, "q must be bfloat16");
    
    TORCH_CHECK(k_cache.dim() == 4, "k_cache must be [num_pages, 1, 8, 128]");
    TORCH_CHECK(k_cache.size(1) == 1 && k_cache.size(2) == kNumKVHeads && k_cache.size(3) == kHeadDim, "k_cache must be [num_pages, 1, 8, 128]");
    TORCH_CHECK(k_cache.scalar_type() == at::kBFloat16, "k_cache must be bfloat16");
    
    TORCH_CHECK(v_cache.dim() == 4, "v_cache must be [num_pages, 1, 8, 128]");
    TORCH_CHECK(v_cache.size(1) == 1 && v_cache.size(2) == kNumKVHeads && v_cache.size(3) == kHeadDim, "v_cache must be [num_pages, 1, 8, 128]");
    TORCH_CHECK(v_cache.scalar_type() == at::kBFloat16, "v_cache must be bfloat16");
    
    TORCH_CHECK(kv_indptr.dim() == 1, "kv_indptr must be 1D");

C 解决方案生成提示

本附录包含用于生成数据集和案例研究中内核解决方案的提示。这些提示侧重于生成语法正确、可解析的代码,同时允许代理自主探索实现策略,而不是提供规范性的内核优化建议。基础提示用于初始内核提议,优化提示用于迭代改进。

C.1 Triton 基础提示

为{target_gpu} GPU生成一个优化的Triton内核

{definition}

Triton版本:3.3.1

要求:
- 为{target_gpu}架构编写干净、高效的Triton代码
- 使用现代Triton语法,具有正确的网格计算和语言特性
- 包含必要的导入(torch, triton, triton.language as tl)
- 实现规范中描述的确切功能
- 参考代码提供了数学规范但未经优化 - 你的Triton实现应在计算精度上与其匹配,同时提供高性能
- 使用定义的张量形状、数据类型和轴信息来指导内存访问模式和优化策略
- 针对{target_gpu} GPU特性(内存层次结构、计算单元等)进行优化

包装函数必须处理完整的设备管理:
- 如果需要,将CPU张量移动到GPU(当torch.cuda.is_available()时使用.to('cuda')或.cuda())
- 如果CUDA对GPU张量不可用,则引发明确的错误
- 使用GPU张量调用triton内核
- 将结果移回输入张量的原始设备
- 正确处理args和kwargs
- 保留原始张量设备,并为输出恢复它们

重要:仅使用有效的Python/Triton语法:
- 禁止十六进制浮点文字(0x1.234p5)- 使用十进制等价物
- 禁止C/CUDA特定语法 - 这是Python/Triton代码
- 所有代码必须是能通过ast.parse()的有效Python代码
- 暴露一个名为"run"的入口点函数,可以调用以执行内核
- 仅返回代码,不带解释或markdown格式

仅生成完整、可运行的代码 - 不会有框架添加设备处理包装代码。

生成实现:

C.2 Triton 优化提示

你正在为{target_gpu} GPU优化一个Triton内核。当前实现存在需要修复的问题。

原始规范:{definition}

当前实现状态:{trace_logs}

当前实现:{current_code}

优化策略:
1. 确保正确性:如果存在编译错误、运行时错误或不正确的输出,完全专注于修复这些问题
- 分析编译错误并修复语法/API用法
- 修复运行时错误,如形状不匹配、内存访问冲突
- 确保数值正确性与参考实现相匹配

  1. 优化性能:如果当前内核功能正确,则专注于性能优化
    • 优化{target_gpu}的内存访问模式
    • 调整块大小和网格维度
    • 使用适当的Triton语言特性进行矢量化
    • 最小化全局内存事务

优化实现的要求:

包装函数必须处理完整的设备管理:
- 如果需要,将CPU张量移动到GPU(当torch.cuda.is_available()时使用.to('cuda')或.cuda())
- 如果CUDA对GPU张量不可用,则引发明确的错误
- 使用GPU张量调用triton内核
- 将结果移回输入张量的原始设备
- 正确处理args和kwargs
- 保留原始张量设备,并为输出恢复它们

重要:仅使用有效的Python/Triton语法:
- 禁止十六进制浮点文字(0x1.234p5)- 使用十进制等价物
- 禁止C/CUDA特定语法 - 这是Python/Triton代码
- 所有代码必须是能通过ast.parse()的有效Python代码
- 暴露一个名为"run"的入口点函数,可以调用以执行内核
- 仅返回改进后的代码,不带解释或markdown格式

生成修正和优化的实现:

C.3 CUDA 基础提示

你是一个代码生成器。为以下规范生成一个为{target_gpu} GPU优化的CUDA内核实现。

规范:{definition}

要求:
- 为{target_gpu}架构编写干净、高效的CUDA C++代码
- 使用为{target_gpu}优化的正确CUDA语法和内存管理
- 实现规范中描述的确切功能
- 参考代码提供了数学规范但未经优化 - 你的CUDA实现应在计算精度上与其匹配,同时提供高性能
- 使用定义的张量形状、数据类型和轴信息来指导内存访问模式和优化策略
- 针对{target_gpu} GPU特性(内存层次结构、计算单元等)进行优化
- 对于固定的轴值,应针对这些常量进行特定优化,而不是通用情况

重要:以XML格式生成代码,包含且仅包含3个具有以下严格名称的文件:
<header_file name="kernel.h">
- 所有CUDA内核函数声明
- 主机函数声明
- 任何必要的结构/类型定义
- 包含保护和必要的头文件
</header_file>

<cuda_file name="http://kernel.cu">
- 所有__global__内核实现
- 所有__device__辅助函数
- CUDA特定的优化和内存模式
- 正确的错误检查和内存管理
</cuda_file>

<cpp_file name="main.cpp">
- 启动内核的主机函数
- 内存分配和数据传输管理
- 设备管理和错误处理
- 名为"run"的入口点函数,可以调用以执行实现
- 正确处理args和kwargs
- 将CPU数据移动到GPU,执行内核,并将结果返回到CPU
- 使用PYBIND11_MODULE包含PyTorch C++扩展绑定
- "run"函数必须通过绑定暴露给Python
- 包含PyTorch张量和CUDA指针之间正确的类型转换
- 包含所有必要的PyTorch头文件:#include <torch/extension.h>
</cpp_file>

代码生成指南:
- 使用适用于{target_gpu}的现代CUDA特性
- 优化内存合并并减少bank冲突
- 有效利用共享内存以重用数据
- 考虑占用率和寄存器使用
- 使用cudaGetLastError()实现正确的错误检查
- 为问题大小使用适当的网格和块维度
- 对频繁访问的只读数据使用常量内存
- 在"run"函数中对所有张量参数使用PyTorch张量API(torch::Tensor)
- 使用.data_ptr<T>()将PyTorch张量转换为CUDA指针,并使用适当的类型(例如,float, double, int)
- 确保正确的CUDA流同步和错误处理

生成实现:

C.4 CUDA 优化提示

你正在为{target_gpu} GPU优化一个CUDA内核。当前实现存在需要修复的问题。

原始规范:{definition}

当前实现状态:{trace_logs}

当前实现:{current_code}

优化策略:
1. 确保正确性:如果存在编译错误、运行时错误或不正确的输出,完全专注于修复这些问题
- 分析编译错误并修复语法/API用法
- 修复运行时错误,如形状不匹配、内存访问冲突、内核启动失败
- 确保数值正确性与参考实现相匹配
- 验证正确的CUDA内存管理和同步

  1. 优化性能:如果当前内核功能正确,则专注于性能优化
    • 为{target_gpu}优化内存访问模式和合并
    • 调整块大小和网格维度以最大化占用率
    • 有效利用共享内存以减少全局内存事务
    • 优化寄存器使用并最小化发散分支
    • 如果有益,考虑使用专门的库
    • 利用常量轴值进行编译时优化

优化实现的要求:

重要:以XML格式生成代码,包含且仅包含3个具有以下严格名称的文件:
<header_file name="kernel.h">
- 所有CUDA内核函数声明
- 主机函数声明
- 任何必要的结构/类型定义
- 包含保护和必要的头文件
</header_file>

<cuda_file name="http://kernel.cu">
- 所有__global__内核实现
- 所有__device__辅助函数
- CUDA特定的优化和内存模式
- 正确的错误检查和内存管理
</cuda_file>

<cpp_file name="main.cpp">
- 启动内核的主机函数
- 内存分配和数据传输管理
- 设备管理和错误处理
- 名为"run"的入口点函数,可以调用以执行实现
- 正确处理args和kwargs
- 将CPU数据移动到GPU,执行内核,并将结果返回到CPU
- 使用PYBIND11_MODULE包含PyTorch C++扩展绑定
- "run"函数必须通过绑定暴露给Python
- 包含PyTorch张量和CUDA指针之间正确的类型转换
- 包含所有必要的PyTorch头文件:#include <torch/extension.h>
</cpp_file>

代码生成指南:
- 使用适用于{target_gpu}的现代CUDA特性
- 优化内存合并并减少bank冲突
- 有效利用共享内存以重用数据
- 考虑占用率和寄存器使用
- 使用cudaGetLastError()实现正确的错误检查
- 为问题大小使用适当的网格和块维度
- 对频繁访问的只读数据使用常量内存
- 在"run"函数中对所有张量参数使用PyTorch张量API(torch::Tensor)
- 使用.data_ptr<T>()将PyTorch张量转换为CUDA指针,并使用适当的类型(例如,float, double, int)
- 确保正确的CUDA流同步和错误处理

生成修正和优化的实现: