1. 前言
这是一个我完全不熟悉的领域,因为工作中遇到了一些和 torch compile 相关的问题,但是脑子里完全没有概念,于是我决定打开这个 black box 看一眼这究竟是怎么回事。希望我可以边写 blog 边学,写完之后我可以建立对 Pytorch Compile 的基本认知,并且在遇到问题时可以有办法去看代码,debug 和定位问题。
2. 什么是 PyTorch Compile
PyTorch Compile 是 Python 2.0 引入的,用于优化 PyTorch Model 的运行速度。大致的原理:原本 PyTorch 是以 eager 的模式执行的,每个 operation 会单独按顺序执行。而 PyTorch Compile 可以捕获模型的 computation graph,对其进行优化(例如 Kernel Fusion),然后编译成优化的 kernel,在 Runtime 直接调用,提升运行速度。 一个简单的调用方式:
compiled_model = torch.compile(model)
3. 实测
import torch
import time
# Define a simple model
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(1024, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 10),
)
def forward(self, x):
return self.layers(x)
# Create model and input data
model = SimpleModel().cuda()
x = torch.randn(1024, 1024, device="cuda")
# Benchmark standard model
def benchmark_standard():
torch.cuda.synchronize()
start_time = time.time()
for _ in range(100):
model(x)
torch.cuda.synchronize()
return time.time() - start_time
# Benchmark compiled model
compiled_model = torch.compile(model)
def benchmark_compiled():
# First run to compile
compiled_model(x)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(100):
compiled_model(x)
torch.cuda.synchronize()
return time.time() - start_time
# Run benchmarks
standard_time = benchmark_standard()
compiled_time = benchmark_compiled()
print(f"Standard model: {standard_time:.4f} seconds")
print(f"Compiled model: {compiled_time:.4f} seconds")
print(f"Speedup: {standard_time/compiled_time:.2f}x")
用一张 A100 的卡实测一下,可以看到用法还是非常简单的, 核心就一行代码: torch.compile(model) 。
# result
Standard model: 0.0356 seconds
Compiled model: 0.0315 seconds
Speedup: 1.13x
结果上来看使用它确实提升了 performance。Note:我测试的时候我一开始怎么跑都是 standard model 速度更快,后来去研究了一下才发现需要使用 torch.cuda.synchronize() 来得到更准确的结果 Reference。
4. 尝试理解它是怎么工作的
4.1 运作机制
- Graph Capture:TorchDynamo 会从 pytorch 代码捕获计算图 (FX graph)。
- Graph Lowering: TorchInductor 把计算图转化为 IR, 在这个过程中,TorchInductor 会做一些优化,例如算子融合等。
- Graph Compilation:如果代码 target 是 gpu,Triton 会把优化过的 IR 图转化成 GPU kernels。如果 target 是 cpu,TorchInductor 会使用 Existing ATen Kernels 或者 使用 LLVM 生成 CPU kernels。
延伸阅读材料:
4.2 什么时候该用 PyTorch Compile?
当模型会被频繁调用或计算量较大时可以尝试使用一下,做一下测试看看有没有性能提升。而当模型很简单,只执行一次或几次的时候,使用这个说不定还会让性能变差,因为编译本身也是有成本的。
5. 常见问题以及如何初步 debug
5.1 常见问题
在真正使用中 PyTorch Compile 并不保证能够提升性能,这是一些常见的问题:
- Dynamic Shape: 如果模型的每次的 input shape 一直在变化,就容易导致 recompile,反而让 performance 更差了
- Graph Breaks: 有时候 TorchDynamo 没法捕捉或者编译一些 code,会让这部分代码执行的时候还是用 eager mode 执行
- 其他可能的问题,等遇到再说吧
5.2 一些 debug 的方法
- TORCH_LOGS: 可以使用 TORCH_LOGS 来输出自己想要查看的 events, 例如加上 “TORCH_LOGS=‘graph_breaks,recompiles’” 可以查看有没有这两个 event 发生来 further debug
TORCH_LOGS="graph_breaks,recompiles" python run_model.py
- _dynamo.explain: 用 _dynamo.explain 来分析 torch.compile 会怎么 process,也可以看到 potential 的 graph breaks 等问题
import torch
import torch._dynamo as dynamo
def forward_func(x):
return (x + 1).sin()
x = torch.randn(10)
# Explanation of compilation
explanation = dynamo.explain(forward_func, (x,))
print(explanation)
- 应该还有一些其他的方式,以后遇到再说吧
6. 结语
好了本文目的达到了,就是大概了解一下 torch compile,希望在以后工作中遇到这类问题可以不需要摸不着头脑,而是至少有个方向可以往下探索。之后应该还是会遇到更复杂的 case,希望到时候可以再学到一些新的技巧,有新的理解再写一个 follow up。