Triton 은 Python 안의 DSL 인데 — “Python 같은 코드” 만으로 추측해서 짜면 거의 항상 느린 결과가 나온다. Triton 의 내부 모델 (program_id, tile 단위 op, autotune 의 의미) 을 정확히 알고 짜야 한다. Umer Adil 이 직접 만든 notebook 위에서 — 첫 vector add 부터 본격 matmul 까지 단계적으로 깐 실전 가이드. 디버깅 (interpret 모드), benchmarking, autotuning 의 함정, Triton vs CUDA 의사결정.
Triton 의 syntax 는 numpy 와 비슷해 보인다 — tl.load, tl.store, +, *. 그래서 처음 시작하는 사람은 “이건 그냥 Python 으로 GPU 코드 짜는 거구나” 라고 생각하기 쉽다. 그러나 Triton 의 내부 모델은 — element 단위가 아니라 tile 단위이고, 한 program (kernel instance) 이 한 tile 의 출력을 만든다. 이 mental model 이 안 잡히면 디버깅이 안 되고 성능도 안 나온다.
Umer 의 강의가 풀려는 질문 셋.
L001 (Mark) 이 “Triton 을 처음 만나는 자리” 였다면 — 이 강의는 “Triton 으로 일을 시작한 사람이 빠지는 함정 카탈로그”. 코드는 짧지만 패턴이 깊다. notebook 자체가 강의 자료이자 실습 — Umer 가 일부러 bug 를 박아두기도 한다.
CUDA 가 thread 단위로 “나는 i 번째 thread 다, i 번째 element 를 처리한다” 의 SIMT 모델이라면 — Triton 은 tile 단위로 “나는 j 번째 program 이고, j 번째 tile 의 출력을 만든다”의 모델. 그 사이에서 컴파일러가 tile → thread 매핑을 자동으로 결정한다.
tile 추상이 좋은 이유 — “tensor cores 위에서 도는 모양” 과 자연스럽게 맞는다. MMA instruction 은 16×16 같은 tile 위에서 도는데, Triton 의 BLOCK_SIZE 가 그 단위를 직접 표현. CUDA 처럼 thread 를 16×16 으로 “손으로 묶을” 필요가 없다.
Umer 의 notebook 첫 코드. 단순한 vector add 안에 Triton 의 4개 핵심 op 가 모두 등장한다. 이 4개를 정확히 이해하면 다른 모든 Triton 코드가 같은 패턴의 변형이라는 게 보인다.
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
BLOCK_SIZE: tl.constexpr):
# 1) program_id — 나는 몇 번째 program?
pid = tl.program_id(0)
# 2) arange — 내가 다룰 tile 의 element index
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# 3) mask — 마지막 tile 의 tail 처리
mask = offs < n
# 4) load — HBM 에서 SRAM 으로
x = tl.load(x_ptr + offs, mask=mask)
y = tl.load(y_ptr + offs, mask=mask)
# 5) compute — tile 단위 numpy-like
out = x + y
# 6) store — SRAM 에서 HBM 으로
tl.store(out_ptr + offs, out, mask=mask)
각 op 의 의미.
호출하는 자리.
n = x.numel()
grid = (triton.cdiv(n, BLOCK_SIZE),)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=1024)
matmul 은 같은 패턴의 2D 버전 — pid_m, pid_n = tl.program_id(0), tl.program_id(1), offs_m = pid_m*BM + tl.arange(0, BM), offs_n = pid_n*BN + tl.arange(0, BN), 그 후 K 차원으로 inner loop. vector add 의 두 차원 확장이 matmul — Umer 의 notebook 이 그 확장 과정을 step by step 으로 보여줌.
Triton 의 가장 큰 학습 곡선 단축 도구. TRITON_INTERPRET=1 환경변수 또는 @triton.jit(interpret=True) 데코레이터를 켜면 — kernel 이 CPU 위에서 시뮬레이션되어 Python breakpoint() 가 그대로 동작한다. tl.load 의 결과 모양을 직접 print, step 마다 변수 검사. GPU kernel 디버깅의 의미가 한 단계 바뀐다.
import os
os.environ['TRITON_INTERPRET'] = '1'
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n
x = tl.load(x_ptr + offs, mask=mask)
breakpoint() # ← CPU 에서 멈춤
y = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, x+y, mask=mask)
interpret 모드의 한계.
Umer 의 helper — triton_util.py 의 breakpoint_if 와 print_if. 특정 program_id 에서만 breakpoint. “pid_0=0, pid_1=1 일 때만 멈춰라”.
def breakpoint_if(conds, pid_0=[0], pid_1=[0],
pid_2=[0]):
from IPython.core.debugger import set_trace
if test_pid_conds(conds, pid_0, pid_1, pid_2):
set_trace()
Triton 의 launch 설정 3개 — BLOCK_SIZE (tile 크기), num_warps (한 program 안의 warp 수), num_stages (pipeline 깊이) — 이 같은 코드의 성능을 크게 흔든다. Umer 의 강의에서 직접 보여준 예: matmul 한 코드가 설정에 따라 1ms 와 10ms 사이에 흩어진다.
# autotune — 사용자가 후보 set 을 선언
@triton.autotune(
configs=[
triton.Config({'BLOCK_M':64, 'BLOCK_N':64,
'BLOCK_K':32},
num_warps=4, num_stages=3),
triton.Config({'BLOCK_M':128, 'BLOCK_N':128,
'BLOCK_K':32},
num_warps=8, num_stages=3),
# … 보통 5-30 개 candidate
],
key=['M', 'N', 'K'], # 입력 모양에 따라 cache
)
@triton.jit
def matmul_kernel(...): ...
autotune 의 동작.
key 별로 cache.함정 두 가지.
key 를 잘 정의해서 비슷한 shape 를 묶기.num_stages 는 — Triton 컴파일러가 inner loop 의 iteration 들을 software pipeline 으로 분리하는 깊이. 3 이면 3 iteration 이 동시에 “다른 stage” 에 있다.
matmul 의 inner loop (K 축 reduction) 한 iteration 의 일.
num_stages = 3 이면 — iteration n 의 load 가 iteration n−1 의 MMA 와, iteration n−2 의 다음 K tile pre-load 와 같은 시간에 도는 식. HBM latency 를 숨기는 핵심 메커니즘.
num_stages 를 늘리면 — shared memory 사용량이 증가한다. 각 stage 가 자기 K tile buffer 를 별도로 가져야 하니까. SRAM 한계에 부딪히면 occupancy 떨어짐. 3 이 보통 sweet spot, 2 로 떨어뜨려야 occupancy 가 잡히는 경우도 있음.
Triton 3.0+ 부터 Hopper 용 추가 hint — TMA 를 통한 비동기 copy, warpgroup-level MMA. 강의 시점 (2024 April) 이후 변화. 별도 추적 필요.
N = 1000, BLOCK_SIZE = 256 일 때 — 마지막 program (pid=3) 은 256 개를 처리하려는데 실제로 232 개만 남았다 (3*256=768, 1000-768=232). 그 24 개의 가짜 자리를 어떻게 처리하는가. 답은 mask.
mask = offs < N 으로 false 가 되어 load 시 0 이 들어가고 store 가 일어나지 않는다.matmul 같은 2D 코드에서는 두 차원 모두 mask. mask_m = offs_m < M, mask_n = offs_n < N 을 만들고 mask = mask_m[:, None] & mask_n[None, :] 의 broadcasting. 잊으면 silent corruption — 가장 자주 만나는 Triton 버그.
강의 곳곳에 Umer 가 “여기는 좀 이상하다” 라고 메모를 남긴 자리들. Triton 사용자가 빠지는 패턴.
경계 element 에 mask 없이 load 하면 SEGV 가 아니라 그냥 random 값이 나온다 (또는 다음 launch 가 죽는다). 항상 mask 를 명시.
BLOCK_SIZE 같은 매개변수에 : tl.constexpr 안 붙이면 — 매 호출마다 새 kernel 컴파일이 일어나서 cache miss. 성능 폭망.
tl.dot 은 fp16/bf16 input + fp32 accumulator 가 보통. fp32 input 은 throughput 이 한 자릿수 떨어진다 (Tensor Core 가 fp32 dense 를 잘 안 함).
if/else 안에서 tile 단위 op 가 잘 안 도는 경우가 있다. 차라리 mask 로 우회.
config 후보를 너무 많이 두면 첫 호출이 매우 느림. notebook 에서 cell 재실행하면 cache 가 날아감 — 매번 측정.
2.x → 3.0 사이 API 가 살짝 바뀐다 (특히 num_stages 의 의미). version pin 권장.
L001 에서 Mark 가 깐 의사결정 사다리 (torch → Triton → CUDA) 의 가운데 — Triton 자리. Umer 가 강의에서 자기 경험으로 정리한 trigger.
Triton 으로 멈출 신호
직접 CUDA 로 내려가는 신호
Umer 의 경험상 — 대부분의 ML kernel 은 Triton 으로 충분하다. CUDA 로 내려가야 하는 자리는 보통 (1) GEMM 자체 (그것도 CUTLASS 로 충분한 경우 다수), (2) FA 처럼 register accounting 이 critical 한 곳. 나머지는 Triton + autotune 으로 vendor 의 90% 까지 가능.
TRITON_INTERPRET=1 또는 @triton.jit(interpret=True). CPU 시뮬레이션 + Python breakpoint().key 를 잘 정의.mask 잊으면 silent corruption — Triton 가장 자주 만나는 버그.: tl.constexpr. 안 붙이면 매 호출마다 재컴파일.BLOCK_SIZE ∈ {256, 512, 1024} 로 측정. torch 의 단순 a+b 와 비교.offs <= n). interpret 모드에서 breakpoint() 로 마지막 tile 의 offs 를 직접 본다.TORCH_LOGS=output_code 로 torch.compile 이 같은 op 에 만든 Triton 과, 직접 짠 Triton 비교. 두 코드의 구조 차이..asm attribute 로 PTX 를 print. register 사용량과 instruction 수 직접 확인.이 노트의 § 05 timing (9 ms, 1.2 ms 등) 은 강의의 패턴 재구성. 자기 GPU + matmul 크기에서 직접 sweep 해봐야 본인의 sweet spot 을 안다. triton.testing.do_bench 로 정확한 측정.