Python fn (eager)
│ ① TorchDynamo
│ · CPython frame intercept
│ · guard 생성
▼
FX Graph (torch operator level)
│ ② AOT Autograd
│ · fwd + bwd 분리
│ · joint graph → partitioner
▼
FX Graph (aten / prims, 2개: fwd, bwd)
│ ③ Decomposition
│ · _refs / _decomp 적용
▼
Inductor Input IR (lowered aten)
│ ④ Inductor
│ · Loops/Pointwise/Reduction
│ · Scheduler fusion
▼
Triton Python / C++/OpenMP
│ ⑤ Codegen → ptxas / gcc
▼
PTX + SASS / .so
source: torch/_dynamo/convert_frame.py · torch/_functorch/aot_autograd.py · torch/_inductor/compile_fx.py
| 단계 | 입력 | 출력 | 핵심 역할 |
|---|---|---|---|
| Dynamo | Python bytecode | FX Graph + guards | frame capture |
| AOT Autograd | FX (torch op) | fwd+bwd FX | diff graph 생성 |
| Decomp | aten 고수준 | prims/aten 저수준 | canonical form |
| Inductor | lowered FX | Inductor IR | fusion·schedule |
| Codegen | Inductor IR | Triton/C++ | kernel source |
import torch def f(x, y): return (x + y).relu() * 2 g = torch.compile(f, backend="inductor", mode="reduce-overhead", fullgraph=False, dynamic=None) g(x, y) # compile on first call g(x, y) # guard hit → cached kernel
| mode | 의미 |
|---|---|
default | Inductor fusion, no cudagraph |
reduce-overhead | + cudagraph capture (decode loop) |
max-autotune | + matmul/conv autotune, coord descent |
| layer | 경로 / key |
|---|---|
| Dynamo | torch._dynamo.config.cache_size_limit (per-fn) |
| Inductor FX | TORCHINDUCTOR_CACHE_DIR · default /tmp/torchinductor_$USER |
| Triton | ~/.triton/cache · key hash ↗ V11 §15 |
| cubin | Triton cache 내부 |
_PyInterpreterState_SetEvalFrameFuncsource: torch/_dynamo/convert_frame.py (v2.x)
Python frame
bytecode = co_code
for each op:
if op produces Tensor:
→ trace into FX
→ record Guard(expr)
elif op needs concrete value:
→ graph break
→ fallback to eager
else:
→ emit residual bytecode
| guard 종류 | 체크 대상 |
|---|---|
TENSOR_MATCH | dtype · device · stride · size |
CONSTANT_MATCH | Python int/bool/None 상수 |
TYPE_MATCH | object type identity |
ID_MATCH | 특정 object id (module param) |
DUPLICATE_INPUT | alias 관계 (x is y) |
DICT_KEYS | dict의 key 집합 |
각 guard는 C++로 lowering되어 call 당 ~ns overhead (v2.3+)
if x.item() > 0)print, input 등)torch.* op (목록 있음)@torch._dynamo.graph_break()fullgraph=True이면 graph break 시 raise. 디버깅용.
from torch._dynamo import explain exp = explain(f)(x, y) print(exp.graph_count, exp.graph_break_count, exp.break_reasons)
반환: graph 수, break 수, 각 break의 (op, reason, stack)
| 조건 | 재컴파일? |
|---|---|
| shape 변화 (static) | 예 |
| shape 변화 (dynamic=True) | 대개 아니오 |
| dtype 변화 | 예 |
| device 변화 | 예 |
| value 변화 (Tensor) | 아니오 |
| Python 상수 변화 | 예 |
cache_size_limit → disable
default 8. 초과 시 해당 frame은 eager 실행. 많은 shape 변화가 예상되면 dynamic=True를 먼저 시도.
source: torch/_functorch/aot_autograd.py
inputs: (x1, x2, ..., xk) primals
labels: (l1, ..., lm) (loss inputs)
outputs: (y1, ..., yn) forward outputs
+ (g1, ..., gk) backward grads
joint_fn(primals, tangents) -> (outs, grad_ins)
= torch.autograd.Function.apply
but traced into ONE FX graph
from torch._functorch.aot_autograd import aot_function compiled = aot_function( fn, fw_compiler=my_fw, # FX -> callable bw_compiler=my_bw, partition_fn=default_partition, decompositions=core_aten)
default_partition: autograd와 동일한 activation 저장min_cut_rematerialization_partition: recompute로 saved tensor 감소 (checkpointing) ↗ V17 §7add_, relu_) → out-of-place로 재작성compiled = aot_module( mod, fw_compiler=inductor_fw, bw_compiler=inductor_bw)
torch.no_grad() 또는 inference_mode이면 AOT가 forward graph만 생성requires_grad=True인 입력은 compile 경로로 들어가면 bwd graph가 항상 생성된다. inference 경로를 원하면 no_grad를 명시.
| 경로 | fwd | bwd |
|---|---|---|
| eager | immediate | runtime tape |
| no_grad compile | FX | 없음 |
| train compile | fwd FX | bwd FX |
print(gm.graph)로 즉시 관측| op | 의미 | target |
|---|---|---|
placeholder | graph input | arg name |
get_attr | module attribute | attr path |
call_function | free fn | callable |
call_method | self.method() | method name |
call_module | sub-module | module path |
output | graph return | tuple of nodes |
op, target, args, kwargs, namemeta['val'] — FakeTensor (shape/dtype/device)meta['stack_trace'] — 원본 코드 위치users, all_input_nodes — DAG edgedef f(x, W, b): y = x @ W z = y + b return z.relu() gm = torch.fx.symbolic_trace(f) print(gm.graph)
graph(): %x : [num_users=1] = placeholder[target=x] %W : [num_users=1] = placeholder[target=W] %b : [num_users=1] = placeholder[target=b] %matmul : call_function[target=torch.matmul](args=(%x, %W)) %add : call_function[target=operator.add](args=(%matmul, %b)) %relu : call_method[target=relu](args=(%add,)) return (%relu,)
# 1) node-by-node walk for n in gm.graph.nodes: if n.op == "call_function" and \ n.target is operator.add: with gm.graph.inserting_after(n): new = gm.graph.call_function( torch.add, n.args) n.replace_all_uses_with(new) gm.graph.erase_node(n) gm.recompile()
replace_pattern — subgraph rewriteInterpreter — 재실행으로 meta 재계산Transformer — node 변환 framework.recompile() 호출 시 Graph로부터 Python source를 생성해 forward()로 저장한다.
생성된 source는 gm.code로 확인 — eager fallback 및 디버깅의 핵심.
| 축 | symbolic_trace | Dynamo |
|---|---|---|
| 메커니즘 | Proxy tensor | bytecode |
| 제어흐름 | 불가 | graph break |
| dynamic shape | 제한적 | SymInt |
| guard | 없음 | 있음 |
symbolic_trace는 data-dependent if·list.append 등에서 trace 불가. torch.compile은 Dynamo만 사용.
| 층 | 예 | 개수 |
|---|---|---|
aten::* | softmax, addmm, layer_norm | ~2000+ |
prims::* | add, mul, broadcast_in_dim, reduction | ~100 |
source: torch/_refs · torch/_prims · torch/_decomp
torch._refs.* — op의 reference Python 구현 (prims 기반)torch._decomp.* — aten→aten 분해 (derivative-safe)core_aten_decompositions() — Inductor가 기본으로 받는 dictBEFORE (torch op level, Dynamo 직후):
%a : placeholder[target=x]
%sm = call_function[target=aten.softmax](
args=(%a, -1))
return (%sm,)
AFTER decomp (core/prims level):
%a = placeholder[target=x]
%mx = call_function[target=aten.amax](
args=(%a, [-1], True))
%sub = call_function[target=aten.sub](
args=(%a, %mx))
%ex = call_function[target=aten.exp](
args=(%sub,))
%sm = call_function[target=aten.sum](
args=(%ex, [-1], True))
%div = call_function[target=aten.div](
args=(%ex, %sm))
return (%div,)
stable softmax: subtract amax for overflow safety ↗ V09 §9
| 관점 | before | after |
|---|---|---|
| Node 수 | 적음 | 많음 |
| fusion 기회 | 없음 | 많음 |
| 의미 명확성 | 고수준 | 저수준 |
| backward 생성 | custom rule | auto |
from torch._decomp import register_decomposition from torch import Tensor @register_decomposition(aten.mse_loss) def mse_loss(x: Tensor, t: Tensor, reduction: int = 1): d = (x - t) loss = d * d if reduction == 1: return loss.mean() return loss
_refs (prims 기반)_decomp (aten→aten)aten::mm, aten::convolution — decomp 대신 extern kernel (cuBLAS/cuDNN) 호출로 남김max-autotune 모드는 Triton matmul 후보도 생성compile_fx(gm, example_inputs)로 진입한다. gm은 이미 decomp가 적용된 FX GraphModule, example_inputs는 FakeTensor 목록이다.
source: torch/_inductor/compile_fx.py :: compile_fx_inner
meta['val'] (FakeTensor) 존재| 필드 | 용도 |
|---|---|
shape | (s0, 64, 64) |
stride | layout inference |
dtype | fusion 호환성 |
device | codegen 분기 |
storage_offset | view 추적 |
for node in graph: fake_args = [a.meta['val'] for a in node.inputs] out_fake = node.target(*fake_args) # FakeTensor op node.meta['val'] = out_fake # shape/dtype 기록
Meta key로 라우팅make_contiguous_strides_for(shape) 기본for node in gm.graph.nodes:
lower = LOWERINGS[node.target]
ir_out = lower(*input_ir_tensors,
**kwargs)
env[node] = ir_out # TensorBox
| op 유형 | Inductor 처리 |
|---|---|
| pointwise | Pointwise IR |
| reduction | Reduction IR |
| matmul/conv | extern (cuBLAS/cuDNN) 또는 Triton template |
| 복잡 custom | eager fallback |
f(idx) → value로 서술된다.
source: torch/_inductor/ir.py
| class | 역할 |
|---|---|
TensorBox | lowering 결과 wrapper |
Loops | iteration 공간을 가진 base |
Pointwise | f(idx) → v (no reduction) |
Reduction | reduction axes + combine |
Scatter | indirect store |
MatMul | extern / Triton template |
Buffer | realized storage |
ComputedBuffer | Loops + Buffer pair |
def relu_add(a, b): def inner(idx): x = a.load(idx) y = b.load(idx) return ops.maximum(x + y, 0) return Pointwise( device=a.device, dtype=a.dtype, inner_fn=inner, ranges=a.size)
inner_fn(idx)는 symbolic index → valueReduction( device=x.device, dst_dtype=torch.float32, src_dtype=x.dtype, inner_fn=lambda idx, ridx: x.load(idx + ridx), ranges=[M], reduction_ranges=[N], reduction_type="sum")
rangesreduction_rangessum · prod · max · min · argmax · argmin · welford| 그룹 | op |
|---|---|
| arith | add · sub · mul · div |
| math | exp · log · sqrt · sin · cos |
| relu-like | maximum · minimum · where |
| compare | eq · lt · le · gt · ge |
| cast | to_dtype · to_dtype_bitcast |
| memory | load · store · index_expr |
ops는 backend에 의해 재해석 (Triton ops vs C++ ops)
InputBuffer — graph input / placeholderComputedBuffer — Inductor가 생성할 tensorExternKernel — cuBLAS/cuDNN 호출 nodesource: torch/_inductor/scheduler.py :: Scheduler
| 종류 | 의미 |
|---|---|
| vertical | producer→consumer 합치기 (inner_fn 합성) |
| horizontal | 동일 iteration space의 독립 node 합치기 |
| reduction+pointwise | reduction 뒤 elementwise 병합 |
| epilogue | matmul 뒤 pointwise 합치기 (Triton template) |
| producer | consumer | fuse? | 조건 |
|---|---|---|---|
| pointwise | pointwise | 예 | 동일 shape · elementwise index |
| pointwise | reduction | 예 | pointwise가 reduction 입력 |
| reduction | pointwise | 조건부 | reduction 출력이 broadcast 없이 사용 |
| reduction | reduction | 조건부 | axis 호환 · persistent 가능 |
| pointwise | matmul | 예(prologue) | Triton matmul template 한정 |
| matmul | pointwise | 예(epilogue) | Triton matmul template 한정 |
| scatter | any | 아니오 | indirect store → realize |
| any | mutation | 아니오 | inplace 경계 |
| dtype A | dtype B | 예 | cast op 자동 삽입 |
| device A | device B | 아니오 | kernel 경계 |
realize() hintnodes = topo_sort(IR graph)
for n in nodes:
for pred in n.deps:
if can_fuse(pred, n):
fuse(pred, n) # vertical
for n in nodes:
for sibling in same_range(n):
if can_fuse_h(n, sibling):
fuse_h(n, sibling) # horizontal
loop_ordering_after_fusion)@triton.jit을 통해 컴파일한다.
source: torch/_inductor/codegen/triton.py
@triton.jit def triton_poi_fused_add_relu_0( in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = tl.load(in_ptr0 + xindex, xmask) x1 = tl.load(in_ptr1 + xindex, xmask) r = tl.maximum(x0 + x1, 0.0) tl.store(out_ptr0 + xindex, r, xmask)
triton_{kind}_fused_{ops}_{hash}kind: poi pointwise · red reduction · per persistent| 축 | 심볼 | 기본 |
|---|---|---|
| x (linearized) | XBLOCK | 256 / 1024 |
| y (2D) | YBLOCK | 16 / 32 |
| r (reduction) | RBLOCK | 8 / 16 / 32 .. |
autotune config는 triton_heuristics.py에서 후보 생성. size_hint에 따라 분기.
call_fn: args: input ptr, output ptr, shape grid: launch config stream: current CUDA stream triton_kernel[grid](... , XBLOCK=256)
tl.store 반복max-autotune은 Triton mm template(kernel/mm.py)을 사용해 tile 크기 후보를 brute-force로 autotune하고 최적 config를 cache한다.
load_scale · cast 등 fusebias · gelu · layernorm 등 fuse| prefix | 의미 |
|---|---|
triton_poi_* | pointwise |
triton_red_* | looped reduction |
triton_per_* | persistent reduction |
triton_tem_* | template kernel (mm/conv) |
이름으로 Inductor가 만든 어떤 kind인지 바로 읽힘.
gcc/clang으로 .so를 빌드한 뒤 dlopen으로 로드한다.
source: torch/_inductor/codegen/cpp.py · cpp_wrapper.py
extern "C" void kernel( const float* in0, const float* in1, float* out, long N) { #pragma omp parallel for for(long i=0; i<N; i+=16){ auto a = at::vec::Vectorized<float> ::loadu(in0+i); auto b = at::vec::Vectorized<float> ::loadu(in1+i); (a + b).relu().store(out+i); } }
| ISA | width | type |
|---|---|---|
| AVX2 | 8 × float32 | __m256 |
| AVX-512 | 16 × float32 | __m512 |
| NEON (ARM) | 4 × float32 | float32x4_t |
| SVE | scalable | svfloat32_t |
at::vec::Vectorized<T>가 ISA 추상화.
#pragma omp parallel fortorch.get_num_threads()omp parallel reduction(+:acc).py source → write /tmp/torchinductor_*/xx.cpp → compile (gcc/clang) -> .so → dlopen -> ctypes fn ptr → call via cpp_wrapper
config.cpp_wrapper = True → Python wrapper까지 C++로torch.export + AOTInductor)의 기반| 항목 | CPU | CUDA |
|---|---|---|
| matmul | MKL/oneDNN extern | Triton/cuBLAS |
| fusion 깊이 | 얕음 | 깊음 |
| dynamic shape | 지원 | 지원 |
| autotune | 제한 | 풍부 |
cudaGraphLaunch 단 한 번으로 replay하는 API. launch당 ~수 μs overhead를 ~ns 수준으로 내린다.
source: torch/_inductor/cudagraph_trees.py
mode="reduce-overhead" 또는 mode="max-autotune"config.triton.cudagraphs = True| 요구 | 이유 |
|---|---|
| static shape | kernel launch가 shape에 의존 |
| static pointer | input tensor address 고정 |
| no CPU sync | capture mid에 synchronize 금지 |
| no mem alloc | capture 중 new alloc 불가 |
| no host→device copy | non-captureable |
config.triton.cudagraph_skip_dynamic_graphs로 우회default ON in v2.2+ when mode uses cudagraph
| 상황 | 유리도 | 이유 |
|---|---|---|
| LLM decode loop | ★★★ | small kernel × 많음 |
| training step | ★★ | shape 거의 고정 |
| prefill | ★ | shape 다양 → re-capture 많음 |
| variable batch | △ | graph explosion |
.cpu() / .item() 호출| env | 효과 |
|---|---|
TORCH_LOGS=cudagraphs | capture 로그 |
TORCH_CUDAGRAPH_TRACE=1 | trace 상세 |
TORCHINDUCTOR_CUDAGRAPHS=0 | 비활성 |
torch._inductor.config.* — Python attrTORCHINDUCTOR_* env — env 변수 매핑source: torch/_inductor/config.py
| flag | 효과 |
|---|---|
triton.unique_kernel_names | kernel name 안에 op 이름 포함 (디버깅) |
triton.cudagraphs | cudagraph 켬 |
coordinate_descent_tuning | 추가 autotune |
epilogue_fusion | matmul epilogue fuse |
cpp_wrapper | C++ wrapper 생성 |
freezing | const folding · inference |
| flag | 뜻 |
|---|---|
max_fusion_size | fuse 할 최대 node 수 |
max_autotune_gemm | matmul 후보 수 상한 |
triton.max_block | block size 상한 (x/y/r) |
| flag / env | 산출물 |
|---|---|
TORCH_COMPILE_DEBUG=1 | IR dump dir 생성 |
TORCH_LOGS="inductor" | Inductor 단계 로그 |
TORCH_LOGS="dynamo" | Dynamo 로그 |
TORCH_LOGS="aot" | AOT Autograd 로그 |
TORCH_LOGS="output_code" | 생성된 Triton/C++ 소스 |
TORCH_LOGS="schedule" | scheduler 결정 |
/tmp/torchinductor_$USER/
└ <fxgraph_hash>/
├ output_code.py # Python wrapper
├ fx_graph_readable.py # pre-lowering
├ fx_graph_runnable.py # re-executable
├ fx_graph_transformed.py
├ triton_<hash>.py # kernel source
└ fx_graph.aot_inductor.so (cpp_wrapper)
| flag | 효과 |
|---|---|
dynamic_shapes | symbolic shape 허용 |
assume_static_by_default | 첫 호출을 static로 |
automatic_dynamic | 재컴파일 시 axis 자동으로 symbolic |
자세한 동작은 ↗ §15 Guard.
TORCHINDUCTOR_TRACE=1 → 각 kernel call에 NVTX range| 목적 | flag |
|---|---|
| 이름 해독 | unique_kernel_names=True |
| IR 확인 | TORCH_COMPILE_DEBUG=1 |
| kernel 소스 | TORCH_LOGS=output_code |
| 재현 | fx_graph_runnable.py 실행 |
torch.compile(fn, backend=...)의 backend는 Dynamo가 추출한 FX GraphModule을 callable로 컴파일하는 함수 또는 등록된 이름이다.
| name | 설명 |
|---|---|
inductor | default · TorchInductor |
eager | Dynamo만 적용 (graph 테스트) |
aot_eager | Dynamo + AOT Autograd, backend은 eager |
cudagraphs | Inductor 없이 cudagraph |
onnxrt | ONNX Runtime |
tvm | TVM ↗ V14 |
eageraot_eageraot_eager로 bisectfrom torch._dynamo import register_backend @register_backend def my_be(gm, example_inputs): # gm: torch.fx.GraphModule # example_inputs: list[FakeTensor] return lambda *args: gm(*args)
(gm, example_inputs) → callablefrom functorch.compile import \ aot_module_simplified def my_be(gm, inputs): return aot_module_simplified( gm, inputs, fw_compiler=my_compile, bw_compiler=my_compile)
backend이 fwd/bwd 모두 컴파일하려면 AOT를 명시 호출.
| 축 | inductor | onnxrt | tvm |
|---|---|---|---|
| IR 수준 | Inductor IR | ONNX | Relay/TIR |
| dynamic shape | 지원 | 제한 | 제한 |
| training | 지원 | 미지원 | 미지원 |
| cudagraph | 내장 | 외부 | 외부 |
상세 비교는 ↗ V14 §15.
import torch print(torch._dynamo.list_backends()) # 현재 설치된 backend 이름 목록
TORCHDYNAMO_REPRO_AFTER=dynamoTORCHDYNAMO_REPRO_AFTER=aot| 종류 | dispatch key | 용도 |
|---|---|---|
| CUDA impl | CUDA | 실제 GPU kernel |
| CPU impl | CPU | reference |
| Meta / Fake impl | Meta | shape inference |
| Autograd impl | AutogradCUDA | backward rule |
@torch.library.custom_op( "my::scaled_add", mutates_args=()) def scaled_add( x: torch.Tensor, y: torch.Tensor, alpha: float) -> torch.Tensor: return x + alpha * y
@scaled_add.register_fake def _(x, y, alpha): torch._check(x.shape == y.shape) return torch.empty_like(x)
def setup_ctx(ctx, inputs, output): x, y, alpha = inputs ctx.alpha = alpha def bwd(ctx, grad): return (grad, ctx.alpha * grad, None) scaled_add.register_autograd( bwd, setup_context=setup_ctx)
from torch._inductor.lowering import register_lowering @register_lowering( torch.ops.my.scaled_add) def _(x, y, alpha): return x + alpha * y # IR ops
| 단계 | 필수? |
|---|---|
| kernel impl (CUDA) | 필수 |
| fake impl | 필수 (compile path) |
| autograd rule | train 시 필수 |
| inductor lowering | 선택 (fusion 원할 때) |
| schema mutates_args | in-place 있을 때 |
Library("my", "DEF") — 새 op 선언Library("aten", "IMPL") — 기존 aten overridetorch.library.opcheck(fn, args) — 등록 정합성 검사torch.empty_like가 아닌 실제 compute를 수행 → Inductor가 shape prop에서 실데이터를 요구해 OOM·성능 저하.
| 축 | static | symbolic |
|---|---|---|
| 의미 | 특정 값으로 특화 | SymInt로 일반화 |
| guard | equality | range / constraint |
| codegen | 상수 embed | runtime arg |
| 성능 | ↑ (상수 fold) | 약간 ↓ |
| 재컴파일 | 값 바뀌면 | 거의 없음 |
SymInt = symbolic integer, 대수 연산 지원ShapeEnv가 constraint 보관trace: if x.shape[0] == 1024: # guard: s0 == 1024 (static) if x.shape[0] > 0: # guard: s0 > 0 (constraint) y = torch.empty(x.shape) # no guard (shape-passthrough) z = x.view(-1, 8) # guard: s0 mod 8 == 0
source: torch/fx/experimental/symbolic_shapes.py
torch._dynamo.mark_dynamic(x, 0) # axis 0만 symbolic, 나머지 static torch._dynamo.mark_static(y, 1) # axis 1은 반드시 특정 값으로 컴파일
dynamic=True 또는 mark_dynamic으로 symbolic화cache_size_limit 조정| 도구 | 출력 |
|---|---|
TORCH_LOGS="guards" | 각 guard 리스트 |
TORCH_LOGS="recompiles" | 왜 재컴파일인지 |
torch._dynamo.explain | break/guard 요약 |
dynamic=True를 먼저 시도.
| category | 의미 |
|---|---|
dynamo | frame capture · bytecode |
aot | joint graph · partition |
inductor | lowering · scheduler |
output_code | 최종 Triton/C++ 소스 |
schedule | fusion 결정 |
guards | guard 목록 |
recompiles | 재컴파일 이유 |
graph_breaks | break 이유 |
여러 카테고리 동시: TORCH_LOGS="+inductor,output_code"
debug_trace/ 하위 디렉토리에 step-by-step$TORCHINDUCTOR_CACHE_DIR or /tmp/torchinductor_$USER/
<fxhash>/
├ fx_graph_readable.py ← Dynamo 직후
├ fx_graph_transformed.py ← decomp 후
├ fx_graph_runnable.py ← 단독 실행 가능
├ output_code.py ← Python wrapper + triton src
├ triton_poi_*.py ← per-kernel source
└ debug_trace/
├ 0_before_pre_grad_graph.py
├ 1_after_decomp_graph.py
├ 2_after_post_grad_graph.py
└ *_ir_pre/post_fusion.txt
# 단독 재현 import torch exec(open("fx_graph_runnable.py").read()) # args0, mod 변수 사용 가능 mod(*args0)
버그를 최소 FX로 축소할 때 유용. minifier 출력과 동일 형식.
| path | 내용 |
|---|---|
torch/_dynamo/ | frame capture |
torch/_functorch/ | AOT Autograd |
torch/fx/ | FX Graph |
torch/_decomp/ | decomp rules |
torch/_refs/ | prims refs |
torch/_inductor/ir.py | Inductor IR |
torch/_inductor/scheduler.py | fusion |
torch/_inductor/codegen/triton.py | Triton codegen |
torch/_inductor/codegen/cpp.py | C++ codegen |
torch/_inductor/config.py | config · env |
triton_poi_* 이름으로 식별| 증상 | 먼저 확인 |
|---|---|
| 매 call 느림 | graph break / recompile 로그 |
| numerical 불일치 | decomp 비활성 + aot_eager |
| OOM | fusion 실패 → realize 과다 |
| cudagraph 안 함 | mode 및 static input 여부 |
| stage | dump env | 소스 |
|---|---|---|
| Dynamo | TORCH_LOGS=dynamo | _dynamo/ |
| AOT Autograd | TORCH_LOGS=aot | _functorch/ |
| Decomp | COMPILE_DEBUG=1 | _decomp/ |
| Inductor lowering | TORCH_LOGS=inductor | _inductor/ir.py |
| Scheduler | TORCH_LOGS=schedule | scheduler.py |
| Codegen | TORCH_LOGS=output_code | codegen/ |
| Guard | TORCH_LOGS=guards | symbolic_shapes.py |
| Recompile | TORCH_LOGS=recompiles | _dynamo/ |
debug_trace/에 *_graph.py로 남는다.
dynamic=True 또는 mark_dynamiccache_size_limit 확대inference · decode loop → reduce-overhead training · 큰 matmul → max-autotune 디버깅·정확도 bisect → aot_eager graph break 원인 찾기 → eager
register_fakeregister_autograd (학습 시)register_lowering (fusion 원할 때)opcheck로 정합성 확인| env | 용도 |
|---|---|
TORCH_LOGS | stage별 log |
TORCH_COMPILE_DEBUG | full dump |
TORCHINDUCTOR_CACHE_DIR | cache 위치 |
TORCHINDUCTOR_UNIQUE_KERNEL_NAMES | kernel name 해독 |
TORCHINDUCTOR_CUDAGRAPHS | cudagraph on/off |
TORCHDYNAMO_REPRO_AFTER | minifier |