TORCH.FX: PRACTICAL PROGRAM CAPTURE AND TRANSFORMATION FOR DEEP LEARNING IN PYTHON

文章标题:TORCH.FX:Python中深度学习的实用程序捕获与转换
作者/机构:James K Reed, Zachary DeVito, Horace He, Ansley Ussery, Jason Ansel (均为 Facebook AI)

A1 主要贡献

本文旨在解决现代深度学习框架中的一个核心矛盾:即时执行(eager execution)模式虽然提升了开发效率和用户体验,但却牺牲了对程序结构的访问能力,而这种能力对于性能优化、可视化、分析和硬件集成等高级转换至关重要。为了在即时模式框架(如PyTorch)中重新获得这种能力,需要一种程序捕获机制。然而,现有的系统(如TorchScript)为了完全忠实地模拟Python的复杂语义(包括可变状态、控制流、复杂数据类型),其捕获技术和生成的中间表示(IR)都变得异常复杂,给转换(transform)的编写带来了巨大困难。

本文提出,可以通过专注于深度学习的典型用例(大多数神经网络模型的高层有向无环图(DAG)结构)而非长尾的复杂情况,来设计一个更简单、更高效的程序捕获与转换框架。基于这一理念,本文介绍了torch.fx,一个完全用Python编写的、为PyTorch设计的程序捕获与转换库,其核心目标是为机器学习从业者提供极高的开发生产力。

本文的主要贡献如下:

  1. 实用性分析:对深度学习程序中重要的程序捕获与转换特性进行了实用性分析。
  2. 纯Python程序捕获库:实现了一个纯Python的程序捕获库,该库可被定制以捕获不同层次的程序细节。
  3. 简单的6指令IR:提出了一种仅包含6个指令的简单中间表示(IR),其设计重点在于易于理解和进行静态分析。
  4. 代码生成系统:构建了一个代码生成系统,能够将转换后的代码无缝地返回到宿主语言(Python)的生态系统中。
  5. 案例研究:展示了torch.fx在实践中如何被用于性能优化、程序分析、设备适配(device lowering)等场景,实现了PyTorch生态系统中以前难以完成的工作流。

A3 背景知识与设计原则

背景知识

程序捕获、特化与IR设计的权衡。无论是即时模式还是图模式框架,在捕获和转换程序时都必须在程序结构的捕获、程序的特化(specialization)以及中间表示(IR)的设计之间做出选择。这些选择共同决定了框架能表示的程序范围、编写转换的难易程度以及转换后程序的性能。通常,为了支持更多程序并实现高性能,需要更复杂的捕获框架和IR,这反过来又使得转换的编写更加困难。

2.1 捕获程序结构

2.2 程序特化

2.3 中间表示(IR)设计

设计原则

现有框架的设计大多倾向于支持更广泛的深度学习程序,但牺牲了实现的简洁性。当捕获的程序是运行的唯一方式时,高保真度至关重要。但PyTorch主要作为即时执行框架使用,程序捕获仅用于特定转换,无需对整个程序都有效。此外,目标用户是机器学习从业者,他们更习惯使用Python而非编译器设计。

通过为典型的深度学习模型而非长尾用例进行设计,可以创建一个更易于使用和实现的框架。torch.fx的设计原则体现了这一理念:

A2 方法细节

TORCH.FX 概述

torch.fx采用符号追踪来捕获程序,使用一个简单的包含6个指令且基于Python的IR来表示它们,并从IR重新生成Python代码来执行。为避免JIT特化带来的重捕获复杂性,torch.fx本身不尝试特化程序,而是依赖于转换过程来决定需要执行何种特化。符号追踪过程是可配置的,用户可以定制以处理更特殊的用例。

from torch.fx import Graph
def replace_activation(g: Graph, old, new):
    for n in g.nodes:
        if n.op == 'call_function' and n.target == old:
            # create IR to call new activate
            with g.inserting_after(n):
                new_n = g.call_function(new, n.args)
            n.replace_all_uses_with(new_n)
            g.erase_node(n)
            # or for this simplified case: 'n.target = new'

replace_activation(traced.graph, torch.relu, torch.nn.functional.gelu)
traced.recompile()
图1. torch.fx使用符号追踪将程序捕获到一个简单的IR中,并从该IR生成Python代码。图2. 变换,如此处替换激活函数的变换,是直接用Python编写的。

4.1 程序捕获

4.2 中间表示

4.3 源码到源码的转换

设计决策

torch.fx融合并扩展了先前工作中的方法,提供了一个易于使用、实现简单且可配置的库。

5.1 符号追踪

5.2 可配置的程序捕获

5.3 预先(AoT)捕获而不进行特化

5.4 基于Python的IR和变换

5.5 IR内部无控制流

def loop_shapes(x, itr):
  # x is an input tensor of size [1, N]
  for _ in range(itr):
    x = torch.cat((x, x), dim=0)

  # Depending on the number of loop iterations, x may have an
  # arbitrary leading dimension i.e. x \in [*dynamic*, N]
  return x

IR本身不包含控制流,并不妨碍变换在更大模型中的基本块子图上工作;如何组合这些子图的细节留给变换的编写者或用户来决定。

5.6 函数式图与有状态模块

A4 实验环境

A5 实验结果

6.1 IR 复杂度

6.2 性能优化

6.3 程序分析

torch.fx已被应用于多种程序分析场景:

6.4 设备和运行时导出/编译

A6 结论

本文介绍了torch.fx,一个纯Python的系统,用于捕获和转换PyTorch程序。通过分析相关系统(如控制流、可变性、数据模型)的复杂性来源,本文展示了torch.fx如何通过专注于常见用例和提供可定制性来避免这些复杂性。通过对优化、分析和设备下沉等多个用例的研究,本文证明了torch.fx的API设计如何成功地实现了这些功能。

A7 附录

A. TORCH.FX 节点语义

A.1 操作码(Opcode)含义

下表描述了torch.fx中每个Nodeopcode的含义。

A.2 args/kwargs 行为

下表描述了不同opcodeargskwargs字段的预期行为。

B 量化评估数值数据

下表为第6.2.1节量化实验的详细运行时间数据(单位:秒)。

C 融合评估数值数据

下表为第6.2.2节融合实验的详细运行时间数据(单位:秒)。

D TensorRT评估数值数据

下表为第6.4节TensorRT实验的详细运行时间数据(单位:秒)。