Batching Helpers: Optimizing Loss Computation

Roman Schaffert, DevTech | AI Open Day 2025/05/30

议程 (Agenda)

动机 (Motivation)

背景:低效的损失计算

下图展示了典型的非批处理(迭代式)代价矩阵计算过程:
Typical non-batched (iterative) cost matrices computation for samples 0, 1, 2 on Page 4

概览 (Overview)

批处理助手概览:简化损失计算的批处理

新颖的数据格式 RaggedBatch

Diagram of RaggedBatch structure on Page 6
Diagram of RaggedBatch structure on Page 6

高效的辅助函数

Diagram of an efficient helper function operation on Page 6
Diagram of an efficient helper function operation on Page 6

RaggedBatch:一种用于非均匀批处理的新颖数据格式

Detailed structure of RaggedBatch on Page 7
Detailed structure of RaggedBatch on Page 7

批处理助手概览:批处理图示

代价矩阵计算 (Cost Matrix Computation)

Comparison of cost matrix computation with and without batching on Page 8
Comparison of cost matrix computation with and without batching on Page 8

损失计算 (Loss Computation)

Comparison of loss computation with and without batching on Page 9
Comparison of loss computation with and without batching on Page 9

最小示例 (Minimum Example)

批处理助手 - 最小示例:完整的损失计算工作流

该工作流分为三个主要步骤,并附有高级API调用的代码示例。

Complete Loss Computation Workflow on Page 11
Complete Loss Computation Workflow on Page 11

步骤1:数据转换 (Data Conversion)
- 输入: 数据样本列表 (data sample list)。
- 操作: 将真实值列表(GT list)转换为 RaggedBatch 格式。
- 输出: RaggedBatch 格式的数据。

步骤2:对象匹配 (Object Matching)
- 输入:
- GT 对象 (RaggedBatch)
- 预测 (torch.Tensor)
- 操作: 匹配GT和预测对象。
- 输出:
- GT 的索引 (RaggedBatch)
- 预测的索引 (RaggedBatch)

步骤3:损失计算 (Loss Computation)
- 输入:
- GT (RaggedBatch)
- 预测 (torch.Tensor)
- 匹配的索引对 (RaggedBatch)
- 输出:
- 每个样本的损失 (torch.Tensor)

步骤2:对象匹配 (Object Matching)

步骤2:对象匹配 (续)

步骤3:损失计算 - 分类与回归损失

步骤3:损失计算 - 存在损失 (Existence Loss)

Code example for existence loss computation on Page 15
Code example for existence loss computation on Page 15

真实世界示例 (Real-world Example)

StreamPETR 优化

概述

[1] S. Wang http://et.al., "Exploring Object-Centric Temporal Modeling for Efficient Multi-View 3D Object Detection", arXiv:2303.11926 [http://cs.CV]
代码: https://github.com/exiawsh/StreamPETR


优化潜力


优化

gt_bboxes = gt_bboxes.unsqueeze_batch_dim(0)
gt_labels = gt_labels.repeat_samples(num_dec_layers, 0)
gt_labels = gt_labels.unsqueeze_batch_dim(0)
* 通过将批处理维度与其他维度分开,可以简化操作 (只需重塑,无需考虑批处理维度) * **减少一个批处理维度上的损失 (跨样本)** * **使用另一个辅助函数:使用索引对直接从GT匹配到预测** Page 19

实验设置

*仅供技术讨论


运行时评估

Page 21
Page 21

AccLib 概览 (AccLib Overview)

基本信息

如果您想尝试/使用我们的一些模块,请告诉我们!
无需等待开源版本!

* 源代码可能不属于 AccLib 版本的一部分


模块:GPU 上的高效匈牙利匹配

动机

Page 24
Page 24

匈牙利匹配基础

Page 25
Page 25

GPU 上的匈牙利匹配

与文献的比较:smatch
Page 27

*仅供技术讨论

比较:smatch 与 SciPy
Page 28

*仅供技术讨论


模块:高斯热图生成

Page 29
Page 29

[2] Z. Liu et. al., "BEVFusion: Multi-Task Multi-Sensor Fusion with Unified Bird's-Eye View Representation", arXiv:2205.13542 [http://cs.CV]
代码: https://github.com/mit-han-lab/bevfusion


模块:折线采样 (Polyline Sampling)

模块:优化测试工具

模块:高效的数据加载和预处理实现

核心信息 (Take Home Messages)

感谢!