gpumode · 강의 아카이브
《GPU Mode》 L014 2024 · APR · 06 High priority transcript · available

A Practitioner's Guide to Triton

Triton 은 Python 안의 DSL 인데 — “Python 같은 코드” 만으로 추측해서 짜면 거의 항상 느린 결과가 나온다. Triton 의 내부 모델 (program_id, tile 단위 op, autotune 의 의미) 을 정확히 알고 짜야 한다. Umer Adil 이 직접 만든 notebook 위에서 — 첫 vector add 부터 본격 matmul 까지 단계적으로 깐 실전 가이드. 디버깅 (interpret 모드), benchmarking, autotuning 의 함정, Triton vs CUDA 의사결정.

Triton DSL program_id · arange tl.load · tl.store autotune · BLOCK_SIZE interpret 모드 num_warps · num_stages masking Triton vs CUDA PTX dump
U
Speaker
Umer Adil (UmerHA)
Independent · Triton 학습 자료 제작 · GPU Mode community
강의 번호
L014
스피커
Umer Adil
학습 우선순위
High · 정독
자료
notebook · triton_util.py
§ 01강의가 풀려는 문제· why this lecture exists

“Python 같으니까 Python 처럼 짜면 되겠지” 의 착각

Triton 의 syntax 는 numpy 와 비슷해 보인다 — tl.load, tl.store, +, *. 그래서 처음 시작하는 사람은 “이건 그냥 Python 으로 GPU 코드 짜는 거구나” 라고 생각하기 쉽다. 그러나 Triton 의 내부 모델은 — element 단위가 아니라 tile 단위이고, 한 program (kernel instance) 이 한 tile 의 출력을 만든다. 이 mental model 이 안 잡히면 디버깅이 안 되고 성능도 안 나온다.

Umer 의 강의가 풀려는 질문 셋.

  1. Triton 의 tile/program 모델을 어떻게 손에 잡힐 정도로 익히는가 — 실제 vector add 부터 matmul 까지 코드 위에서.
  2. Triton 의 디버깅이 GPU 코드 디버깅과 무엇이 다른가 — interpret 모드, breakpoint, print.
  3. 같은 코드가 BLOCK_SIZE/num_warps/num_stages 로 10× 차이 나는데, 어떻게 sweep 하나 — autotune 의 의미와 함정.
강의의 인지적 frame

L001 (Mark) 이 “Triton 을 처음 만나는 자리” 였다면 — 이 강의는 “Triton 으로 일을 시작한 사람이 빠지는 함정 카탈로그”. 코드는 짧지만 패턴이 깊다. notebook 자체가 강의 자료이자 실습 — Umer 가 일부러 bug 를 박아두기도 한다.

“Triton 의 첫 커널은 1 시간 안에 짤 수 있다. 그 커널을 빠르게 만드는 데는 그 후로 몇 주가 걸린다.”Umer Adil · 강의 paraphrase
§ 02tile-level 추상화의 가치· SIMT 와의 차이

thread 가 아니라 tile 을 짠다 — 한 단계 위의 추상

CUDA 가 thread 단위로 “나는 i 번째 thread 다, i 번째 element 를 처리한다” 의 SIMT 모델이라면 — Triton 은 tile 단위로 “나는 j 번째 program 이고, j 번째 tile 의 출력을 만든다”의 모델. 그 사이에서 컴파일러가 tile → thread 매핑을 자동으로 결정한다.

CUDA — SIMTthread 가 한 element

  • 32 thread 가 warp 으로 묶여 같은 instruction 을 도는데, 사용자가 thread 별 코드를 짠다.
  • shared memory layout, bank conflict, coalesced load 모두 사용자 책임.
  • 장점: 모든 hardware 디테일을 control 가능.
  • 단점: 같은 알고리즘을 여러 hardware 에 옮길 때마다 재튜닝.

Triton — tile-levelprogram 이 한 tile

  • 한 program 은 한 tile (BLOCK_SIZE 크기) 의 출력을 만드는 일감.
  • tile 안의 element-thread 매핑은 컴파일러가 결정 (layout pass).
  • shared memory, swizzling, coalesced load 도 컴파일러가 자동 — 보통 80-90% 성능까지.
  • 단점: 가장 안쪽 layout 을 사용자가 직접 지정 못 함 — bank conflict 미세조정 어려움.
tile 단위가 가져오는 것

tile 추상이 좋은 이유 — “tensor cores 위에서 도는 모양” 과 자연스럽게 맞는다. MMA instruction 은 16×16 같은 tile 위에서 도는데, Triton 의 BLOCK_SIZE 가 그 단위를 직접 표현. CUDA 처럼 thread 를 16×16 으로 “손으로 묶을” 필요가 없다.

§ 03program_id · arange · load · store· 최소 4개 op

vector add 한 페이지 — Triton 의 mental model 이 다 들어 있다

Umer 의 notebook 첫 코드. 단순한 vector add 안에 Triton 의 4개 핵심 op 가 모두 등장한다. 이 4개를 정확히 이해하면 다른 모든 Triton 코드가 같은 패턴의 변형이라는 게 보인다.

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
               BLOCK_SIZE: tl.constexpr):
    # 1) program_id — 나는 몇 번째 program?
    pid    = tl.program_id(0)
    # 2) arange — 내가 다룰 tile 의 element index
    offs   = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    # 3) mask — 마지막 tile 의 tail 처리
    mask   = offs < n
    # 4) load — HBM 에서 SRAM 으로
    x      = tl.load(x_ptr + offs, mask=mask)
    y      = tl.load(y_ptr + offs, mask=mask)
    # 5) compute — tile 단위 numpy-like
    out    = x + y
    # 6) store — SRAM 에서 HBM 으로
    tl.store(out_ptr + offs, out, mask=mask)

각 op 의 의미.

  • program_id(0) — 1D grid 위의 program index. 0, 1, 2, … N-1 까지 N 개 program 이 launch 되고 각자 자기 pid 만 본다. CUDA 의 blockIdx 에 가장 가까움.
  • tl.arange(0, BLOCK_SIZE) — 0, 1, …, BLOCK_SIZE-1 의 vector 를 만든다. 이게 Triton 의 가장 핵심 — “tile 의 모든 element 에 대한 인덱스 vector”.
  • tl.load(ptr + offs, mask=mask) — pointer + offset vector 위치에서 BLOCK_SIZE 개 element 를 한 번에 read. mask 가 false 인 자리는 default 값 (0).
  • tl.store(ptr + offs, val, mask=mask) — 반대로 BLOCK_SIZE 개 write.

호출하는 자리.

n = x.numel()
grid = (triton.cdiv(n, BLOCK_SIZE),)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=1024)
2D / matmul 로 확장

matmul 은 같은 패턴의 2D 버전 — pid_m, pid_n = tl.program_id(0), tl.program_id(1), offs_m = pid_m*BM + tl.arange(0, BM), offs_n = pid_n*BN + tl.arange(0, BN), 그 후 K 차원으로 inner loop. vector add 의 두 차원 확장이 matmul — Umer 의 notebook 이 그 확장 과정을 step by step 으로 보여줌.

§ 04interpret 모드로 디버깅· CPU 시뮬레이션

GPU 커널 안에서 breakpoint() 를 거는 자리

Triton 의 가장 큰 학습 곡선 단축 도구. TRITON_INTERPRET=1 환경변수 또는 @triton.jit(interpret=True) 데코레이터를 켜면 — kernel 이 CPU 위에서 시뮬레이션되어 Python breakpoint() 가 그대로 동작한다. tl.load 의 결과 모양을 직접 print, step 마다 변수 검사. GPU kernel 디버깅의 의미가 한 단계 바뀐다.

import os
os.environ['TRITON_INTERPRET'] = '1'

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
               BLOCK_SIZE: tl.constexpr):
    pid    = tl.program_id(0)
    offs   = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask   = offs < n
    x      = tl.load(x_ptr + offs, mask=mask)
    breakpoint()                      # ← CPU 에서 멈춤
    y      = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x+y, mask=mask)

interpret 모드의 한계.

  • 실행 속도: 매우 느림. 큰 N 으로 돌리면 안 됨.
  • 일부 op 가 시뮬레이션이 정확하지 않을 수 있음 — async copy, atomic 같은 부분.
  • numerical 결과가 GPU 와 약간 다른 경우가 드물게 있음 (rounding 모드).

Umer 의 helper — triton_util.pybreakpoint_ifprint_if. 특정 program_id 에서만 breakpoint. “pid_0=0, pid_1=1 일 때만 멈춰라”.

def breakpoint_if(conds, pid_0=[0], pid_1=[0],
                  pid_2=[0]):
    from IPython.core.debugger import set_trace
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        set_trace()
“GPU 커널 디버깅 시 변수를 일부러 글로벌 메모리에 적던 시절을 끝낸다 — interpret 모드는 GPU 커널 디버깅의 의미를 한 단계 바꾼다.”Andreas Köpf · 32:18 (L001 의 인용)
§ 05autotune 과 BLOCK_SIZE 사다리· 3 변수 sweep

같은 source 가 BLOCK_SIZE 로 10× 빠르거나 느리다 — autotune 이 그걸 자동으로

Triton 의 launch 설정 3개 — BLOCK_SIZE (tile 크기), num_warps (한 program 안의 warp 수), num_stages (pipeline 깊이) — 이 같은 코드의 성능을 크게 흔든다. Umer 의 강의에서 직접 보여준 예: matmul 한 코드가 설정에 따라 1ms 와 10ms 사이에 흩어진다.

FIG · BLOCK_SIZE sweep — 같은 코드, 다른 시간matmul 4096×4096
BLOCK=32, warps=2
9.0 ms
slow
BLOCK=64, warps=4
5.5 ms
BLOCK=128, warps=4
3.0 ms
BLOCK=128, warps=8, stages=3
1.2 ms
best
BLOCK=256, warps=8
1.8 ms
register pressure
최적값 주변에서 작은 차이. 너무 작으면 SM 이 비어 (occupancy 부족), 너무 크면 register spill / shared memory 한계. 값은 강의의 패턴 재구성.
# autotune — 사용자가 후보 set 을 선언
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M':64, 'BLOCK_N':64,
                       'BLOCK_K':32},
                      num_warps=4, num_stages=3),
        triton.Config({'BLOCK_M':128, 'BLOCK_N':128,
                       'BLOCK_K':32},
                      num_warps=8, num_stages=3),
        # … 보통 5-30 개 candidate
    ],
    key=['M', 'N', 'K'],   # 입력 모양에 따라 cache
)
@triton.jit
def matmul_kernel(...): ...

autotune 의 동작.

  1. 처음 호출 시 모든 config 를 직접 실행해 측정.
  2. 가장 빠른 config 를 key 별로 cache.
  3. 이후 같은 key 의 호출은 cached config 사용.

함정 두 가지.

  • 첫 호출의 wall time 이 매우 길다 — 30 config × 5ms = 150 ms 시 측정에 들어간다. test 환경에서 이상하게 느려보임.
  • shape 가 매번 다르면 cache 가 안 됨 — dynamic shape 모델에서는 autotune 의 효과가 거의 없거나 매번 재측정. key 를 잘 정의해서 비슷한 shape 를 묶기.
§ 06pipeline 깊이의 trade-off· num_stages

num_stages = 3 의 의미 — 동시에 도는 작업이 3개

num_stages 는 — Triton 컴파일러가 inner loop 의 iteration 들을 software pipeline 으로 분리하는 깊이. 3 이면 3 iteration 이 동시에 “다른 stage” 에 있다.

matmul 의 inner loop (K 축 reduction) 한 iteration 의 일.

  1. K tile load (HBM → SRAM)
  2. MMA 누산 (SRAM → register)
  3. 다음 iteration 준비

num_stages = 3 이면 — iteration n 의 load 가 iteration n−1 의 MMA 와, iteration n−2 의 다음 K tile pre-load 와 같은 시간에 도는 식. HBM latency 를 숨기는 핵심 메커니즘.

trade-off

num_stages 를 늘리면 — shared memory 사용량이 증가한다. 각 stage 가 자기 K tile buffer 를 별도로 가져야 하니까. SRAM 한계에 부딪히면 occupancy 떨어짐. 3 이 보통 sweet spot, 2 로 떨어뜨려야 occupancy 가 잡히는 경우도 있음.

Hopper 의 새 변수

Triton 3.0+ 부터 Hopper 용 추가 hint — TMA 를 통한 비동기 copy, warpgroup-level MMA. 강의 시점 (2024 April) 이후 변화. 별도 추적 필요.

§ 07masking 패턴· tail · padding

마지막 tile 이 BLOCK_SIZE 의 배수가 아닐 때 — mask 가 정답

N = 1000, BLOCK_SIZE = 256 일 때 — 마지막 program (pid=3) 은 256 개를 처리하려는데 실제로 232 개만 남았다 (3*256=768, 1000-768=232). 그 24 개의 가짜 자리를 어떻게 처리하는가. 답은 mask.

FIG · masking — 마지막 tile 의 tail 처리N=1000, BLOCK_SIZE=256
0
1
2
3
253
254
255
256
511
512
767
768
999
mask
mask
mask
mask
mask
파란 자리는 실제 데이터. 회색 자리는 N 을 넘는 fake offsets — mask = offs < N 으로 false 가 되어 load 시 0 이 들어가고 store 가 일어나지 않는다.
2D mask

matmul 같은 2D 코드에서는 두 차원 모두 mask. mask_m = offs_m < M, mask_n = offs_n < N 을 만들고 mask = mask_m[:, None] & mask_n[None, :] 의 broadcasting. 잊으면 silent corruption — 가장 자주 만나는 Triton 버그.

§ 08흔한 함정과 rough edges· Umer 의 메모

“Triton 은 아직 새 프로젝트” — Umer 가 직접 만난 함정 카탈로그

강의 곳곳에 Umer 가 “여기는 좀 이상하다” 라고 메모를 남긴 자리들. Triton 사용자가 빠지는 패턴.

silent OOB without mask

경계 element 에 mask 없이 load 하면 SEGV 가 아니라 그냥 random 값이 나온다 (또는 다음 launch 가 죽는다). 항상 mask 를 명시.

tl.constexpr 안 붙이면 autotune 망함

BLOCK_SIZE 같은 매개변수에 : tl.constexpr 안 붙이면 — 매 호출마다 새 kernel 컴파일이 일어나서 cache miss. 성능 폭망.

tl.dot 의 dtype 제약

tl.dot 은 fp16/bf16 input + fp32 accumulator 가 보통. fp32 input 은 throughput 이 한 자릿수 떨어진다 (Tensor Core 가 fp32 dense 를 잘 안 함).

일부 control flow 가 안 됨

if/else 안에서 tile 단위 op 가 잘 안 도는 경우가 있다. 차라리 mask 로 우회.

autotune 의 재컴파일

config 후보를 너무 많이 두면 첫 호출이 매우 느림. notebook 에서 cell 재실행하면 cache 가 날아감 — 매번 측정.

Triton 버전 사이의 변경

2.x → 3.0 사이 API 가 살짝 바뀐다 (특히 num_stages 의 의미). version pin 권장.

“Triton 은 빠르게 진화 중인 프로젝트다 — rough edges 가 있다. encountered 하면 노트해두는 게 다음 사람의 시간을 아끼는 길.”Umer Adil · 강의 paraphrase
§ 09Triton vs CUDA 의사결정· 언제 더 내려갈 것인가

Triton 으로 충분한 자리, 직접 CUDA 로 내려가야 하는 자리

L001 에서 Mark 가 깐 의사결정 사다리 (torch → Triton → CUDA) 의 가운데 — Triton 자리. Umer 가 강의에서 자기 경험으로 정리한 trigger.

Triton 으로 멈출 신호

  • NCU 가 “tail effect / occupancy / launch shape” 만 hint 로 줌. 이건 BLOCK_SIZE/num_warps 조정으로 잡힘.
  • memory-bound op (elementwise, normalize, cross-entropy 등). HBM 트래픽이 거의 peak. compute 의 디테일이 무의미.
  • fast iteration 이 중요 — 코드 변경 후 다시 컴파일이 1초.

직접 CUDA 로 내려가는 신호

  • NCU 가 “shared memory bank conflict”, “uncoalesced reads”, “register spill” hint. Triton 의 자동 layout 으로는 못 잡는 자리.
  • vendor library (cuBLAS, cuDNN) 의 80-90% 까지 가야 하는 GEMM/conv. CUTLASS template 위로.
  • warp-level primitive 가 필요 (warp shuffle, ballot 등). Triton 은 표현 못 함.
  • 같은 cycle 에 여러 thread 가 같은 SRAM bank 를 치는 자리를 손으로 swizzle.
실용적 분포

Umer 의 경험상 — 대부분의 ML kernel 은 Triton 으로 충분하다. CUDA 로 내려가야 하는 자리는 보통 (1) GEMM 자체 (그것도 CUTLASS 로 충분한 경우 다수), (2) FA 처럼 register accounting 이 critical 한 곳. 나머지는 Triton + autotune 으로 vendor 의 90% 까지 가능.

§ 10기억할 메모와 코드· key takeaways · repo
tile-level 모델
한 program 이 한 tile 의 출력. tile 안의 thread 매핑은 컴파일러가. CUDA 의 SIMT 와 한 단계 다름.
4 개 핵심 op
program_id, arange, tl.load(mask=), tl.store(mask=). 모든 Triton 코드는 이 4개의 변형.
interpret 모드
TRITON_INTERPRET=1 또는 @triton.jit(interpret=True). CPU 시뮬레이션 + Python breakpoint().
3 launch 변수
BLOCK_SIZE (tile 크기), num_warps (병렬도), num_stages (pipeline 깊이). 같은 코드 10× 차이.
autotune 의 함정
첫 호출 매우 느림 (모든 config 측정). dynamic shape 에서는 cache 안 됨. key 를 잘 정의.
masking 의무
tail 처리 + 2D 경계. mask 잊으면 silent corruption — Triton 가장 자주 만나는 버그.
tl.constexpr
launch 변수에 항상 : tl.constexpr. 안 붙이면 매 호출마다 재컴파일.
의사결정 trigger
NCU 가 bank conflict / register spill 하면 직접 CUDA. tail/occupancy 면 Triton 으로 충분.
Slides 별도 슬라이드 없음 — notebook 자체가 강의 자료

손에 새기기 — 실습 시퀀스

  1. vector add 첫 커널 — § 03 의 코드를 그대로. BLOCK_SIZE ∈ {256, 512, 1024} 로 측정. torch 의 단순 a+b 와 비교.
  2. interpret 모드 디버깅 — 일부러 mask 를 잘못 짠 add kernel 을 만든다 (예: offs <= n). interpret 모드에서 breakpoint() 로 마지막 tile 의 offs 를 직접 본다.
  3. matmul 한 페이지 — Umer notebook 의 matmul 섹션 그대로. naive Triton matmul → torch.matmul 과 정확도 + 시간 비교.
  4. autotune 적용 — 같은 matmul 에 5-10 개 config 의 autotune. 첫 호출 시간과 두 번째 호출 시간 비교. cache 동작 확인.
  5. num_stages sweep — 1, 2, 3, 4 로 같은 matmul 측정. shared memory 사용량과 throughput 의 trade-off 시각화.
  6. output_code 비교TORCH_LOGS=output_codetorch.compile 이 같은 op 에 만든 Triton 과, 직접 짠 Triton 비교. 두 코드의 구조 차이.
  7. PTX dump — 한 kernel 의 .asm attribute 로 PTX 를 print. register 사용량과 instruction 수 직접 확인.
  8. silent corruption 직접 만들기 — mask 를 일부러 빼고 결과가 어디서부터 다른지 검증. 이 함정의 형태를 손에 새긴다.
§ 12열린 질문· open questions
  • Triton 3.0+ 의 새 기능 — Hopper 용 TMA / WGMMA / async barrier hint. 강의 시점 (2024 April) 이후. Triton 공식 docs 직접 확인 필요.
  • autotune 의 더 영리한 strategy — Bayesian 또는 gradient-based search. naive 모든-config-측정 외 alternatives.
  • interpret 모드의 한계 정확한 목록 — 어떤 op 가 시뮬레이션 안 되는지. Triton issue tracker 참조 필요.
  • sm_120 (Blackwell) 위에서의 변화 — 강의 시점 이후 새 hardware. Triton 이 어떻게 따라잡고 있는가.
  • Triton 의 layout pass 디버깅 — 직접 layout 을 보고 싶을 때의 도구. 보통 컴파일러 내부 — 정확한 진입점 확인 필요.
  • tl.dot 의 fp32 input 지원 — TF32 모드의 영향. 정확도 vs 속도의 trade-off 측정 필요.
  • cross-platform — AMD ROCm 의 Triton — 같은 코드가 ROCm 위에서 도는가. 실제 sweet spot 의 hardware 별 차이.
검증 메모

이 노트의 § 05 timing (9 ms, 1.2 ms 등) 은 강의의 패턴 재구성. 자기 GPU + matmul 크기에서 직접 sweep 해봐야 본인의 sweet spot 을 안다. triton.testing.do_bench 로 정확한 측정.

← Lecture 013 Ring Attention — distributed 의 자리에서 single GPU Triton 으로 다시 Lecture 015 → CUTLASS — Triton 보다 한 단계 아래, CUDA template + CuTe layout