PyTorch 随笔 (一): PyTorch Compile

2025/03/15

1. 前言

这是一个我完全不熟悉的领域,因为工作中遇到了一些和 torch compile 相关的问题,但是脑子里完全没有概念,于是我决定打开这个 black box 看一眼这究竟是怎么回事。希望我可以边写 blog 边学,写完之后我可以建立对 Pytorch Compile 的基本认知,并且在遇到问题时可以有办法去看代码,debug 和定位问题。

2. 什么是 PyTorch Compile

PyTorch Compile 是 Python 2.0 引入的,用于优化 PyTorch Model 的运行速度。大致的原理:原本 PyTorch 是以 eager 的模式执行的,每个 operation 会单独按顺序执行。而 PyTorch Compile 可以捕获模型的 computation graph,对其进行优化(例如 Kernel Fusion),然后编译成优化的 kernel,在 Runtime 直接调用,提升运行速度。 一个简单的调用方式:

compiled_model = torch.compile(model)

3. 实测

import torch
import time

# Define a simple model
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(1024, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 10),
        )

    def forward(self, x):
        return self.layers(x)

# Create model and input data
model = SimpleModel().cuda()
x = torch.randn(1024, 1024, device="cuda")

# Benchmark standard model
def benchmark_standard():
    torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(100):
        model(x)
    torch.cuda.synchronize()
    return time.time() - start_time

# Benchmark compiled model
compiled_model = torch.compile(model)

def benchmark_compiled():
    # First run to compile
    compiled_model(x)
    torch.cuda.synchronize()

    start_time = time.time()
    for _ in range(100):
        compiled_model(x)
    torch.cuda.synchronize()
    return time.time() - start_time

# Run benchmarks
standard_time = benchmark_standard()
compiled_time = benchmark_compiled()

print(f"Standard model: {standard_time:.4f} seconds")
print(f"Compiled model: {compiled_time:.4f} seconds")
print(f"Speedup: {standard_time/compiled_time:.2f}x")

用一张 A100 的卡实测一下,可以看到用法还是非常简单的, 核心就一行代码: torch.compile(model)

# result
Standard model: 0.0356 seconds
Compiled model: 0.0315 seconds
Speedup: 1.13x

结果上来看使用它确实提升了 performance。Note:我测试的时候我一开始怎么跑都是 standard model 速度更快,后来去研究了一下才发现需要使用 torch.cuda.synchronize() 来得到更准确的结果 Reference

4. 尝试理解它是怎么工作的

4.1 运作机制

延伸阅读材料:

4.2 什么时候该用 PyTorch Compile?

当模型会被频繁调用或计算量较大时可以尝试使用一下,做一下测试看看有没有性能提升。而当模型很简单,只执行一次或几次的时候,使用这个说不定还会让性能变差,因为编译本身也是有成本的。

5. 常见问题以及如何初步 debug

5.1 常见问题

在真正使用中 PyTorch Compile 并不保证能够提升性能,这是一些常见的问题:

5.2 一些 debug 的方法

TORCH_LOGS="graph_breaks,recompiles" python run_model.py
import torch
import torch._dynamo as dynamo

def forward_func(x):
    return (x + 1).sin()

x = torch.randn(10)

# Explanation of compilation
explanation = dynamo.explain(forward_func, (x,))

print(explanation)

6. 结语

好了本文目的达到了,就是大概了解一下 torch compile,希望在以后工作中遇到这类问题可以不需要摸不着头脑,而是至少有个方向可以往下探索。之后应该还是会遇到更复杂的 case,希望到时候可以再学到一些新的技巧,有新的理解再写一个 follow up。