Batching Helpers: Optimizing Loss Computation

Roman Schaffert, DevTech | AI Open Day 2025/05/30

议程 (Agenda)

  • 动机 (Motivation)
  • 概览 (Overview)
  • 最小示例 (Minimum Example)
  • 真实世界示例 (Real-world Example)
  • AccLib 概览 (AccLib Overview)
  • 核心信息 (Take Home Messages)

动机 (Motivation)

背景:低效的损失计算

  • 背景: 在高级驾驶辅助系统(ADAS)中,损失计算效率低下。
  • 例如,StreamPETR模型(批大小为8)的训练时间中,约有35%用于损失计算。
  • 这主要是由于大量的CPU(Python)开销和较低的GPU利用率。
  • 通常,这部分计算没有批处理(batching)或部分批处理。
  • 目标: 实现易于使用的批处理功能。
  • 处理非均匀大小的样本。
  • 实现计算高效。
  • 易于与现有的损失函数集成。
  • 解决方案: 批处理助手 (Batching Helpers)
  • 为非均匀大小的样本设计了 RaggedBatch 格式。
  • 构建了辅助函数以在典型用例中应用 RaggedBatch
  • 实现为一个Python包。
  • 例如,在StreamPETR(adaptation WD-7)中,当批大小为8时,训练速度提升了1.22倍。

下图展示了典型的非批处理(迭代式)代价矩阵计算过程:
Typical non-batched (iterative) cost matrices computation for samples 0, 1, 2 on Page 4

概览 (Overview)

批处理助手概览:简化损失计算的批处理

新颖的数据格式 RaggedBatch

  • 用于批处理非均匀大小的样本。
  • 数据存储为PyTorch张量,允许:
  • 使用内置的PyTorch操作。
  • 复用现有的自定义实现(例如损失函数)。
  • 集成了基本功能:
  • 例如,在非均匀维度上求平均值。
  • 将数据拆分为单个样本。
Diagram of RaggedBatch structure on Page 6
Diagram of RaggedBatch structure on Page 6

高效的辅助函数

  • RaggedBatch 格式上操作。
  • 专为损失计算设计。
  • 例如:
  • 将张量(嵌套)合并到 RaggedBatch 中。
  • 索引写入 (indexed writing)。
  • 部分功能通过 C++/CUDA 扩展实现。
Diagram of an efficient helper function operation on Page 6
Diagram of an efficient helper function operation on Page 6

RaggedBatch:一种用于非均匀批处理的新颖数据格式

  • 数据结构:
  • tensor:
    • 包含数据。
    • 使用填充值(filler values)来获得统一的尺寸。
    • 在填充值之前是有效值。
    • 形状允许多个维度。
  • mask:
    • 指示有效元素。
    • 形状包括批处理维度和非均匀维度。
  • sample_sizes:
    • 对于批处理中的每个元素,包含非均匀维度的大小。
    • 形状为批处理维度。
    • mask 存在冗余,但在某些操作中更受欢迎。
Detailed structure of RaggedBatch on Page 7
Detailed structure of RaggedBatch on Page 7
  • RaggedBatch 具有3种不同功能的维度:
  • 批处理维度 (batch dimensions): 数量 ≥ 1,用于迭代样本、解码器层、相机图像等。
  • 非均匀维度 (non-uniform dimension): 数量 = 1,用于迭代样本中的对象。
  • 数据维度 (data dimensions): 数量 ≥ 0,取决于数据表示。

批处理助手概览:批处理图示

代价矩阵计算 (Cost Matrix Computation)

  • 输入用于真实值(GT)到预测的对象匹配(例如,匈牙利算法)。
  • 每个元素代表一个 GT-预测 对。
Comparison of cost matrix computation with and without batching on Page 8
Comparison of cost matrix computation with and without batching on Page 8
  • 无批处理:
  • 每个样本被单独、迭代地处理。
  • 导致开销。
  • 有批处理:
  • 对所有代价矩阵执行单步处理。
  • 处理的数据更多(包括填充元素)。
  • 开销的减少远大于数据量增加带来的影响。

损失计算 (Loss Computation)

  • 索引GT和预测以获得成对的对应关系。
  • 使用索引数据计算成对损失。
Comparison of loss computation with and without batching on Page 9
Comparison of loss computation with and without batching on Page 9
  • 无批处理:
  • 每个样本被单独、迭代地处理。
  • 导致开销。
  • 有批处理:
  • 对所有样本执行单步处理。
  • 一次性处理整个批次。

最小示例 (Minimum Example)

批处理助手 - 最小示例:完整的损失计算工作流

该工作流分为三个主要步骤,并附有高级API调用的代码示例。

Complete Loss Computation Workflow on Page 11
Complete Loss Computation Workflow on Page 11

步骤1:数据转换 (Data Conversion)
- 输入: 数据样本列表 (data sample list)。
- 操作: 将真实值列表(GT list)转换为 RaggedBatch 格式。
- 输出: RaggedBatch 格式的数据。

步骤2:对象匹配 (Object Matching)
- 输入:
- GT 对象 (RaggedBatch)
- 预测 (torch.Tensor)
- 操作: 匹配GT和预测对象。
- 输出:
- GT 的索引 (RaggedBatch)
- 预测的索引 (RaggedBatch)

步骤3:损失计算 (Loss Computation)
- 输入:
- GT (RaggedBatch)
- 预测 (torch.Tensor)
- 匹配的索引对 (RaggedBatch)
- 输出:
- 每个样本的损失 (torch.Tensor)

步骤2:对象匹配 (Object Matching)

  • 成本函数实现“照旧”
  • 需要支持批处理。
  • 但不需要直接支持非均匀批处理。
  • 创建 RaggedBatch
  • sample_sizes 与参考 RaggedBatch 相同。
  • 目标(非均匀)维度不同。
  • 在GT输入中,非均匀维度 dim==1
  • 在成本矩阵中,dim==1 迭代预测,dim==2 迭代GTs。

步骤2:对象匹配 (续)

  • 匹配本身在CPU上进行且非批处理
  • 首先将数据传输到CPU。
  • 拆分到单个样本
  • 执行匹配(循环处理样本)
  • 将获得的数据合并到批次中
  • 移动到GPU:注意,对于最终结果,将小张量复制到GPU然后合并,比在GPU上合并更高效。

步骤3:损失计算 - 分类与回归损失

  • 获取匹配的GT和预测
  • 损失计算
  • 是批处理的。
  • 在每个对象的基础上执行。
  • 不区分有效对象和填充物。
  • 将损失结果封装为RaggedBatch实例
  • sample_sizes 与匹配的GT数据相同。
  • 共享 masksample_sizes 张量。
  • 对对象求和,忽略填充值

步骤3:损失计算 - 存在损失 (Existence Loss)

  • 为现有预测元素获取掩码
  • 存在损失计算
  • torch.mean() 不能直接使用,因为存在填充值。
  • shorthand 用于处理 RaggedBatch 的张量(可选地,掩码也可以作为输入)。
  • 为非存在对象的匹配对象设置权重。
  • 存在损失的计算是基于每个对象进行的。
  • 最终损失
  • 对所有损失求和并返回。
  • torch.sum() 可以直接使用。
Code example for existence loss computation on Page 15
Code example for existence loss computation on Page 15

真实世界示例 (Real-world Example)

StreamPETR 优化

概述

  • 选择 StreamPETR 的动机 [1]
    • ADAS 中一种常见的训练/损失计算方法
    • 开源
  • 配置
    • 基础: stream_petr_r50_flash_704_bs2_seq_24e
    • 批量大小: 8, 原始为 2
    • 数据加载器工作线程数: 8, 原始为 4
    • 单 GPU 训练
  • 注意:使用了更大的批量大小
    • 突显了批处理优化的效果
    • 代表了感兴趣的用例

[1] S. Wang http://et.al., "Exploring Object-Centric Temporal Modeling for Efficient Multi-View 3D Object Detection", arXiv:2303.11926 [http://cs.CV]
代码: https://github.com/exiawsh/StreamPETR


优化潜力

  • HungarianAssigner3D & HungarianAssigner2D: 匹配器在每个样本上独立操作
    • 成本矩阵计算 (实际匹配的先决条件) → 可优化
    • 匹配本身 (CPU上的SciPy实现) → 不变
  • StreamPETRHead
    • 损失计算本身是按样本批处理的
    • 也可以在解码器层上进行批处理 → 已优化 (使用多批次维度)
    • 可以使用批处理分配器 → 已优化
  • FocalHead
    • 损失计算本身是按样本和相机图像批处理的
    • 可以使用批处理分配器 → 已优化
    • 注意:两个基线和优化实现都使用自定义 CUDA 实现来生成高斯热图
      • 简化了批处理热图的生成
      • 确保了公平的比较

优化

  • 虽然实现很大程度上遵循了我们的示例,但仍存在一些显著差异
    • 在解码器层上进行批处理
      • 使用 StreamPETRHead (& HungarianMatcher3D) 中的多个批处理维度
      • 提供了易于操作批处理维度的功能
gt_bboxes = gt_bboxes.unsqueeze_batch_dim(0)
gt_labels = gt_labels.repeat_samples(num_dec_layers, 0)
gt_labels = gt_labels.unsqueeze_batch_dim(0)
* 通过将批处理维度与其他维度分开,可以简化操作 (只需重塑,无需考虑批处理维度) * **减少一个批处理维度上的损失 (跨样本)** * **使用另一个辅助函数:使用索引对直接从GT匹配到预测** Page 19
  • 对使用预先实现的非批处理损失函数的调整
    • 不要在损失函数内部进行归约
    • 而是编写用于重塑和归约的包装器

实验设置

  • 硬件
    Page 20
  • 训练进行了1000次迭代 (其中50次为预热迭代)

*仅供技术讨论


运行时评估

  • 基线: StreamPETR 经过初步优化
    • torch.nan_to_num() 移至 GPU 以进行更公平的比较
    • 无初步优化的运行时:
      • 训练迭代: 895 ms
      • 损失计算: 309 ms
  • 优化: 在基线之上使用批处理助手
Page 21
Page 21

AccLib 概览 (AccLib Overview)

基本信息

  • 由我们团队开发的 Python 包
  • 促进 ADAS 中高效的 AI 实现,目前专注于训练
  • 库包含
    • 常用操作的高效实现,例如匈牙利匹配*,高斯热图生成
    • 数据加载和预处理助手
    • 优化助手,例如批处理助手
  • 计划在几个月内开源

如果您想尝试/使用我们的一些模块,请告诉我们!
无需等待开源版本!

* 源代码可能不属于 AccLib 版本的一部分


模块:GPU 上的高效匈牙利匹配

动机

  • 对损失计算至关重要,例如在目标检测中
  • CPU 实现引入了开销
    • 对于更大的批量大小效率较低
    • 数据需要在 GPU 和 CPU 之间复制
Page 24
Page 24

匈牙利匹配基础

  • 基于成对权重在两个集合之间进行元素配对匹配
  • 目标检测损失计算中的重要步骤

  • 输入:成本矩阵

    • 两个集合中元素对之间的匹配成本 (例如,真实值 GT 和预测对象)
    • 对应于连接图中每个节点到另一集合中每个节点的边的权重
  • 任务

    • 对于矩阵中的每一列,选择一个唯一的行
    • 使得所选矩阵元素的累积权重得到优化
    • 例如,在目标检测训练的背景下:
      • 为每个 GT 对象找到一个唯一的预测匹配
      • 最大化 GT 和匹配预测之间的整体相似性
  • 输出

    • 为每列选择的行
    • 等效于:列和行的索引对
Page 25
Page 25

GPU 上的匈牙利匹配

  • smatch

    • 匹配少量元素 (128 x 128)
    • 计划进一步优化
  • mmatch (正在进行中)

    • 匹配更多数量的元素
    • 专为矩形成本矩阵设计
      • 较小维度:最多 128 或 256 个条目
      • 较大维度:512 个或更多条目

与文献的比较:smatch
Page 27

*仅供技术讨论

比较:smatch 与 SciPy
Page 28

*仅供技术讨论


模块:高斯热图生成

  • 主要用例: 目标中心检测

    • Focal loss 应用于高斯热图
    • 直觉:
      • 预测目标中心
      • 确保损失不会过于稀疏
    • 例如,在 BEVFusion [2] 中,实现了 1.17 倍的训练加速
  • 实现

    • 提供高效的 GPU 实现
    • 与 RaggedBatch 兼容
Page 29
Page 29

[2] Z. Liu et. al., "BEVFusion: Multi-Task Multi-Sensor Fusion with Unified Bird's-Eye View Representation", arXiv:2205.13542 [http://cs.CV]
代码: https://github.com/mit-han-lab/bevfusion


模块:折线采样 (Polyline Sampling)

  • 功能

    • 给定一个由点序列 pᵢ 构成的折线 S(pᵢ),
    • 获取折线上沿线距离为 d 的点 qⱼ = S(pᵢ)(dⱼ)

    Page 30 Diagram
    * 示例用例: 用于损失计算的车道 GT 预处理
    * 实现
    * 提供 CPU 实现
    * 计划实现 GPU 版本

模块:优化测试工具

  • 用于在优化过程中轻松进行性能测量及其他测试的工具。
  • 包括
    • Stopwatch:轻松、轻量级地测量平均运行时间。
    • NVTX 范围包装器:包装 NVTX 范围并提供附加功能。
      • 在范围的 poppush 操作时轻松启用/禁用 CUDA 同步。
      • 检查 push-pop 不匹配(用于调试,会增加开销)。
      • 待扩展...
  • 设计
    • 易于启用/禁用和配置(可全局完成)。
    • 可在代码的任何部分使用,状态会自动保持一致(工具是单例对象)。
    • 禁用时开销极小(例如,调用空函数)。

模块:高效的数据加载和预处理实现

  • 包含
    • 图像训练管道框架 (Image Training Pipeline Framework)
    • 视频加载插件和管道 (Video Loading Plugin & Pipeline)
  • 单独呈现

核心信息 (Take Home Messages)

  • 批处理助手 (Batching Helpers) 包简化了 ADAS 中用于损失计算的高效批处理

    • 包含
      • RaggedBatch:用于非均匀样本大小的数据格式。
      • 用于常用操作的辅助函数。
      • 部分用 C++/CUDA 实现。
      • 专注于基于对象/目标的损失计算。
    • 易于使用
      • 简化了批处理匹配和损失计算的实现。
      • 允许重用现有实现(PyTorch 算子、损失函数等)。
  • 使用批处理助手的示例

    • 一个最小化的实现。
    • 真实世界用例:StreamPETR
  • 运行时优化分析 (基于 StreamPETR)

    • 损失计算加速:x 4.3
    • 整体训练加速:x 1.2
  • AccLib 概述:促进 ADAS 中高效的人工智能实现

感谢!