文章标题:TORCH.FX:Python中深度学习的实用程序捕获与转换
作者/机构:James K Reed, Zachary DeVito, Horace He, Ansley Ussery, Jason Ansel (均为 Facebook AI)
本文旨在解决现代深度学习框架中的一个核心矛盾:即时执行(eager execution)模式虽然提升了开发效率和用户体验,但却牺牲了对程序结构的访问能力,而这种能力对于性能优化、可视化、分析和硬件集成等高级转换至关重要。为了在即时模式框架(如PyTorch)中重新获得这种能力,需要一种程序捕获机制。然而,现有的系统(如TorchScript)为了完全忠实地模拟Python的复杂语义(包括可变状态、控制流、复杂数据类型),其捕获技术和生成的中间表示(IR)都变得异常复杂,给转换(transform)的编写带来了巨大困难。
本文提出,可以通过专注于深度学习的典型用例(大多数神经网络模型的高层有向无环图(DAG)结构)而非长尾的复杂情况,来设计一个更简单、更高效的程序捕获与转换框架。基于这一理念,本文介绍了torch.fx,一个完全用Python编写的、为PyTorch设计的程序捕获与转换库,其核心目标是为机器学习从业者提供极高的开发生产力。
本文的主要贡献如下:
torch.fx在实践中如何被用于性能优化、程序分析、设备适配(device lowering)等场景,实现了PyTorch生态系统中以前难以完成的工作流。程序捕获、特化与IR设计的权衡。无论是即时模式还是图模式框架,在捕获和转换程序时都必须在程序结构的捕获、程序的特化(specialization)以及中间表示(IR)的设计之间做出选择。这些选择共同决定了框架能表示的程序范围、编写转换的难易程度以及转换后程序的性能。通常,为了支持更多程序并实现高性能,需要更复杂的捕获框架和IR,这反过来又使得转换的编写更加困难。
jit.trace (DeVito et al.【TorchScript, 2018】) 采用了此方法。一种稍微复杂的方式是符号追踪,即使用抽象值而非具体示例输入进行追踪。MXNet的Gluon (Chen et al.【Mxnet: A flexible and efficient machine learning library for heterogeneous distributed systems, 2015】) 和TensorFlow的tf.function (Moldovan et al.【AutoGraph: Imperative-style Coding with Graphbased Performance, 2018】) 实现了这种方法。符号追踪的优点是不需要用户提供示例输入,并且能暴露输入值依赖的Python控制流位置。追踪通常只记录张量和少数其他数据结构(如张量列表)上的操作,因此其可见性有限,但对于深度学习计算(通常是扁平的张量操作序列,即基本块程序)来说已经足够。tf.function通过一个轻量级模块化分阶段(Lightweight Modular Staging)系统 (Rompf & Odersky【Lightweight modular staging: a pragmatic approach to runtime code generation and compiled DSLs, 2010】) 来增强符号追踪,该系统使用Python AST变换将命令式控制流结构转换为可被追踪的高阶Python函数。a + b)非常抽象。ML框架在捕获程序时,通常会对其进行特化,使其仅对特定类型或张量形状有效。特化程度越高,适用输入范围越窄。不同方法在特化程度、时机(预先/即时)和安全性上有所不同。torch.jit.trace (DeVito et al.【TorchScript, 2018】) 会根据示例输入的形状进行特化。这种捕获方式是非侵入式的,但可能导致追踪到的表示形式是形状特化的,即仅对追踪时使用的值形状有效,对其他形状可能失败。jit组合子 (Frostig et al.【Compiling machine learning programs via high-level tracing, 2018】) 通过要求输入是纯函数来改进JIT特化,这强制了非张量计算(如形状表达式)的引用透明性,从而可以根据输入形状判断是否需要重新捕获。JIT特化的缺点是代码执行更难推理(例如print或pdb语句只在重新追踪时执行),且重新追踪和转换可能导致难以预测的性能抖动。x[i] = y会修改x。由于PyTorch支持这些别名和突变语义,程序修改必须依赖于安全性分析(如别名分析 (Andersen【Program analysis and specialization for the C programming language, 1994】))。TorchScript实现了别名分析,但代价高昂:所有操作都需标注别名和突变行为,保守的假设会阻碍优化。而JAX的函数式方法将状态管理的负担转移到框架之外,模型必须是纯函数,参数作为输入传递。这使得需要同时修改状态和代码的转换(如批量归一化折叠)变得更加复杂。现有框架的设计大多倾向于支持更广泛的深度学习程序,但牺牲了实现的简洁性。当捕获的程序是运行的唯一方式时,高保真度至关重要。但PyTorch主要作为即时执行框架使用,程序捕获仅用于特定转换,无需对整个程序都有效。此外,目标用户是机器学习从业者,他们更习惯使用Python而非编译器设计。
通过为典型的深度学习模型而非长尾用例进行设计,可以创建一个更易于使用和实现的框架。torch.fx的设计原则体现了这一理念:
torch.fx采用符号追踪来捕获程序,使用一个简单的包含6个指令且基于Python的IR来表示它们,并从IR重新生成Python代码来执行。为避免JIT特化带来的重捕获复杂性,torch.fx本身不尝试特化程序,而是依赖于转换过程来决定需要执行何种特化。符号追踪过程是可配置的,用户可以定制以处理更特殊的用例。
torch.fx捕获代码的示例。symbolic_trace函数接收一个函数或torch.nn.Module,并将其结构捕获到一个Graph对象中。该Graph对象与模块参数结合成一个GraphModule,这是一个torch.nn.Module的子类,其forward方法运行被捕获的图。打印图的节点可以看到捕获的IR:placeholder节点代表输入,output节点代表结果,call_function节点直接引用要调用的Python函数,call_method节点调用其第一个参数的方法。图2展示了一个简单的变换示例,该变换将代码中所有的relu激活函数替换为gelu。from torch.fx import Graph
def replace_activation(g: Graph, old, new):
for n in g.nodes:
if n.op == 'call_function' and n.target == old:
# create IR to call new activate
with g.inserting_after(n):
new_n = g.call_function(new, n.args)
n.replace_all_uses_with(new_n)
g.erase_node(n)
# or for this simplified case: 'n.target = new'
replace_activation(traced.graph, torch.relu, torch.nn.functional.gelu)
traced.recompile()
torch.fx的符号追踪机制使用一个Proxy数据结构来记录流经程序的值上的操作。Proxy是一个鸭子类型(duck-typed)的Python类,它记录对其的属性访问和方法调用,充当具体程序值的抽象替代品。Proxy利用__torch_function__协议 (Abbasi et al.【Improving subclassing Tensor by propagating subclass instances, 2020】) 来拦截并记录PyTorch算子(它们是自由函数)的派发。此外,torch.fx重写了PyTorch的Module抽象,以记录对使用代理值(proxied values)的Module的调用。整个符号追踪过程可以通过一个Tracer类进行配置,用户可以重写其方法来控制哪些值应保持为Proxy,哪些值在追踪期间应被部分求值。torch.fx在一个基于DAG的IR中表示程序,这适用于深度学习中常见的基本块程序。程序表示为一个Graph对象,其中包含一系列线性的Node对象,每个Node代表一个操作。Node具有以下属性:opcode (字符串):描述节点代表的操作类型(具体语义见附录A.1)。target:对于调用节点(call_module、call_function、call_method),这是调用的目标。args 和 kwargs:共同表示在追踪期间观察到的Python调用约定中的参数(各操作码的具体语义见附录A.2)。args和kwargs中对其他节点的引用来表示。为简化IR,torch.fx的IR没有用于建模数据结构构造或突变的原始操作。然而,args和kwargs支持立即值(immediate values):Python内置类型(如int、float)和递归集合类型(如tuple、list)可以作为节点参数出现,而无需单独的对象构造节点。这使得IR非常干净,节点与张量操作近似一一对应。torch.fx将程序的状态存储在GraphModule类中。GraphModule是转换后程序的容器,它暴露了转换后生成的代码,并提供了nn.Module中熟悉的参数管理API。GraphModule可以像普通的nn.Module一样在任何地方使用,确保了转换后的代码与PyTorch生态系统的其他部分具有互操作性。torch.fx的IR提供了两个操作码来访问模块层次结构中的状态:call_module(调用子模块的forward方法)和get_attr(从模块中获取参数)。这在可变参数和与之交互的函数式Graph之间提供了自然的分离,同时将它们保留在单个对象中以便于同时对两者进行转换。torch.fx转换流程的最后阶段是代码生成。torch.fx不退出Python生态系统进入一个定制的运行时,而是从转换后的IR生成有效的Python源代码。这些转换后的代码随后被加载到Python中,生成一个可调用的Python对象,并被安装为GraphModule实例的forward方法。使用代码生成允许torch.fx转换的结果被安装在模型中,并用于进一步的转换。如图3所示,可以将一个程序的追踪结果安装为一个新模块的激活函数,然后对结果进行符号追踪以进行进一步的转换。torch.fx融合并扩展了先前工作中的方法,提供了一个易于使用、实现简单且可配置的库。
torch.fx使用带有Proxy对象的符号追踪,而不是嵌入式语言技术,因为前者使用Python灵活的对象模型更容易直接在Python中实现。其实现足够简单,用户在追踪行为异常时可以阅读和单步调试源代码。此外,追踪有助于消除模型中不依赖于输入的控制流,例如torch.nn.Sequential中对顺序模块的循环,这对于穿透各种抽象以获取实际运行的算子至关重要。符号追踪对常见模型效果很好,代价是无法捕获那些真正包含输入依赖控制流的长尾模型,但这一限制通过使追踪过程可定制来弥补。torch.fx的符号追踪是可定制的。Tracer类控制着fx.symbolic_trace的行为,其方法可以被重写以改变追踪过程。is_leaf_module方法可以被重写,以指定哪些PyTorch Module实例在追踪期间应被视为不透明调用。默认情况下,torch.fx会保留PyTorch内置模块(如nn.Conv2d)的完整性,同时追踪用户自定义模块,以创建由标准、可理解的原语组成的轨迹。定制此行为可以屏蔽模型中包含不支持语言特性的部分,或修改用于转换的表示级别。create_proxy方法可以被重写,以自定义在图中创建节点及关联的运行时Proxy值的行为。例如,这可以用于在节点上安装自定义元数据以用于转换,或支持将自定义数据结构作为可追踪值。shape和ndim)在符号追踪期间作为Proxy值返回,对这些值的操作可以被记录下来。当这些Proxy对象被用于不可追踪的操作(如转换为Python内置类型int或bool)时,用户会收到一个错误消息和指向问题位置的堆栈跟踪。torch.fx的IR完全在Python中表示和实现,而不是使用如Protocol Buffers之类的跨语言格式。用户可以轻松地调用、阅读或重写它,无需理解C++或Protocol Buffers。变换也用Python编写。此外,变换的结果也是Python代码,这使得它易于检查正确性、用pdb调试、提供给其他库,并传递给进一步的变换。转换后的代码被封装在一个GraphModule中,可以像其他nn.Module一样在PyTorch中使用,例如,用户可以将其用TorchScript编译以进行部署,或在PyTorch的DistributedDataParallel库中使用。这种方式将torch.fx进一步整合到Python生态系统中,而不是将转换后的代码隔离到一个定制的、更难使用的运行时中。def loop_shapes(x, itr):
# x is an input tensor of size [1, N]
for _ in range(itr):
x = torch.cat((x, x), dim=0)
# Depending on the number of loop iterations, x may have an
# arbitrary leading dimension i.e. x \in [*dynamic*, N]
return x
IR本身不包含控制流,并不妨碍变换在更大模型中的基本块子图上工作;如何组合这些子图的细节留给变换的编写者或用户来决定。
torch.fx省略了此类分析,而是将可变操作定义为未定义行为,并可在追踪期间捕获时选择性地引发错误。在IR中避免可变性极大地简化了深度学习程序的分析和转换。大多数模型不受此限制,因为大多数可变性都局限于模型的参数。torch.fx仍然保留了PyTorch的分层nn.Module结构,并可以表示从此结构中的模块调用和属性获取。像torch.nn.Conv2d这样的模块对用户来说是易于理解的,有详细记录的参数,并且将参数的有状态使用隐藏在模块内部,因此保留这些对象使编写转换更容易。torch.fx、torch.jit.script和torch.jit.trace为 canonical ResNet50 模型生成的IR的复杂性(以操作数量衡量)。torch.fx IR包含445个操作,而torch.jit.trace为860个,torch.jit.script为2614个。torch.fx通过追踪并展开与输入无关的控制流,并将简单的常量和数据结构内联为节点参数,显著简化了典型模型的IR,使其比torch.jit.trace小近一半,比torch.jit.script小一个数量级。简化的IR降低了编写和维护程序转换的复杂性。torch.fx的量化带来了高达3.3倍的运行时性能提升,且性能方差很小,显示了预先转换带来的可预测性。torch.fx不仅提供了预期的性能提升,其开发效率也比基于TorchScript的实现高出一个数量级。这得益于其简化的表示、Python API以及与原生PyTorch生态的融合。torch.fx的预先、基于图的特性为这种非局部程序转换提供了必要的上下文和状态修改能力。整个转换和测试工具仅用不到150行Python代码实现,展示了其API在实现简洁、快速开发的程序转换方面的强大能力。torch.fx将阻塞调用替换为非阻塞调用和单独的等待调用,并将非阻塞调用尽可能提前。torch.fx可用于实现复杂的程序调度优化,以重叠网络调用和本地计算。torch.fx已被应用于多种程序分析场景:
fx.passes.shape_prop包提供了一个通过解释图来记录形状的朴素实现。其他更高级的形状传播系统(如通过符号表达式或渐进类型语义)也正在开发中。fx.graph_drawer包使用户能够使用Graphviz可视化torch.fx图,提供了一种直观理解深度学习程序DAG结构的方式。torch.fx-to-TensorRT转换系统,将PyTorch ResNet50模型和LearningToPaint模型下沉(lower)到NVIDIA TensorRT,并在V100 GPU上进行评估。torch.fx为编译器栈(如TensorRT)与PyTorch的集成提供了一条高效路径。该项目的开发效率很高,利用torch.fx的Python API可以快速构建转换层、模型自动分割等功能,最终为用户提供了易于使用、检查和调试的API。本文介绍了torch.fx,一个纯Python的系统,用于捕获和转换PyTorch程序。通过分析相关系统(如控制流、可变性、数据模型)的复杂性来源,本文展示了torch.fx如何通过专注于常见用例和提供可定制性来避免这些复杂性。通过对优化、分析和设备下沉等多个用例的研究,本文证明了torch.fx的API设计如何成功地实现了这些功能。
下表描述了torch.fx中每个Node的opcode的含义。
args/kwargs 行为下表描述了不同opcode下args和kwargs字段的预期行为。
下表为第6.2.1节量化实验的详细运行时间数据(单位:秒)。
下表为第6.2.2节融合实验的详细运行时间数据(单位:秒)。
下表为第6.4节TensorRT实验的详细运行时间数据(单位:秒)。