파일 일곱 개를 탁자에 펼쳐놓고 한 줄씩 읽어본다. 같은 커널들이 CUDA 에서 Triton 으로 옮겨갈 때 코드가 어떻게 접히는지 — 어떤 디테일이 컴파일러 밑으로 숨고, 어떤 것이 여전히 네 손에 남는지.
program = 한 쓰레드가 아니라, 한 블록.CUDA 의 threadIdx.x 는 사라지지 않았다. 컴파일러 아래로 숨었을 뿐. Triton 은 block-level SPMD — 블록 안의 쓰레드 병렬성은 num_warps 만 보고 알아서 결정한다.
@triton.jit def vector_add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) tl.store(out_ptr + offsets, x + y, mask=mask)
int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) out[idx] = x[idx] + y[idx];
threadIdx.x 갖고 뭘 하던 건 Triton 에선 어디 갔어?tl.arange(0, 1024) 는 "이 블록이 처리할 1024개 인덱스 전체" 를 한 번에 가리키는 벡터. 각 lane 을 어느 쓰레드가 처리할진 Triton 이 결정해.67M 원소 reduce — CUDA v4 = 1.039 ms, Triton = 1.097 ms (5% 느림). Python → autotune 캐시 → JIT 캐시 → argument binding → cuLaunchKernel 까지 ~50–100 µs. 작은 커널에선 이 overhead 가 연산 시간보다 커질 수 있어. → element-wise 30 개를 각각 Triton 커널로 짜면 망한다.
autotune 은 각 config 를 순차 실행해 같은 output 버퍼에 쓴다. 이전 시도가 남긴 stale partial sum 이 결과에 섞임. 해법 — reset_to_zero=["partial_ptr"]. 문서에 희미하게 있고, 안 읽으면 몇 시간 디버깅.
BLOCK_SIZE 가 바뀌면 num_programs 도 바뀐다. 최소 블록 기준으로 최대 크기 partial 버퍼를 잡고, 실제 선택된 config 의 prefix 만 슬라이싱한다.
@triton.autotune(
configs=AUTOTUNE_CONFIGS,
key=["n_elements"],
reset_to_zero=["partial_ptr"], # ← 이거 중요
)
BLOCK_SIZE=1024 로 잡고, 마스크로 뒤쪽 24 개를 걸러. 여기 other=-float("inf") 가 핵심이야. OOB lane 이 -inf 면 tl.max 에 영향 없고, exp(-inf)=0 이라 sum 에도 기여 안 해. 마스크 로직이 데이터 값에 녹아드는 거.BLOCK_SIZE 네. N 을 직접 키로 안 잡은 게 영리한 거네.BLOCK_SIZE = _next_pow2(N) 이라 N=513~1024 가 모두 1024 로 bucket 돼. 캐시 효율적인 autotune 키 설계 — Triton 배울 때 제일 늦게 배우는 기술이야.offs = tl.arange(0, BLOCK_SIZE) # 0..1023 mask = offs < n_cols # 앞 1000 만 True x = tl.load(in_row + offs, mask=mask, other=-float("inf")) # OOB → -inf
출력 C 의 타일을 도는 순서를 바꾸는 것만으로 L2 재사용률이 크게 달라진다. row-major 로 돌면 B 의 열들이 캐시에서 쓸려나간다. 그룹 단위로 돌면 같은 B 열이 여러 번 재사용된다.
타일 0→1→2→… A 의 같은 행, B 의 다른 열. B 가 L2 에 안 들어가면 쓸어버림.
타일 0→1→2→3 이 B 의 같은 열 을 4 번 재사용. L2 효율 ↑.
blockIdx.x 가 그냥 하드웨어 순서대로. 수식을 커널 맨 앞에 손으로 풀어써야 함. ② 그 수식이 읽기가 정말 나빠. ③ GROUP_SIZE_M 이 바뀌면 재컴파일. Triton 에선 autotune 파라미터고 표준 관용구야.| variant | TFLOPS | note |
|---|---|---|
| 우리 CUDA v3 (FMA only) | 3.9 | register blocking |
| torch.matmul (cuBLAS + TF32) | 25.8 | NVIDIA 수년 튜닝 |
| Triton fp32 | 28.9 | cuBLAS + 12% |
| 우리 CUDA v4 (WMMA fp16) | 18.5 | 직접 짠 mma |
| cuBLAS fp16 | 51.8 | — |
| Triton fp16 | 54.0 | cuBLAS + 4% · 40 줄 |
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
CUDA 버전에선 이 로직이 30+ 줄에 걸쳐 있다. 추상화가 맞는 자리에 있으면 복잡도가 죽는다.
| impl | time (ms) | speedup |
|---|---|---|
| CUDA FA v1 (fp32) | 3.045 | 1.00× |
| Triton FA (fp16) | 0.496 | 6.14× |
① tl.dot 이 Tensor Core 씀 (CUDA v1 은 fp32 FMA). ② autotune 이 (BLOCK_M, BLOCK_N, num_warps, num_stages) 6 개 config 탐색 — CUDA 로 sweep 하려면 재컴파일 6 번. ③ tl.trans, 2-D 포인터 브로드캐스트, swizzled smem layout 모두 자동.
마스크로 -inf 채워도 QKᵀ 는 전체를 계산한다. FA-v2 의 실제 이득은 상삼각 전체에 들어가는 K 타일을 이터레이션 자체에서 뺀다는 데 있다.
def flash_attention_mha_fwd_kernel(..., IS_CAUSAL: tl.constexpr, # ← 핵심 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): if IS_CAUSAL: end_n = tl.minimum(N, (pid_m+1) * BLOCK_M) else: end_n = N
→ IS_CAUSAL=True 용 커널과 =False 용 커널 별도로 컴파일. 런타임에 if 자체가 없음. C++ template specialization 과 동치.
| (B, H, N, d) | ours | SDPA | ratio |
|---|---|---|---|
| (1, 32, 2048, 128) | 0.784 | 0.613 | 0.78× |
| (1, 32, 4096, 128) | 2.964 | 2.559 | 0.86× |
| (16, 12, 512, 64) | 0.249 | 0.282 | 1.13× |
268 줄 파일 하나로 cuDNN 의 78–90%. 못 따라잡는 이유 — async copy, persistent kernel, warp specialization 이 아직 experimental.
@custom_op( "triton_training::flash_attention_mha", mutates_args=(), device_types="cuda", ) def flash_attention_mha_op(q, k, v, is_causal=False): return triton_flash_attention_mha(q, k, v, is_causal=is_causal) @flash_attention_mha_op.register_fake def _fake(q, k, v, is_causal=False): return torch.empty_like(q) # ← Dynamo 용 shape 선언
torch.compile 이 모델을 트레이싱할 때 FakeTensor (shape, dtype, device 만) 를 쓴다. 진짜 데이터가 없으니 우리 Triton 커널은 못 돌린다. 대신 "이 op 의 출력 shape 은 이것이다" 를 선언 → Dynamo 가 그래프를 안 끊는다. 없으면 fullgraph=True fail.
"고수준 DSL" 이 아니라 "추상화 높이가 딱 그 자리에 있는 언어". 아래 다섯 개는 컴파일러 밑으로 숨고, 위 다섯 개는 여전히 네 손에 남는다.
threadIdx.x 사라짐tl.sumtl.dot 이 dtype 보고 자동program_id, grid=lambda metatl.load(mask=...) 의 HBM 패턴tl.constexprTRITON_CACHE_DIR 밑 *.ptx.코드 참조 · triton_kernels/ 전체 7 파일 · L4 sm_89, CUDA 13.0, PyTorch 2.11, Triton 3.6.
Spread seven files on the table and read them line by line. When the same kernels move from CUDA to Triton — how does the code fold? What details disappear under the compiler, and what stays in your hands?
program = one block, not one thread.CUDA's threadIdx.x didn't vanish — it just slipped under the compiler. Triton is block-level SPMD: the thread parallelism inside a block is decided by the compiler from num_warps.
@triton.jit def vector_add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) tl.store(out_ptr + offsets, x + y, mask=mask)
int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) out[idx] = x[idx] + y[idx];
threadIdx.x in CUDA — where did it go in Triton?tl.arange(0, 1024) is a vector pointing to "all 1024 indices this block handles." Which thread handles which lane — Triton decides.Reducing 67M elements — CUDA v4 = 1.039 ms, Triton = 1.097 ms (5% slower). Python → autotune cache → JIT cache → argument binding → cuLaunchKernel eats ~50–100 µs. For tiny kernels that overhead can exceed the compute time. → Launch 30 element-wise ops as separate Triton kernels and you're done.
Autotune runs configs sequentially against the same output buffer. Leftover stale partial sums from previous attempts mix into the result. Fix — reset_to_zero=["partial_ptr"]. It's faintly mentioned in the docs; miss it and you debug for hours.
Change BLOCK_SIZE and num_programs changes too. Size the partial buffer for the maximum case, then slice it to the prefix that matches the chosen config.
@triton.autotune(
configs=AUTOTUNE_CONFIGS,
key=["n_elements"],
reset_to_zero=["partial_ptr"], # ← matters
)
BLOCK_SIZE=1024, mask out the last 24. The trick is other=-float("inf"). If OOB lanes hold -inf, tl.max is unaffected and exp(-inf)=0 so sum doesn't get contributions either. Mask logic melts into the data values.BLOCK_SIZE. Clever not to use N directly.BLOCK_SIZE = _next_pow2(N), N=513–1024 all bucket into 1024. Cache-friendly autotune key design — the last skill you pick up when learning Triton.offs = tl.arange(0, BLOCK_SIZE) # 0..1023 mask = offs < n_cols # only first 1000 True x = tl.load(in_row + offs, mask=mask, other=-float("inf")) # OOB → -inf
Change only the order in which output C tiles are visited, and L2 reuse changes dramatically. Row-major eviction sweeps B's columns. Group-wise traversal lets the same B columns be reused.
Tile 0→1→2→… same row of A, different column of B. If B doesn't fit in L2, it gets swept out.
Tiles 0→1→2→3 reuse the same B column four times. L2 efficiency ↑.
blockIdx.x is just linear hardware order. You have to write the math by hand at the top of the kernel. ② That math is painful to read. ③ Change GROUP_SIZE_M and you recompile. In Triton it's an autotune parameter and a standard idiom.| variant | TFLOPS | note |
|---|---|---|
| our CUDA v3 (FMA only) | 3.9 | register blocking |
| torch.matmul (cuBLAS + TF32) | 25.8 | years of NVIDIA tuning |
| Triton fp32 | 28.9 | cuBLAS + 12% |
| our CUDA v4 (WMMA fp16) | 18.5 | hand-written mma |
| cuBLAS fp16 | 51.8 | — |
| Triton fp16 | 54.0 | cuBLAS + 4% · 40 lines |
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
In CUDA this logic stretches over 30+ lines. Complexity collapses when the abstraction is at the right height.
| impl | time (ms) | speedup |
|---|---|---|
| CUDA FA v1 (fp32) | 3.045 | 1.00× |
| Triton FA (fp16) | 0.496 | 6.14× |
① tl.dot uses Tensor Cores (our CUDA v1 is fp32 FMA). ② Autotune sweeps 6 configs of (BLOCK_M, BLOCK_N, num_warps, num_stages) — sweeping that by hand in CUDA means 6 recompiles. ③ tl.trans, 2-D pointer broadcasts, swizzled smem layouts — all automatic.
Filling with -inf still computes the full QKᵀ. FA-v2's real win comes from pulling the upper-triangle K tiles out of the iteration itself.
def flash_attention_mha_fwd_kernel(..., IS_CAUSAL: tl.constexpr, # ← key BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): if IS_CAUSAL: end_n = tl.minimum(N, (pid_m+1) * BLOCK_M) else: end_n = N
→ Compile separate kernels for IS_CAUSAL=True and =False. The if doesn't exist at runtime. Equivalent to C++ template specialization.
| (B, H, N, d) | ours | SDPA | ratio |
|---|---|---|---|
| (1, 32, 2048, 128) | 0.784 | 0.613 | 0.78× |
| (1, 32, 4096, 128) | 2.964 | 2.559 | 0.86× |
| (16, 12, 512, 64) | 0.249 | 0.282 | 1.13× |
One 268-line file at 78–90% of cuDNN. What's missing — async copy, persistent kernel, warp specialization — all still experimental in Triton.
@custom_op( "triton_training::flash_attention_mha", mutates_args=(), device_types="cuda", ) def flash_attention_mha_op(q, k, v, is_causal=False): return triton_flash_attention_mha(q, k, v, is_causal=is_causal) @flash_attention_mha_op.register_fake def _fake(q, k, v, is_causal=False): return torch.empty_like(q) # ← shape decl for Dynamo
When torch.compile traces a model, it uses FakeTensors (shape, dtype, device — no data). Our Triton kernel can't run on those. Instead, we declare "this op's output shape is this" → Dynamo doesn't break the graph. Without it, fullgraph=True fails.
Not a "high-level DSL." A language that sits at exactly the right abstraction height. The five below go under the compiler; the five above stay in your hands.
threadIdx.x is gonetl.sumtl.dot picks by dtypeprogram_id, grid=lambda metatl.load(mask=...)tl.constexpr*.ptx under TRITON_CACHE_DIR.Code ref · triton_kernels/ all 7 files · L4 sm_89, CUDA 13.0, PyTorch 2.11, Triton 3.6.