专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
北美留学生观察  ·  “哪吒闹海”的马斯克距离被暗杀,还有多久? ·  昨天  
北美留学生观察  ·  暴力、药物、奴役:泰国“娱乐业”盯上日本风俗 ... ·  2 天前  
北美留学生观察  ·  输了还打人?中国香港冰球队获胜后遭对方“围殴 ... ·  2 天前  
北美留学生观察  ·  又双叒叕是韩国队!伸手犯规,完事儿还笑?中国 ... ·  3 天前  
51好读  ›  专栏  ›  极市平台

一文搞懂 TorchDynamo 原理

极市平台  · 公众号  ·  · 2024-08-12 22:00

正文

↑ 点击 蓝字 关注极市平台
作者丨Fei kong
来源丨https://fkong.tech/posts/2023-05-20-dynamo/
编辑丨极市平台

极市导读

本文详细介绍了TorchDynamo的工作原理和使用方法,它是PyTorch 2.0中用于捕获计算图的组件之一。文章通过简单示例展示了TorchDynamo的基本用法,解释了其在捕获Python字节码时的灵活性和可靠性,并与TorchScript和TorchFX进行了比较。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

简介

PyTorch 2.0 的使命是更快、更 Pythonic 以及一如既往地支持动态特性。为了达到这个目的,PyTorch 2.0 引入了 torch.compile ,在解决 PyTorch 固有的性能问题的同时,把部分用 C++ 实现的东西引入 Python 中。PyTorch 2.0 利用了 4 个组件: TorchDynamo,AOTAutograd,PrimTorch 和 TorchInductor。本文以几个简单的案例讲解 TorchDynamo 的使用方法和实现原理。

PyTorch 2.0

TorchDynamo 的作用是从 PyTorch 应用中抓取计算图 ,相比于 TorchScript 和 TorchFX,TorchDynamo 更加灵活、可靠性更高。用过 TorchScript 的朋友知道,通过 jit.trace 或者 jit.script 把模型转化为 TorchScript 的过程困难重重,往往需要修改大量源代码。而 TorchFX 在捕获计算图时,遇到不支持的算子会直接报错,最常见的就是 if 语句。TorchDynamo 克服了 TorchScript 和 TorchFX 的缺点,使用起来极为方便,用户体验相比于 TorchScript 和 TorchFX 大幅提升。配合 TorchInductor 等后端编译器,经 TorchDynamo 捕获的计算图只需要几行代码的改动就可以观测到不错的性能提升。

用法

使用 TorchDynamo 的方法非常简单,可以通过 torch.compile() 或者 torch._dynamo.optimize() ,其中可以指定 backend 'inductor' 'eager' ,或者以用户自定义的 Python 函数作为 graph compiler。在下面的代码片段中,我们以自定义的 Python 函数 my_compiler 作为编译器:

from typing import List  
import torch  
  
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):  
    print(">>> my_compiler() invoked:")  
    print(">>> FX graph:")  
    gm.graph.print_tabular()  
    print(f">>> Code:\n{gm.code}")  
    return gm.forward  # return a python callable  
  
@torch.compile(backend=my_compiler)  
def foo(x, y):  
    return (x + y) * x  
  
if __name__ == "__main__":  
    a, b = torch.randn(10), torch.ones(10)  
    foo(a, b)  

执行上面的代码,可以看到 TorchDynamo 把从函数 foo() 中捕获到一张计算图,TorchDynamo 以 FX Graph 保存捕获到的计算图:

>>> FX graph:  
opcode         name    target                   args       kwargs  
-------------  ------  -----------------------  ---------  --------  
placeholder    x       x                        ()         {}  
placeholder    y       y                        ()         {}  
call_function  add     function add>  (x, y)     {}  
call_function  mul     function mul>  (add, x)   {}  
output         output  output                   ((mul,),)  {}  
  
>>> Code:  
def forward(self, x : torch.Tensor, y : torch.Tensor):  
    add = x + y;  y = None  
    mul = add * x;  add = x = None  
    return (mul,)  

Python 字节码

TorchDynamo 捕获计算图是在翻译 Python 字节码的过程中实现的 。Python 函数在执行前会被 Python 虚拟机编译为字节码 (bytecode),每一个 Python 函数的实例都对应一个 frame,其中保存着运行该函数所需要的全局变量、局部变量、字节码等等。为了便于理解 Python 虚拟机、字节码和 TorchDynamo 的行为,下面用 hello() 函数简要介绍下 Python 字节码的行为。我们可以用 dis 包查看 Python 函数的字节码:

import dis  
  
def hello():  
    print("Hello, world!")  
  
for k in ["co_names""co_varnames""co_consts"]:  
    print(k, getattr(hello.__code__, k))  
print(dis.dis(hello))  

执行上面的代码,我们得到下面的结果:

co_names ('print',)  
co_varnames ()  
co_consts (None, 'Hello, world!')  
  
0 LOAD_GLOBAL              0 (print)  
2 LOAD_CONST               1 ('Hello, world!')  
4 CALL_FUNCTION            1  
6 POP_TOP  
8 LOAD_CONST               0 (None)  
10 RETURN_VALUE  

其中包含了 6 条 Python 字节码,它们的功能如下:

  • LOAD_GLOBAL 0 : 从 f_builtins f_globals 中加载由下标 0 所引用的全局对象,把它压到数据栈上;
  • LOAD_CONST 1 : 从 co_consts 中加载由下标 1 所引用的常量,把它压到数据栈上;
  • CALL_FUNCTION 1 : 从栈顶出栈 1 个元素作为函数参数,再出栈一个元素作为被调函数,调用该函数并把返回值压到数据栈上;
  • POP_TOP : 从栈顶移除一个元素;
  • LOAD_CONST 0 : 从 co_consts 中加载由下标 0 所引用的常量,把它压到数据栈上;
  • RETURN_VALUE : 从栈顶出栈 1 个元素,把它作为返回值返回给主调函数;

Python 虚拟机是 Stack Machine ,它维护了 3 个 stack:

  • Call Stack : 其中的条目是 Python frame,类似 C 的函数调用栈;
  • Evaluation Stack (or Data Stack) : 每个 Python frame 都有一个 evaluation stack,执行 Python 字节码时的数据由该 stack 管理,这与常见的 Register Machine 有所区别;
  • Block Stack : 每个 Python frame 都有一个 block stack,目的是跟踪 Python 中的控制结构,例如循环、 try / except with 语句等,进入/退出这类控制结构时会有对应的条目被 push/pop。Block stack 帮助 Python 在任意时刻都知道当前活跃的 block, continue break 会影响当前活跃的 block;

更多 Python 字节码和虚拟机的细节可以参考 _PyEval_EvalFrameDefault。

实现原理

TorchDynamo 的 编译过程发生在将要执行前 ,它是一个 JIT 编译器。在 Python 将要执行函数时,TorchDynamo 开始翻译字节码并捕获计算图。在 Python 虚拟机 (PVM) 中有一个非常重要的函数 _PyEval_EvalFrameDefault,它的功能是在 PVM 中逐条执行编译好的字节码。TorchDynamo 的入口是 PEP-523 提供的 CPython Frame Evaluation API,它可以让用户通过 回调函数(callback function) 获取字节码,并把修改过后的字节码返回给解释器执行,或者执行预先编译好的目标代码,从而可以在 Python 中实现 即时编译器 (JIT Compiler) 的功能。TorchDynamo 正是通过 PEP-523 把 TorchDynamo 的核心逻辑引入到 Python 虚拟机中,从而在函数将要运行前获取字节码。下图展示了 TorchDynamo 的核心原理:

TorchDynamo

TorchDynamo 实现了一个 Python 虚拟机的模拟器,在模拟 Python 字节码执行的过程中构建出对应的计算图 。仍以 foo() 为例:

@torch.compile(backend=my_compiler)  
def foo(x, y):  
    return (x + y) * x  

foo() 对应的字节码如下,TorchDynamo 在翻译字节码 BINARY_ADD BINARY_MULTIPLY 时在 FX Graph 中建立了 operator.add operator.mul 两个 FX Node,最后形成一张完整的计算图:

 0 LOAD_FAST                0 (x)  
 2 LOAD_FAST                1 (y)  
 4 BINARY_ADD  
 6 LOAD_FAST                0 (x)  
 8 BINARY_MULTIPLY  
10 RETURN_VALUE  

为了检验 TorchDynamo 捕获的计算图在下次执行时还是否有效,TorchDynamo 会为被编译的函数创建 Guard 。从 Guard 生成的 Python 可执行函数 check_fn ,在 TorchDynamo 中 负责检测被编译函数的输入属性是否发生变化 ,如果没有发生变化则可以重用此前编译好的函数,否则当前输入对此前编译好的函数无效,需要 重新编译 (graph recompilation) 该函数。 TENSOR_MATCH 是检测张量信息的 Guard ,在默认情况下,主要负责检查输入的张量 device、shape、stride 等属性是否改变。 foo() 函数对应的 check_fn 如下,它会调用 C++ 函数检查张量 x y 的信息是否发生变化,进而决定是否能重用此前编译好的函数:

GUARDS ___guarded_code.valid and ___check_tensors(x, y)  

经 TorchDynamo 编译好的函数被保存在 frame 的 cache 中 ,从而避免再次编译相同的函数和输入。默认情况下 cache 大小为 64,也就是说,对于同一个 Python 函数,它的输入最多可以有 64 种变化,超过这个限制后 TorchDynamo 不再编译该函数。

Graph Break

TorchDynamo 并不能把所有的函数都捕获到一张计算图中。 TorchDynamo 碰到无法支持的算子时会创建 graph break,把计算图切分成它可以支持的几张子图,并返回 Python 解释器执行它无法处理的算子 。最常见的导致 graph break 的案例是用张量的值作为 if 语句的条件,以下面的函数为例:

def toy_example(a, b):  
    x = a / (torch.abs(a) + 1)  
    if b.sum()         b = b * -1  
    return x * b  

TorchDynamo 会把 toy_example() 拆分为 3 张子图,不能处理的 if 语句由 Python 解释器执行。编译后对应的 Python 函数如下,执行完编译好的子图 __compiled_fn_0() 后,程序返回到 Python 解释器,根据 if 语句的结果选择执行还未编译的子图 __resume_at_30_1() __resume_at_38_2() :

def compiled_toy_example(a, b):  
    x, lt = __compiled_fn_0(a, b)  
    if lt:  
        return __resume_at_30_1(b, x)  
    else:  
        return __resume_at_38_2(b, x)  

其中包含了 3 个函数:

  • __compiled_fn_0() : TorchDynamo 编译好的子图,对应 if 语句前面的部分:
def __compiled_fn_0(a, b):  
    x = a / (torch.abs(a) + 1)  
    return b.sum() 
  • __resume_at_30_1() : TorchDynamo 未编译的子图,对应 if 分支 (TorchDynamo 直接操纵字节码,为了方便解释这里用了 Python 伪代码,并假设 Python 中支持 goto 和 label):
# pseudo python code with goto and label  
def __resume_at_30_1(b, x):  
    goto if_next  
    x = a / (torch.abs(a) + 1)  
    if b.sum()         label if_next  
        b = b * -1  
    return x * b  

该函数会在首次执行时被 TorchDynamo 捕获并编译。

  • __resume_at_38_2() : TorchDynamo 未编译的子图,对应 else 分支,该函数也会在首次执行时被 TorchDynamo 捕获并编译:
# pseudo python code with goto and label  
def __resume_at_38_2(b, x):  
    goto if_jump  
    x = a / (torch.abs(a) + 1)  
    if b.sum()         b = b * -1  
    label if_jump  
    return x * b  

Dynamic Shape

默认情况下 TorchDynamo 为 static shape 模式,捕获计算图时张量的 shape stride 被特化并记录在 Guard 中。捕获计算图结束时会生成 Guard 对应的 check_fn ,用于 检查该计算图中的输入信息有没有发生变化 。如果没有发生变化则重用已经编译好的计算图,否则重新捕获并编译计算图 (graph recompilation)。当设置环境变量 TORCHDYNAMO_DYNAMIC_SHAPES 为 1 时,此时 TorchDynamo 以 dynamic shape 模式捕获计算图,张量的 shape stride 不会被特化、不会被记录在 Guard 中,生成的 check_fn 也不检查 shape stride 。因此,以不同 shape stride 的张量执行编译好的计算图时,不会重新捕获计算图和重新编译。下面的代码片段中, test() 调用了两次 toy_example() ,两次不同的调用之间 tensor 的 shape 不同,所以会触发重新编译:

@torch.compile(backend=my_compiler)  
def toy_example(x):  
    x = x / (torch.abs(x) + 1)  
    return x  
  
def test():  
    x = torch.randn(10)  
    toy_example(x)  
    x = torch.randn(20)  
    toy_example(x)  

使用 torch.compile() 编译 toy_example() 并运行,可以看到这里触发了两次 toy_example() 的编译。这是因为第二次调用 toy_example() 时,张量 x 没能通过 Guard 检查。相关函数调用栈:

  • [C011] > torch/csrc/dynamo/guards.cpp#L49 [New]
  • [C010] > torch/csrc/dynamo/guards.cpp#L207 [New]
  • [P009] >#L2:
  • [C008] > torch/csrc/dynamo/eval_frame.c#L355
  • [C007] > torch/csrc/dynamo/eval_frame.c#L621
  • [C006] > torch/csrc/dynamo/eval_frame.c#L346
  • [C005] > torch/csrc/dynamo/eval_frame.c#L505
  • [C004] > torch/csrc/dynamo/eval_frame.c#L640
  • [C003] > torch/csrc/dynamo/eval_frame.c#L621
  • [C002] > torch/csrc/dynamo/eval_frame.c#L346
  • [P001] > torch/_dynamo/eval_frame.py#L233
  • [P000] > test.py#L18:test [New]

循环展开

TorchDynamo 把 Python 中的循环捕获为循环展开的计算图,即捕获的计算图中不再包含循环 。例如下面的代码片段,其中的 for 循环迭代了 4 次、每次执行一次乘法操作:

@torch.compile  
def toy_example(x, n):  
    for i in range(1, n+1):  
        x = x * i  
    return x  
  
def test():  
    x = torch.randn(10)  
    toy_example(x, 4)  

捕获到的计算图对应的 Python 函数为:

def forward(self, x : torch.Tensor):  
    mul = x * 1;  x = None  
    mul_1 = mul * 2;  mul = None  
    mul_2 = mul_1 * 3;  mul_1 = None  
    mul_3 = mul_2 * 4;  mul_2 = None  
    return (mul_3,)  

这个过程的原理是 TorchDynamo 在它的 Python 虚拟机模拟器中模拟运行了 FOR_ITER 这条字节码指令,然后捕获在每次迭代中出现的运算,而不是把 for 循环本身捕获到计算图中。这个过程的函数调用栈如下:

  • [P053] > torch/_dynamo/symbolic_convert.py#L911 [New]
  • [P049] > torch/_dynamo/symbolic_convert.py#L537
  • [P045] > torch/_dynamo/symbolic_convert.py#L590 [New]
  • [P041] > torch/_dynamo/symbolic_convert.py#L1838 [New]
  • [P037] > torch/_dynamo/convert_frame.py#L298 [New]
  • [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
  • [P029] > torch/_dynamo/convert_frame.py#L279 [New]
  • [P025] > torch/_dynamo/utils.py#L158 [New]
  • [P021] > torch/_dynamo/convert_frame.py#L200 [New]
  • [P017] > torch/_dynamo/convert_frame.py#L96 [New]
  • [P013] > torch/_dynamo/convert_frame.py#L403 [New]
  • [P009] > torch/_dynamo/eval_frame.py#L362
  • [C008] > torch/csrc/dynamo/eval_frame.c#L355
  • [C007] > torch/csrc/dynamo/eval_frame.c#L621
  • [C006] > torch/csrc/dynamo/eval_frame.c#L346
  • [C005] > torch/csrc/dynamo/eval_frame.c#L399
  • [C004] > torch/csrc/dynamo/eval_frame.c#L640
  • [C003] > torch/csrc/dynamo/eval_frame.c#L621
  • [C002] > torch/csrc/dynamo/eval_frame.c#L346
  • [P001] > torch/_dynamo/eval_frame.py#L233 [New]
  • [P000] > test.py#L19:test [New]

内联函数

针对用户函数调用,TorchDynamo 会尝试内联 (inline) 被调函数,从而生成更大的计算图。但如果被掉函数中存在 graph break,那么内联就会失败,此时函数调用栈中的每个函数都会产生一个 graph break。 下面的代码片段中 test() 调用了递归函数 toy_example() :

@torch.compile  
def toy_example(x, n):  
    if n > 0:  
        return toy_example(x, n-1) * n  
    else:  
        return x  
  
def test():  
    x = torch.randn(10)  
    toy_example(x, 4)  

TorchDynamo 在捕获 toy_example(x, 4) 的计算图时,会尝试内联 toy_example(x, 3) 的计算图,依次类推,直到成功内联 toy_example(x, 0) 的计算图。最终生成一个大的计算图,其中的函数调用被展开:

def forward(self, x : torch.Tensor):  
    mul = x * 1;  x = None  
    mul_1 = mul * 2;  mul = None  
    mul_2 = mul_1 * 3;  mul_1 = None  
    mul_3 = mul_2 * 4;  mul_2 = None  
    return (mul_3,)  

但在下面的代码片段中,用户函数 baz() 无法被 TorchDynamo 内联,因为其中的 if 条件依赖于张量的值,只有在运行时才能确定执行哪个分支,故而存在一个 graph break。这个 graph break 导致其调用者 bar() foo 都产生了 graph break,最后总共生成 7 个计算图, baz() 中包含 3 个:

def baz(x):  
    return -x if x > 0 else x - 1  
  
def bar(x):  
    return x * baz(x - 1)  
  
@torch.compile  
def foo(x):  
    return x * bar(2 * x)  
  
def test():  
    x = torch.tensor([4])  
    foo(x)  

TorchDynamo 通过字节码指令 CALL_FUNCTION 实现内联函数,其中识别用户函数调用并尝试内联,内联失败时恢复主调函数的状态并创建 graph break,子图编译完后返回解释器执行子函数调用。这个过程通过 InliningInstructionTranslator 实现,它不支持子图编译,函数调用栈如下:

  • [P034] > torch/_dynamo/exc.py#L69 [New]
  • [P033] > torch/_dynamo/symbolic_convert.py#L234 [New]
  • [P032] > torch/_dynamo/symbolic_convert.py#L537
  • [P031] > torch/_dynamo/symbolic_convert.py#L590
  • [P030] > torch/_dynamo/symbolic_convert.py#L1956
  • [P029] > torch/_dynamo/symbolic_convert.py#L1930
  • [P028] > torch/_dynamo/symbolic_convert.py#L524
  • [P027] > torch/_dynamo/variables/functions.py#L90
  • [P026] > torch/_dynamo/variables/functions.py#L251
  • [P025] > torch/_dynamo/symbolic_convert.py#L469
  • [P024] > torch/_dynamo/symbolic_convert.py#L988
  • [P023] > torch/_dynamo/symbolic_convert.py#L341
  • [P022] > torch/_dynamo/symbolic_convert.py#L537
  • [P021] > torch/_dynamo/symbolic_convert.py#L590
  • [P020] > torch/_dynamo/symbolic_convert.py#L1956 [New]
  • [P019] > torch/_dynamo/symbolic_convert.py#L1930 [New]
  • [P018] > torch/_dynamo/symbolic_convert.py#L524 [New]
  • [P017] > torch/_dynamo/variables/functions.py#L90 [New]
  • [P016] > torch/_dynamo/variables/functions.py#L251 [New]
  • [P015] > torch/_dynamo/symbolic_convert.py#L469 [New]






请到「今天看啥」查看全文