PyTorch core 의 Richard Zou 가 깐 torch.compile 의 내부 — TorchDynamo 의 graph 캡처, AOT autograd, FakeTensor, dynamic shape, Inductor 의 lowering, cudagraph 의 위치. “왜 graph break 가 일어나나, dynamic shape 가 왜 잘 안 되나, 어디서 디버깅을 시작하나” 의 실전 질문에 대한 답. 본 페이지는 transcript 가 실패해 PyTorch 공식 문서, dev-discuss 토론, Richard 의 공개 talk 으로 재구성됐다.
PyTorch 2.0 의 torch.compile() 은 한 줄 wrapping 으로 fusion 을 잡아주는 기적적인 도구로 광고되지만 — 실전에서 학습 코드를 감싸면 graph break 가 줄줄이 터지거나, recompile 이 매 step 일어나거나, dynamic shape 에서 행이 멈춘다. 이 강의는 왜 그런 일이 일어나는지 의 내부 답을 깐다.
강의의 출발 질문 셋.
본 노트는 transcript 실패로 — PyTorch 공식 문서, dev-discuss 의 Dynamo/Inductor 디자인 RFC, Richard Zou 의 다른 talk (PyTorch Conference, EuroPython) 을 종합한 재구성.
torch.compile 을 한 줄짜리 마법으로 보면 안 된다. 안에는 4개의 거의 독립적인 시스템 이 있다 — Dynamo (graph 캡처), AOT autograd (backward 합성), FakeTensor (shape 추론), Inductor (lowering). 각자 다른 자리에서 다른 이유로 깨진다.
TorchDynamo 는 Python bytecode 위에서 동작한다 — 함수가 호출되면 bytecode 를 walk 하면서 “이 op 가 graph 에 들어갈 수 있나” 를 체크한다. 못 들어가는 자리에서 graph break — 거기서 graph 를 끊고, 그 라인은 일반 Python 으로 돌고, 그 후 다시 graph 캡처 시작.
graph break 를 일으키는 흔한 패턴.
data-dependent control flow 라고 부름.
torch.where
graph break 가 한 번 나도 — 그 전후의 두 sub-graph 는 각자 컴파일된다. 단, fusion 이 break 경계를 넘지 못 한다는 점이 손해. “이 break 가 fusion 을 깨고 있는지” 가 진짜 질문. 학습 루프 시작/끝 의 break 는 거의 무해.
torch.compile 은 처음 호출될 때 그 입력의 shape 으로 graph 를 만든다. 그 다음 호출 — 다른 shape 이면? 두 가지 모드. (1) static (default) — recompile. (2) dynamic — symbolic shape 으로 graph 를 만들어 두고 다음 호출에서 재사용.
static 모드의 함정 — 매 step batch size 가 다른 학습 (가변 길이 sequence, dynamic batching) 에서 매번 recompile. compile 비용이 1~5초인데, 매 step 이 100ms 면 — compile 비용이 학습 비용보다 큼.
dynamic 모드의 함정 — 모든 차원이 dynamic 이라고 처리되면 fusion 이 잘 안 됨 (Inductor 가 shape 을 모르면 vectorization 결정 어려움). 보통 “batch dim 만 dynamic, 나머지 static” 이 좋음.
compile 된 graph 에는 guard 가 같이 박힌다 — “이 graph 는 batch=32, seq=512, dtype=fp16 일 때만 valid” 같은 조건. 호출 시 guard 가 false 면 recompile. 너무 많은 guard 가 박히면 recompile 폭주.
# dynamic shape 의 세 모드
torch.compile(fn) # default · static
torch.compile(fn, dynamic=True) # 모든 차원 symbolic
torch.compile(fn, dynamic=None) # automatic — 첫 recompile 후 dynamic
# 차원별 마킹
from torch import _dynamo as dynamo
dynamo.mark_dynamic(x, 0) # dim 0 만 dynamic
# compile log 보기
import os
os.environ["TORCH_LOGS"] = "recompiles"
graph 를 짜려면 — 매 op 의 출력 shape, dtype, device 를 알아야 한다. 그런데 실제 Tensor 를 만들면 메모리가 잡히고 GPU op 가 실행된다. compile 단계에서는 그게 싫다. 그래서 FakeTensor — “shape, dtype, device 만 들고 있고 storage 는 없는 Tensor”.
FakeTensor 가 풀어주는 것들.
torch.matmul(a, b) 의 출력 shape 을 실제 곱셈 안 하고 알 수 있음. metadata 만 다룸.if x.sum() > 0 같은 검사를 trace 시점에 못 함. 그래서 graph break.from torch._subclasses import FakeTensorMode
with FakeTensorMode():
a = torch.randn(1024, 1024, device="cuda")
b = torch.randn(1024, 1024, device="cuda")
c = a @ b
print(c.shape) # torch.Size([1024, 1024])
print(c.dtype) # torch.float32
print(c.device) # cuda:0
# 실제 메모리 0 — c 는 metadata 만
(1) data-dependent op — torch.unique, torch.nonzero, torch.repeat_interleave(x, n) 같이 출력 shape 이 input value 에 의존. fake tensor 가 추정 못 함. (2) custom op 의 fake impl 미등록 — 사용자 정의 op 이 fake tensor mode 에서 어떻게 동작하는지 별도 등록 필요 (@register_fake).
PyTorch 의 default autograd 는 tape-based — forward 가 실행되면서 동시에 grad function 을 tape 에 쌓고, backward 호출 시 그걸 거꾸로 실행. AOT autograd 는 다르다 — forward 가 실행되기 전 단계에서 backward graph 를 같이 만들어 둔다 (Ahead-Of-Time).
왜 AOT 가 필요한가.
backward graph 안에 forward op 을 다시 넣어 둔다 (activation checkpointing 의 정형화된 형태). 메모리 절약. 단, 그 결정을 자동으로 하는 게 어려움 — 강의에서 “min-cut partitioning” 이라는 알고리즘으로 forward 와 backward 사이의 “저장” 과 “재계산” 의 경계를 결정한다고 알려져 있다.
# AOT autograd 의 의사 흐름
joint_graph = trace_forward_and_backward(fn, inputs)
# joint_graph: forward + backward 가 한 graph 안에
fwd_graph, bwd_graph = partition(joint_graph)
# forward 가 무엇을 backward 에 넘겨줄지 결정
# (saved tensors vs recomputed)
fwd_compiled = inductor_compile(fwd_graph)
bwd_compiled = inductor_compile(bwd_graph)
# 호출 시 forward 만 실행, backward 는 backward() 시
def compiled_fn(*args):
out, saved = fwd_compiled(*args)
register_for_backward(saved, bwd_compiled)
return out
x → linear → gelu → linear → norm → out (fusion 가능)
Inductor 는 torch.compile 의 default backend. FX graph 를 받아서 Triton kernel(GPU) 또는 C++/cpp_wrapper(CPU) 로 lowering. 사용자가 거의 보지 않지만 fusion / scheduling / vectorization 의 결정을 여기서 한다.
Inductor 는 graph 의 op 들을 — pointwise (elementwise), reduction (sum/max), tile (matmul/conv) 의 세 카테고리로 분류. 같은 카테고리는 잘 fuse 되고 (특히 pointwise), reduction 다음의 pointwise 도 fuse 가능. tile 연산은 보통 자기 kernel 로 — Inductor 는 이 결정을 자동으로.
실전에서 Inductor 의 출력 코드를 보고 싶다면 — TORCH_LOGS=output_code 환경변수. 생성된 Triton 코드가 console 에 dump 된다. L001 에서 Mark 가 같은 트릭으로 새 커널의 “시작점” 을 얻는 방법을 깐다.
cudagraph 는 PyTorch 의 별도 기능 — 같은 sequence 의 GPU 호출을 한 번 캡처해서 “하나의 단위” 로 launch. launch overhead 가 크게 줄어든다 (특히 작은 op 이 많이 있는 경우).
cudagraph 가 도움 되는 자리.
cudagraph 와 torch.compile 은 서로 다른 layer 에서 작동한다. compile 이 op 들을 fuse 해서 큰 kernel 로 만들고, cudagraph 가 그 kernel 들의 sequence 를 한 번에 launch. 같이 쓸 수 있고 보통 같이 쓰는 게 best.
(1) shape 이 고정이어야 함 — graph 캡처 시점의 shape 으로 lock. (2) memory address 도 고정 — pre-allocated buffer 사용. (3) conditional 안 됨 — capture 한 sequence 만 그대로. (4) capture 시 “warmup” 한 번 필요 — 첫 호출은 캡처용.
# torch.compile + cudagraph 같이
fn = torch.compile(model, mode="reduce-overhead")
# reduce-overhead = inductor + cudagraph
큰 batch / 큰 모델에서는 — kernel 자체가 충분히 무거워서 launch overhead 가 무의미. cudagraph 가 잡는 메모리 (pre-allocated buffer) 가 부담일 수 있음. 그리고 dynamic shape 이 있으면 cudagraph 가 매번 capture 다시 — 오히려 손해.
torch.compile 디버깅의 가장 큰 함정은 — 어디서 깨졌는지가 안 보임. 4-stage 파이프라인 어디든 깨질 수 있음. 표준 시퀀스 를 거치며 layer 별로 좁혀간다.
torch._dynamo.explain(fn, *args) — graph break 의 횟수와 자리를 알려준다. 0 이 목표.TORCH_LOGS="recompiles". 학습 step 마다 recompile 하면 cache_size_limit 도달. dynamic 모드 활성화 또는 shape 고정.TORCH_LOGS="output_code". Inductor 의 Triton 출력. fusion 이 의도대로 됐는지 확인.TORCH_COMPILE_DEBUG=1. ./torch_compile_debug/ 디렉토리에 파일 dump.backend="aot_eager" 로 두면 AOT 까지만, Inductor 안 함. 둘 중 어느 단계가 문제인지 식별.torch._dynamo.explain(fn, x) — graph break 위치 확인. get_graph_break_reasons()
TORCH_LOGS="recompiles" python train.py — recompile 자주 나면 dynamic 모드.
TORCH_LOGS="output_code" — Triton 코드 dump. fusion 검증.
TORCH_COMPILE_DEBUG=1 — full trace 디렉토리 dump.
aot_eager, eager — 어느 단계 문제인지 식별.
“처음 torch.compile 을 쓸 때 — 먼저 graph break 0 을 만든다. fusion 최적화는 그 다음.” graph break 가 있으면 fusion 도 깨지니까 둘이 동시에 풀려야 함. 원본 영상 확인 필요 — 강의에서 Richard 가 같은 조언을 했는지.
강의의 실전 답 — “이런 패턴은 compile 가 잡는다”, “이런 패턴은 깨진다”. 코딩 습관 단위.
torch.where, mask 처리.if cond: x else: y 가 cond 가 tensor 면 break. torch.where(cond, x, y).x += y 도 가능하지만, autograd 와 잘 안 맞아서 일부 케이스에서 break. 가능하면 x = x + y.torch.library.custom_op 로 등록. fake impl 도 같이.torch.unique, torch.nonzero 같이 출력 shape 이 입력 값에 의존하는 op 은 graph 밖으로.torch._dynamo.mark_dynamic(x, 0).compile 친화적 코드의 황금률 — “tensor 가 graph 안에서 끝까지 살아 있게”. tensor → Python value → tensor 로 왔다갔다 하면 break 폭주. 모든 logic 을 tensor op 으로 표현.
TORCH_LOGS=output_code 로 코드 dump.본 노트의 모든 코드 스니펫과 동작 설명은 PyTorch 2.0+ 공식 문서와 dev-discuss RFC 를 토대로 한 재구성. PyTorch 의 발전 속도가 빠르므로 — 실제 사용 시 자기 PyTorch 버전의 docs 를 직접 확인. 특히 backend 옵션과 mode 이름은 자주 변경됨.