给模型加上一行 torch.compile(),推理速度有时能飙升数倍甚至十倍。这背后不是魔法,而是 GPU 执行模型的一个根本性优化——内核融合(Kernel Fusion)。理解它,你才能判断什么时候该用 compile、什么时候它帮不上忙,以及如何写出更容易被融合的代码。
没有 Compile 时,GPU 在做什么
PyTorch 的 eager 模式下,每遇到一个算子,框架就向 GPU 发射一个 kernel——一段在 GPU 上运行的函数。一个简单的两层 MLP:
y = relu(linear(x))
eager 模式下,GPU 先执行 linear 的 kernel,再把结果写回显存;接着再执行 relu 的 kernel,再写回显存。两次 kernel launch,两次显存读写。
问题不在计算本身,而在 launch 开销和显存往返:
- Kernel launch 有固定成本。每次向 GPU 提交一个 kernel,都要走一遍驱动栈,大约在几微秒量级。算子多了,这些微秒叠加起来就很可观。
- 中间结果必须写回显存再读出来。GPU 的计算单元和显存之间带宽有限。一个
linear输出写回显存、relu再从显存读进来,这个来回比在寄存器或 SRAM 里直接传递慢得多。
一个真实模型可能有几百个算子。eager 模式下,GPU 在不停地"发射 kernel → 等完成 → 读中间结果 → 发射下一个 kernel",大量时间花在等待和搬运上,计算单元反而经常空闲。
内核融合:把多步计算压进一个 Kernel
torch.compile() 的核心动作之一就是把多个算子合并成一个 kernel,让 GPU 一次 launch 就完成整段计算。
上面的 relu(linear(x)),融合后变成一个 kernel:先做矩阵乘法,结果留在 GPU 的寄存器或片上 SRAM 里,紧接着就地做 ReLU,最终只把激活后的结果写回显存。一次 launch,一次显存写入。
收益来自两方面:
- Launch 开销从 N 次降到 1 次。几十个算子融合后,驱动栈只走一遍。
- 中间结果不再往返显存。数据在 GPU 片上高速存储中流转,带宽瓶颈大幅缓解。
这对带宽受限的算子(比如 element-wise 的激活函数、归一化)效果最明显——它们本身计算量不大,但在 eager 模式下每次都要从显存读一遍、写一遍,大部分时间在等内存。融合后,这些算子"搭车"进了前一个计算密集型 kernel 的尾部,几乎零额外开销。
用代码看差异
下面这段脚本可以直接运行,对比 eager 和 compiled 模式下同一个模型的执行时间和内核数量。需要一张 NVIDIA GPU 和 PyTorch 2.0+。
import torch
import torch.nn as nn
import time
class SmallMLP(nn.Module):
def __init__(self, dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim),
)
def forward(self, x):
return self.net(x)
dim = 512
model = SmallMLP(dim).cuda()
x = torch.randn(256, dim, device="cuda")
# ---------- eager 模式 ----------
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(200):
_ = model(x)
torch.cuda.synchronize()
eager_time = (time.perf_counter() - t0) / 200
# ---------- compiled 模式 ----------
compiled = torch.compile(model, mode="reduce-overhead")
# 首次调用会触发编译,先 warm up
_ = compiled(x)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(200):
_ = compiled(x)
torch.cuda.synchronize()
compiled_time = (time.perf_counter() - t0) / 200
print(f"eager: {eager_time * 1000:.3f} ms/iter")
print(f"compile: {compiled_time * 1000:.3f} ms/iter")
print(f"speedup: {eager_time / compiled_time:.2f}x")
运行后你大概率会看到 compile 版本快 1.5x–3x。模型越大、算子越多,收益越高。
如果想直观看到 kernel 数量的变化,可以用 torch.profiler:
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CUDA]) as prof:
model(x)
print("eager kernel count:", len(prof.key_averages()))
with profile(activities=[ProfilerActivity.CUDA]) as prof:
compiled(x)
print("compiled kernel count:", len(prof.key_averages()))
eager 模式下你会看到每个 Linear 和 ReLU 各自对应独立的 CUDA kernel;compiled 模式下,多个算子被合并,kernel 数量显著减少。
融合不是万能的——边界与取舍
内核融合有几个天然限制,了解它们能避免踩坑:
1. 融合需要算子之间数据流连续。 如果两个算子之间有显存上的分支(比如同一中间结果被两条路径分别消费),编译器可能无法把所有路径塞进同一个 kernel,只能部分融合。
2. 数据依赖打破融合。 比如 argmax 的结果作为下一个算子的索引输入——这种动态依赖在编译期无法静态确定,融合会中断。
3. 编译本身有开销。 首次调用 torch.compile() 需要追踪(trace)模型、分析图、生成内核代码,可能要几秒到几十秒。对只跑一次的推理场景,编译时间可能比省下来的执行时间还长。适合重复执行的训练循环或高频推理服务。
4. 动态形状会触发重编译。 如果每次输入的 tensor shape 不同(比如 batch size 变化、序列长度不固定),编译器会按新 shape 重新编译。可以用 dynamic=True 告诉编译器预期动态形状,但生成的 kernel 通常比静态形状版本稍慢。
写更容易被融合的代码
几个实用习惯能让编译器更好地发挥融合能力:
- 避免在 forward 中插入 Python 控制流。
if、for循环依赖数据值时,编译器要么断开融合,要么反复重编译。尽量用torch.where、torch.masked_fill等张量级操作替代。 - 减少不必要的
.item()或.cpu()调用。这些操作把数据从 GPU 拉回主机,强制打断计算图的连续性。 - 保持输入 shape 稳定。训练时固定 batch size,推理时 pad 到统一长度,减少重编译。
- 优先用 PyTorch 原生算子组合,而非手写 CUDA kernel。编译器理解原生算子的语义,能自动融合;自定义 kernel 对编译器来说是黑盒,只能单独发射。
什么时候该用 Compile
| 场景 | 建议 |
|---|---|
| 训练循环(数百到数千步) | 强烈推荐,编译开销很快摊薄 |
| 固定 shape 的推理服务 | 推荐,一次编译后持续受益 |
| 变长输入的推理(如 LLM 不同 token 数) | 用 dynamic=True,但预期收益比静态 shape 低 |
| 一次性脚本 / 快速实验 | 不推荐,编译时间占比太高 |
| 模型大量使用自定义 CUDA extension | 收益有限,黑盒算子无法融合 |
一行 torch.compile(model) 背后的核心机制就是内核融合——把多个 GPU kernel 压成一个,砍掉 launch 开销和显存往返。理解这个原理,你就能在合适的场景下拿到真实的加速,而不是在不适用的地方浪费编译时间。