cudatraining · 학습 기록
《The Stack》 EP.43 2026 · APR · 20 42 MIN IMAGINARY

Triton 이 숨기는 것,
노출하는 것

파일 일곱 개를 탁자에 펼쳐놓고 한 줄씩 읽어본다. 같은 커널들이 CUDA 에서 Triton 으로 옮겨갈 때 코드가 어떻게 접히는지 — 어떤 디테일이 컴파일러 밑으로 숨고, 어떤 것이 여전히 네 손에 남는지.

S
Host
샘 (Sam)
"그래서 왜 이게 편한 거지?"
J
Guest
젠슨 (Jensen)
속으론 반기지 않았던, 어쩔 수 없이 받아들인 물건
00 · 00:28Cold Open

파일 일곱 개. 한 줄씩.

SAM
지난 편에서 내가 "5 년 뒤에 매트릭스곱 말고 뭘 최적화할 거냐"고 물었는데 네가 다음 편으로 미뤘지.
JENSEN
(웃음) 오늘도 안 말해줄 거야. 대신 더 재밌는 거. 그 친구가 이번엔 같은 커널들을 Triton 으로 다시 짰어. 파일 일곱 개 열어놓고 한 줄씩 읽어보자고.
FIG 0 · 오늘의 자료 · triton_kernels/7 files
01 smoke_vector_add 02 reduction 03 softmax 04 matmul 05 flash_attention 06 flash_attention_mha 07 …_mha_op
01 · 01:20가장 작은 Triton 프로그램

program = 한 쓰레드가 아니라, 한 블록.

CUDA 의 threadIdx.x 는 사라지지 않았다. 컴파일러 아래로 숨었을 뿐. Triton 은 block-level SPMD — 블록 안의 쓰레드 병렬성은 num_warps 만 보고 알아서 결정한다.

smoke_vector_add.py · Tritonblock-level SPMD
@triton.jit
def vector_add_kernel(x_ptr, y_ptr, out_ptr, n,
                      BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.store(out_ptr + offsets, x + y, mask=mask)
같은 주소 계산 · CUDAthread-level
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = x[idx] + y[idx];
SAM
CUDA 에서 threadIdx.x 갖고 뭘 하던 건 Triton 에선 어디 갔어?
JENSEN
사라졌어. 정확히는 컴파일러가 숨겼어. tl.arange(0, 1024) 는 "이 블록이 처리할 1024개 인덱스 전체" 를 한 번에 가리키는 벡터. 각 lane 을 어느 쓰레드가 처리할진 Triton 이 결정해.
SAM
SIMD 같네.
JENSEN
정확히 그거야. "block-level SPMD." 각 쓰레드가 아니라 각 블록이 프로그램. 그 안의 쓰레드 병렬성은 컴파일러 몫.
두 패턴이 같은 PTX 로 내려가.
다만 Triton 코드는 "인덱스 공간" 수준에서 사고하게 해줘.— Jensen, 04:12
FIG 1 · 추상화의 높이where does threadIdx.x live?
PyTorch / JAX 텐서 연산 · "Python" Triton 블록 수준 · "C" CUDA 쓰레드 수준 · "어셈블리" PTX / SASS "기계어" Triton 은 C 의 자리 — 대부분은 여기서 쓰고, 핫패스만 어셈블리로
02 · 06:48tl.sum 한 줄이 warp shuffle 전체를 대체

근데 대가가 세 가지 있다.

FIG 2 · 같은 reduction — CUDA vs Triton~15 lines → 2 lines
// CUDA — warp shuffle reduction for (int o = 16; o > 0; o >>= 1) local += __shfl_down_sync(0xFFFFFFFFu, local, o); if (lane == 0) sdata[wid] = local; __syncthreads(); if (wid == 0) { val = (tid < nwarps) ? sdata[tid] : 0.f; for (int o = 16; o > 0; o >>= 1) val += __shfl_down_sync(0xFFFFFFFFu, val, o); if (tid == 0) atomicAdd(out, val); } ~15 lines # Triton partial = tl.sum(x, axis=0) tl.store(partial_ptr + pid, partial) 컴파일러가 내려보낸 SASS 에 동일한 shfl.sync.bfly 가 들어감 → TRITON_PRINT_PTX=1 로 검증 2 lines · 같은 PTX
SAM
그럼 진짜 공짜네?
JENSEN
(웃음) 아니야. 세 가지 대가.
① launch overhead

67M 원소 reduce — CUDA v4 = 1.039 ms, Triton = 1.097 ms (5% 느림). Python → autotune 캐시 → JIT 캐시 → argument binding → cuLaunchKernel 까지 ~50–100 µs. 작은 커널에선 이 overhead 가 연산 시간보다 커질 수 있어. → element-wise 30 개를 각각 Triton 커널로 짜면 망한다.

② autotune 의 footgun

autotune 은 각 config 를 순차 실행해 같은 output 버퍼에 쓴다. 이전 시도가 남긴 stale partial sum 이 결과에 섞임. 해법 — reset_to_zero=["partial_ptr"]. 문서에 희미하게 있고, 안 읽으면 몇 시간 디버깅.

③ autotune 과 2-pass reduce 의 관용구

BLOCK_SIZE 가 바뀌면 num_programs 도 바뀐다. 최소 블록 기준으로 최대 크기 partial 버퍼를 잡고, 실제 선택된 config 의 prefix 만 슬라이싱한다.

"Triton 은 고수준이지만 얇아."
추상화가 얕아서 내부 동작이 자꾸 새어나와. 이 누출을 잘 다루는 감각이 Triton 엔지니어의 가치.— Jensen, 12:04
reduction.py · footgun 처방reset_to_zero
@triton.autotune(
    configs=AUTOTUNE_CONFIGS,
    key=["n_elements"],
    reset_to_zero=["partial_ptr"],  # ← 이거 중요
)
03 · 13:02한 프로그램 = 한 행

마스크 로직이 데이터 값 에 자연스럽게 녹아든다.

SAM
그럼 N=1000 인 행은 어떻게 처리해?
JENSEN
BLOCK_SIZE=1024 로 잡고, 마스크로 뒤쪽 24 개를 걸러. 여기 other=-float("inf") 가 핵심이야. OOB lane 이 -inf 면 tl.max 에 영향 없고, exp(-inf)=0 이라 sum 에도 기여 안 해. 마스크 로직이 데이터 값에 녹아드는 거.
SAM
autotune 키가 BLOCK_SIZE 네. N 을 직접 키로 안 잡은 게 영리한 거네.
JENSEN
그래. BLOCK_SIZE = _next_pow2(N) 이라 N=513~1024 가 모두 1024 로 bucket 돼. 캐시 효율적인 autotune 키 설계 — Triton 배울 때 제일 늦게 배우는 기술이야.
FIG 3 · N=1000 행에서 OOB lane 다루기mask = data value
offs x max/sum 1000 valid · 0..999 24 OOB real data -inf max(x) · sum(exp(x-m)) OOB 의 -inf 는 max 에도, exp 합에도 영향 없음 → 분기 없이 청결
softmax.py · 핵심 세 줄
offs = tl.arange(0, BLOCK_SIZE)       # 0..1023
mask = offs < n_cols                  # 앞 1000 만 True
x = tl.load(in_row + offs, mask=mask,
            other=-float("inf"))       # OOB → -inf
04 · 18:41Triton 이 CUDA 를 이기는 한 가지

Grouped Program ID swizzling — CUDA 에서도 짤 수 있지만, 안 짠다.

출력 C 의 타일을 도는 순서를 바꾸는 것만으로 L2 재사용률이 크게 달라진다. row-major 로 돌면 B 의 열들이 캐시에서 쓸려나간다. 그룹 단위로 돌면 같은 B 열이 여러 번 재사용된다.

FIG 4A · row-major · B 열이 매번 교체naive
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

타일 0→1→2→… A 의 같은 행, B 의 다른 열. B 가 L2 에 안 들어가면 쓸어버림.

FIG 4B · grouped · GROUP_SIZE_M=4swizzled
0
4
8
12
16
20
24
28
1
5
9
13
17
21
25
29
2
6
10
14
18
22
26
30
3
7
11
15
19
23
27
31

타일 0→1→2→3 이 B 의 같은 열 을 4 번 재사용. L2 효율 ↑.

SAM
이거 CUDA 에서도 짤 수 있잖아. 왜 CUDA 엔 없다고 해?
JENSEN
짤 수 있지, 근데 안 짜. 세 가지 이유 — ① CUDA 엔 blockIdx.x 가 그냥 하드웨어 순서대로. 수식을 커널 맨 앞에 손으로 풀어써야 함. ② 그 수식이 읽기가 정말 나빠. ③ GROUP_SIZE_M 이 바뀌면 재컴파일. Triton 에선 autotune 파라미터고 표준 관용구야.
FIG 4C · 측정치 — 4096³ matmulL4 sm_89
variantTFLOPSnote
우리 CUDA v3 (FMA only)3.9register blocking
torch.matmul (cuBLAS + TF32)25.8NVIDIA 수년 튜닝
Triton fp3228.9cuBLAS + 12%
우리 CUDA v4 (WMMA fp16)18.5직접 짠 mma
cuBLAS fp1651.8
Triton fp1654.0cuBLAS + 4% · 40 줄
20 년 묵은 cuBLAS 가 40 줄짜리 Python 에 진다.
autotune 이 사람보다 config 공간을 잘 탐색해. 측정이 이론을 이기는 전형적인 경우.— Jensen, 23:58
05 · 26:10Flash Attention 이 40 줄

절반. 그리고 6.1× 빨라.

FIG 5A · line count커널 본체만
CUDA FA v1 80 lines Triton FA 40 lines · 50% 같은 일 — 절반의 코드, 6.1× 빠름
acc update · 한 줄로 표현된 online + P@V
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)

CUDA 버전에선 이 로직이 30+ 줄에 걸쳐 있다. 추상화가 맞는 자리에 있으면 복잡도가 죽는다.

FIG 5B · 성능 — N=8192, seq · single headL4
impltime (ms)speedup
CUDA FA v1 (fp32)3.0451.00×
Triton FA (fp16)0.4966.14×
6 배 차이는 어디서 왔나

tl.dot 이 Tensor Core 씀 (CUDA v1 은 fp32 FMA). autotune 이 (BLOCK_M, BLOCK_N, num_warps, num_stages) 6 개 config 탐색 — CUDA 로 sweep 하려면 재컴파일 6 번. tl.trans, 2-D 포인터 브로드캐스트, swizzled smem layout 모두 자동.

06 · 31:35constexpr 한 줄의 마법

causal mask 는 "채우는 것" 이 아니라 "루프에서 빼는 것."

마스크로 -inf 채워도 QKᵀ 는 전체를 계산한다. FA-v2 의 실제 이득은 상삼각 전체에 들어가는 K 타일을 이터레이션 자체에서 뺀다는 데 있다.

FIG 6A · causal 루프 · pid_m 별로 상한이 다름skipped tiles ↓ 절반
pid_m ↓ start_n → m=0 m=1 m=2 m=3 m=4 m=5 computed skipped · causal end_n = (pid_m+1) * BLOCK_M → 평균 N/2 타일만 돌면 됨 N=2048: non-causal · 2.643 ms causal · 0.784 ms (3.3×) 대각선 타일 하나만 per-element mask 필요 — 나머지 전부는 iter 에서 제외
constexpr · 커널 두 개로 JITno runtime branch
def flash_attention_mha_fwd_kernel(..., 
    IS_CAUSAL: tl.constexpr,     # ← 핵심
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr):
    if IS_CAUSAL:
        end_n = tl.minimum(N, (pid_m+1) * BLOCK_M)
    else:
        end_n = N

→ IS_CAUSAL=True 용 커널과 =False 용 커널 별도로 컴파일. 런타임에 if 자체가 없음. C++ template specialization 과 동치.

FIG 6B · 우리 Triton vs SDPA (cuDNN FA-2)LLaMA-7B shape
(B, H, N, d)oursSDPAratio
(1, 32, 2048, 128)0.7840.6130.78×
(1, 32, 4096, 128)2.9642.5590.86×
(16, 12, 512, 64)0.2490.2821.13×

268 줄 파일 하나로 cuDNN 의 78–90%. 못 따라잡는 이유 — async copy, persistent kernel, warp specialization 이 아직 experimental.

07 · 37:20torch.ops.* 로 올라가기

70 줄. 이게 Triton 커널을 부품 으로 만드는 마지막 접착제.

flash_attention_mha_op.py · 70 lines
@custom_op(
    "triton_training::flash_attention_mha",
    mutates_args=(),
    device_types="cuda",
)
def flash_attention_mha_op(q, k, v, is_causal=False):
    return triton_flash_attention_mha(q, k, v, is_causal=is_causal)

@flash_attention_mha_op.register_fake
def _fake(q, k, v, is_causal=False):
    return torch.empty_like(q)   # ← Dynamo 용 shape 선언
왜 register_fake 가 핵심인가

torch.compile 이 모델을 트레이싱할 때 FakeTensor (shape, dtype, device 만) 를 쓴다. 진짜 데이터가 없으니 우리 Triton 커널은 못 돌린다. 대신 "이 op 의 출력 shape 은 이것이다" 를 선언 → Dynamo 가 그래프를 안 끊는다. 없으면 fullgraph=True fail.

FIG 7 · 이 70 줄이 여는 것vLLM pattern
AttentionBlock · torch.compile(fullgraph=True) 그래프 브레이크 0 건 · err = 0.00e+00 torch.ops.triton_training.flash_attention_mha our Triton FA 268 lines · L4 Python + Triton vLLM PagedAttention 수백 lines · H100 C++ + CUDA 같은 패턴. 다른 구현. 같은 접점.
SAM
그러니까 Brian 이 지금 만든 게 거의 vLLM 스타일 프로덕션 op.
JENSEN
거의 그래. 빠진 건 backward (autograd) 랑 GQA 지원. 둘 다 설계가 명확해서 다음 레슨에서 추가 가능. 그리고 그 두 개 붙이면 vLLM PagedAttention 포팅 이 다음 목표.

Triton 이 숨기는 것 · 노출하는 것

"고수준 DSL" 이 아니라 "추상화 높이가 딱 그 자리에 있는 언어". 아래 다섯 개는 컴파일러 밑으로 숨고, 위 다섯 개는 여전히 네 손에 남는다.

숨기는 것 · HIDES컴파일러 아래
  1. 쓰레드 수준 병렬성 — threadIdx.x 사라짐
  2. warp shuffle, smem tree reduction — tl.sum
  3. Tensor Core 인스트럭션 선택 — tl.dot 이 dtype 보고 자동
  4. smem layout swizzle — bank conflict 자동 회피
  5. launch config 튜닝 — autotune 에 위임
노출하는 것 · EXPOSES여전히 네 손
  1. 블록 크기, grid 구조 — program_id, grid=lambda meta
  2. 메모리 계층 의식 — tl.load(mask=...) 의 HBM 패턴
  3. 컴파일 타임 vs 런타임 경계 — tl.constexpr
  4. autotune 키 설계 — 너무 넓으면 튜닝 폭발, 좁으면 놓침
  5. 수치적 동작 — online softmax, fp16 vs fp32 accumulator
CUDA 가 어셈블리, Triton 이 C, PyTorch 가 Python.
대부분은 C 로 짜고, 핫패스만 어셈블리로 내려간다.— Sam & Jensen, 41:02
CUDA 를 계속 배워야 하는 이유 · 네 가지
  1. Triton 이 막히는 순간 — 새 mma (Blackwell FP4), persistent kernel, async copy 정밀 제어 — 아직 CUTLASS/CUDA.
  2. Triton 이 생성한 PTX 를 읽을 줄 알아야 디버그 — TRITON_CACHE_DIR*.ptx.
  3. vLLM, FlashAttention-3, Mamba 커널이 아직 CUDA 기반. 이 코드 읽으려면 CUDA 가 모국어.
  4. "왜 이 Triton 이 느린가" 추적 — bank conflict, register spill, occupancy. 답은 CUDA 개념에.

코드 참조 · triton_kernels/ 전체 7 파일 · L4 sm_89, CUDA 13.0, PyTorch 2.11, Triton 3.6.

← Ep.42 너희 GPT는 왜 내 GPU를 이렇게 쓰게 됐을까 Back → Index · 11편 기록