Batching Helpers: Optimizing Loss Computation
Batching Helpers: Optimizing Loss Computation
Roman Schaffert, DevTech | AI Open Day 2025/05/30
议程 (Agenda)
- 动机 (Motivation)
- 概览 (Overview)
- 最小示例 (Minimum Example)
- 真实世界示例 (Real-world Example)
- AccLib 概览 (AccLib Overview)
- 核心信息 (Take Home Messages)
动机 (Motivation)
背景:低效的损失计算
- 背景: 在高级驾驶辅助系统(ADAS)中,损失计算效率低下。
- 例如,StreamPETR模型(批大小为8)的训练时间中,约有35%用于损失计算。
- 这主要是由于大量的CPU(Python)开销和较低的GPU利用率。
- 通常,这部分计算没有批处理(batching)或部分批处理。
- 目标: 实现易于使用的批处理功能。
- 处理非均匀大小的样本。
- 实现计算高效。
- 易于与现有的损失函数集成。
- 解决方案: 批处理助手 (Batching Helpers)
- 为非均匀大小的样本设计了
RaggedBatch格式。 - 构建了辅助函数以在典型用例中应用
RaggedBatch。 - 实现为一个Python包。
- 例如,在StreamPETR(adaptation WD-7)中,当批大小为8时,训练速度提升了1.22倍。
下图展示了典型的非批处理(迭代式)代价矩阵计算过程:
概览 (Overview)
批处理助手概览:简化损失计算的批处理
新颖的数据格式 RaggedBatch
- 用于批处理非均匀大小的样本。
- 数据存储为PyTorch张量,允许:
- 使用内置的PyTorch操作。
- 复用现有的自定义实现(例如损失函数)。
- 集成了基本功能:
- 例如,在非均匀维度上求平均值。
- 将数据拆分为单个样本。
高效的辅助函数
- 在
RaggedBatch格式上操作。 - 专为损失计算设计。
- 例如:
- 将张量(嵌套)合并到
RaggedBatch中。 - 索引写入 (indexed writing)。
- 部分功能通过 C++/CUDA 扩展实现。
RaggedBatch:一种用于非均匀批处理的新颖数据格式
- 数据结构:
tensor:- 包含数据。
- 使用填充值(filler values)来获得统一的尺寸。
- 在填充值之前是有效值。
- 形状允许多个维度。
mask:- 指示有效元素。
- 形状包括批处理维度和非均匀维度。
sample_sizes:- 对于批处理中的每个元素,包含非均匀维度的大小。
- 形状为批处理维度。
- 与
mask存在冗余,但在某些操作中更受欢迎。
RaggedBatch具有3种不同功能的维度:- 批处理维度 (batch dimensions): 数量 ≥ 1,用于迭代样本、解码器层、相机图像等。
- 非均匀维度 (non-uniform dimension): 数量 = 1,用于迭代样本中的对象。
- 数据维度 (data dimensions): 数量 ≥ 0,取决于数据表示。
批处理助手概览:批处理图示
代价矩阵计算 (Cost Matrix Computation)
- 输入用于真实值(GT)到预测的对象匹配(例如,匈牙利算法)。
- 每个元素代表一个 GT-预测 对。
- 无批处理:
- 每个样本被单独、迭代地处理。
- 导致开销。
- 有批处理:
- 对所有代价矩阵执行单步处理。
- 处理的数据更多(包括填充元素)。
- 开销的减少远大于数据量增加带来的影响。
损失计算 (Loss Computation)
- 索引GT和预测以获得成对的对应关系。
- 使用索引数据计算成对损失。
- 无批处理:
- 每个样本被单独、迭代地处理。
- 导致开销。
- 有批处理:
- 对所有样本执行单步处理。
- 一次性处理整个批次。
最小示例 (Minimum Example)
批处理助手 - 最小示例:完整的损失计算工作流
该工作流分为三个主要步骤,并附有高级API调用的代码示例。
步骤1:数据转换 (Data Conversion)
- 输入: 数据样本列表 (data sample list)。
- 操作: 将真实值列表(GT list)转换为 RaggedBatch 格式。
- 输出: RaggedBatch 格式的数据。
步骤2:对象匹配 (Object Matching)
- 输入:
- GT 对象 (RaggedBatch)
- 预测 (torch.Tensor)
- 操作: 匹配GT和预测对象。
- 输出:
- GT 的索引 (RaggedBatch)
- 预测的索引 (RaggedBatch)
步骤3:损失计算 (Loss Computation)
- 输入:
- GT (RaggedBatch)
- 预测 (torch.Tensor)
- 匹配的索引对 (RaggedBatch)
- 输出:
- 每个样本的损失 (torch.Tensor)
步骤2:对象匹配 (Object Matching)
- 成本函数实现“照旧”
- 需要支持批处理。
- 但不需要直接支持非均匀批处理。
- 创建
RaggedBatch sample_sizes与参考RaggedBatch相同。- 目标(非均匀)维度不同。
- 在GT输入中,非均匀维度
dim==1。 - 在成本矩阵中,
dim==1迭代预测,dim==2迭代GTs。
步骤2:对象匹配 (续)
- 匹配本身在CPU上进行且非批处理
- 首先将数据传输到CPU。
- 拆分到单个样本
- 执行匹配(循环处理样本)
- 将获得的数据合并到批次中
- 移动到GPU:注意,对于最终结果,将小张量复制到GPU然后合并,比在GPU上合并更高效。
步骤3:损失计算 - 分类与回归损失
- 获取匹配的GT和预测
- 损失计算
- 是批处理的。
- 在每个对象的基础上执行。
- 不区分有效对象和填充物。
- 将损失结果封装为
RaggedBatch实例 sample_sizes与匹配的GT数据相同。- 共享
mask和sample_sizes张量。 - 对对象求和,忽略填充值
步骤3:损失计算 - 存在损失 (Existence Loss)
- 为现有预测元素获取掩码
- 存在损失计算
torch.mean()不能直接使用,因为存在填充值。shorthand用于处理RaggedBatch的张量(可选地,掩码也可以作为输入)。- 为非存在对象的匹配对象设置权重。
- 存在损失的计算是基于每个对象进行的。
- 最终损失
- 对所有损失求和并返回。
torch.sum()可以直接使用。
真实世界示例 (Real-world Example)
StreamPETR 优化
概述
- 选择 StreamPETR 的动机 [1]
- ADAS 中一种常见的训练/损失计算方法
- 开源
- 配置
- 基础: stream_petr_r50_flash_704_bs2_seq_24e
- 批量大小: 8, 原始为 2
- 数据加载器工作线程数: 8, 原始为 4
- 单 GPU 训练
- 注意:使用了更大的批量大小
- 突显了批处理优化的效果
- 代表了感兴趣的用例
[1] S. Wang http://et.al., "Exploring Object-Centric Temporal Modeling for Efficient Multi-View 3D Object Detection", arXiv:2303.11926 [http://cs.CV]
代码: https://github.com/exiawsh/StreamPETR
优化潜力
- HungarianAssigner3D & HungarianAssigner2D: 匹配器在每个样本上独立操作
- 成本矩阵计算 (实际匹配的先决条件) → 可优化
- 匹配本身 (CPU上的SciPy实现) → 不变
- StreamPETRHead
- 损失计算本身是按样本批处理的
- 也可以在解码器层上进行批处理 → 已优化 (使用多批次维度)
- 可以使用批处理分配器 → 已优化
- FocalHead
- 损失计算本身是按样本和相机图像批处理的
- 可以使用批处理分配器 → 已优化
- 注意:两个基线和优化实现都使用自定义 CUDA 实现来生成高斯热图
- 简化了批处理热图的生成
- 确保了公平的比较
优化
- 虽然实现很大程度上遵循了我们的示例,但仍存在一些显著差异
- 在解码器层上进行批处理
- 使用 StreamPETRHead (& HungarianMatcher3D) 中的多个批处理维度
- 提供了易于操作批处理维度的功能
- 在解码器层上进行批处理
gt_bboxes = gt_bboxes.unsqueeze_batch_dim(0)
gt_labels = gt_labels.repeat_samples(num_dec_layers, 0)
gt_labels = gt_labels.unsqueeze_batch_dim(0)
* 通过将批处理维度与其他维度分开,可以简化操作 (只需重塑,无需考虑批处理维度)
* **减少一个批处理维度上的损失 (跨样本)**
* **使用另一个辅助函数:使用索引对直接从GT匹配到预测**
- 对使用预先实现的非批处理损失函数的调整
- 不要在损失函数内部进行归约
- 而是编写用于重塑和归约的包装器
实验设置
- 硬件
- 训练进行了1000次迭代 (其中50次为预热迭代)
*仅供技术讨论
运行时评估
- 基线: StreamPETR 经过初步优化
torch.nan_to_num()移至 GPU 以进行更公平的比较- 无初步优化的运行时:
- 训练迭代: 895 ms
- 损失计算: 309 ms
- 优化: 在基线之上使用批处理助手
AccLib 概览 (AccLib Overview)
基本信息
- 由我们团队开发的 Python 包
- 促进 ADAS 中高效的 AI 实现,目前专注于训练
- 库包含
- 常用操作的高效实现,例如匈牙利匹配*,高斯热图生成
- 数据加载和预处理助手
- 优化助手,例如批处理助手
- 计划在几个月内开源
如果您想尝试/使用我们的一些模块,请告诉我们!
无需等待开源版本!
* 源代码可能不属于 AccLib 版本的一部分
模块:GPU 上的高效匈牙利匹配
动机
- 对损失计算至关重要,例如在目标检测中
- CPU 实现引入了开销
- 对于更大的批量大小效率较低
- 数据需要在 GPU 和 CPU 之间复制
匈牙利匹配基础
- 基于成对权重在两个集合之间进行元素配对匹配
-
目标检测损失计算中的重要步骤
-
输入:成本矩阵
- 两个集合中元素对之间的匹配成本 (例如,真实值 GT 和预测对象)
- 对应于连接图中每个节点到另一集合中每个节点的边的权重
-
任务
- 对于矩阵中的每一列,选择一个唯一的行
- 使得所选矩阵元素的累积权重得到优化
- 例如,在目标检测训练的背景下:
- 为每个 GT 对象找到一个唯一的预测匹配
- 最大化 GT 和匹配预测之间的整体相似性
-
输出
- 为每列选择的行
- 等效于:列和行的索引对
GPU 上的匈牙利匹配
-
smatch
- 匹配少量元素 (128 x 128)
- 计划进一步优化
-
mmatch (正在进行中)
- 匹配更多数量的元素
- 专为矩形成本矩阵设计
- 较小维度:最多 128 或 256 个条目
- 较大维度:512 个或更多条目
与文献的比较:smatch
*仅供技术讨论
比较:smatch 与 SciPy
*仅供技术讨论
模块:高斯热图生成
-
主要用例: 目标中心检测
- Focal loss 应用于高斯热图
- 直觉:
- 预测目标中心
- 确保损失不会过于稀疏
- 例如,在 BEVFusion [2] 中,实现了 1.17 倍的训练加速
-
实现
- 提供高效的 GPU 实现
- 与 RaggedBatch 兼容
[2] Z. Liu et. al., "BEVFusion: Multi-Task Multi-Sensor Fusion with Unified Bird's-Eye View Representation", arXiv:2205.13542 [http://cs.CV]
代码: https://github.com/mit-han-lab/bevfusion
模块:折线采样 (Polyline Sampling)
-
功能
- 给定一个由点序列 pᵢ 构成的折线 S(pᵢ),
- 获取折线上沿线距离为 d 的点 qⱼ = S(pᵢ)(dⱼ)
* 示例用例: 用于损失计算的车道 GT 预处理
* 实现
* 提供 CPU 实现
* 计划实现 GPU 版本
模块:优化测试工具
- 用于在优化过程中轻松进行性能测量及其他测试的工具。
- 包括:
- Stopwatch:轻松、轻量级地测量平均运行时间。
- NVTX 范围包装器:包装 NVTX 范围并提供附加功能。
- 在范围的
pop和push操作时轻松启用/禁用 CUDA 同步。 - 检查
push-pop不匹配(用于调试,会增加开销)。 - 待扩展...
- 在范围的
- 设计:
- 易于启用/禁用和配置(可全局完成)。
- 可在代码的任何部分使用,状态会自动保持一致(工具是单例对象)。
- 禁用时开销极小(例如,调用空函数)。
模块:高效的数据加载和预处理实现
- 包含
- 图像训练管道框架 (Image Training Pipeline Framework)
- 视频加载插件和管道 (Video Loading Plugin & Pipeline)
- 单独呈现
核心信息 (Take Home Messages)
-
批处理助手 (Batching Helpers) 包简化了 ADAS 中用于损失计算的高效批处理
- 包含
- RaggedBatch:用于非均匀样本大小的数据格式。
- 用于常用操作的辅助函数。
- 部分用 C++/CUDA 实现。
- 专注于基于对象/目标的损失计算。
- 易于使用
- 简化了批处理匹配和损失计算的实现。
- 允许重用现有实现(PyTorch 算子、损失函数等)。
- 包含
-
使用批处理助手的示例
- 一个最小化的实现。
- 真实世界用例:StreamPETR。
-
运行时优化分析 (基于 StreamPETR)
- 损失计算加速:x 4.3
- 整体训练加速:x 1.2
-
AccLib 概述:促进 ADAS 中高效的人工智能实现