LESSON 08 · 2026.04.19 · L4
Triton vs CUDA — 추상화의 비용은 어디서 나타나는가
레슨 1–6 의 네 커널 (reduction, softmax, matmul, flash attention) 을
Triton 40–130 줄짜리로 재작성. L4 (sm_89) 에서 측정한 네 구간.
GPU · L4 (sm_89)
stack · torch 2.6 + triton 3.2
ported · 4 kernels
하드웨어 업그레이드 — 왜 L4 인가
T4 는 FP16 WMMA 한 종류, TF32 없음. Triton tl.dot 가 TF32 를 자동 쓰는 걸 측정하려면 TF32 TC 가 있는 GPU 필요. L4 (Ada) = TF32 121 · FP16 242 · FP8 485 TFLOPS, L2 48 MB (T4 의 8×).
메모리 바운드 — 3 접근이 10% 안에서 동률
| task | CUDA | torch | Triton |
| Reduction 67M fp32 | 258 GB/s | 254 GB/s | 245 GB/s |
| Softmax 4096² fp32 | 237 GB/s | 240 GB/s | 221 GB/s |
Triton 이 손 CUDA 의 93–95%. HBM 이 bottleneck 이면 추상화 비용은 증발한다.
컴퓨트 바운드 — Triton 이 cuBLAS 를 미세하게 이김
| task | CUDA | torch | Triton |
| matmul 4096³ FP32 (TF32) | 3.9 TF | 25.8 | 28.9 |
| matmul 4096³ FP16 | 18.5 TF | 51.8 | 54.0 |
우리 WMMA 대비 Triton 2.9×. cuBLAS 대비 FP32 +12%, FP16 +4%. autotune 이 사람 손 튜닝을 근소하게 넘은 지점.
Flash Attention — Triton fp16 이 polyglot
| N | CUDA FA fp32 | Triton fp32 | Triton fp16 | SDPA fp16 |
| 1024 | 0.324 | 0.148 | 0.122 | 0.076 |
| 2048 | 0.638 | 0.196 | 0.138 | 0.076 |
| 4096 | 1.256 | 0.358 | 0.207 | 0.127 |
| 8192 | 3.045 | 1.118 | 0.496 | 0.394 |
N=8192: Triton fp16 이 우리 CUDA FP32 대비 6.14×, SDPA (cuDNN FA-2) 의 79%. 100 줄짜리 Triton 이 cuDNN 의 80% 에 도달. Tri Dao 가 FA-2 를 Triton 으로 쓴 이유.
한 줄 = 수십 줄 이 네 번 반복
tl.sum(x) = 레슨 3 의 warp shuffle boilerplate 15 줄
tl.dot(a,b) = 레슨 5 의 WMMA fragment + mma_sync 50 줄 (+ dtype 으로 TC 자동 선택)
- Grouped program-id swizzle 9 줄 = CUDA 에선 짜기 자체가 고역
- Online softmax 50 줄 (레슨 6) →
tl.max + tl.maximum + tl.exp + tl.sum 15 줄
추상화의 비용이 실존하는 구간
| 구간 | Triton vs CUDA | 원인 |
| 작은 N (< 4 MB) | 3–12× 뒤짐 | Launch floor 50–100 µs (Python → autotune 캐시 → JIT → cuLaunch) |
| HBM 바운드 | 95% | 거의 없음 |
| 큰 matmul/FA | 이김 | autotune 이 사람보다 나은 config 선택 |
실무: Transformer layer 가 ≥1 ms 면 100 µs 오버헤드 10% 미만 — 인내 가능. element-wise 30 개를 각각 Triton 런치하면 망함.
두 개의 footgun
(1) TF32 벤치 거짓말. torch.matmul(fp32) 는 기본 TF32 미사용. tl.dot 는 사용. 그대로 비교하면 "Triton 이 torch 를 2× 이김" 처럼 보임. 공정 비교엔:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
(2) Autotune stale write. @triton.autotune 이 trial 중 output 버퍼에 오염된 write. reset_to_zero=["partial_ptr"] + 호출 후 best_config.kwargs["BLOCK_SIZE"] 로 slicing.
CUDA 왜 계속 배우나
- Triton 이 막히는 순간 CUDA 로 떨어짐 (Blackwell mma, persistent kernel, async copy).
- Triton 이 낸 PTX 를 읽을 줄 알아야 perf bug 추적.
- vLLM, FA-3, Mamba 는 여전히 CUDA.
- "왜 느림" 의 답이 bank conflict, register spill, occupancy — CUDA 개념.
CUDA = 어셈블리, Triton = C. 대부분은 C 로 짜고, 핫패스만 어셈블리.
LESSON 08 · 2026.04.19 · L4
Triton vs CUDA — where does the cost of abstraction show up?
Four kernels from lessons 1–6 (reduction, softmax, matmul, flash attention) rewritten in 40–130 lines of Triton. Four regimes measured on L4 (sm_89).
GPU · L4 (sm_89)
stack · torch 2.6 + triton 3.2
ported · 4 kernels
Hardware upgrade — why L4
T4 has only one FP16 WMMA flavor and no TF32. To measure Triton tl.dot automatically picking TF32, you need a GPU with TF32 Tensor Cores. L4 (Ada) = TF32 121 · FP16 242 · FP8 485 TFLOPS, L2 48 MB (8× T4's).
Memory-bound — three approaches tie within 10%
| task | CUDA | torch | Triton |
| Reduction 67M fp32 | 258 GB/s | 254 GB/s | 245 GB/s |
| Softmax 4096² fp32 | 237 GB/s | 240 GB/s | 221 GB/s |
Triton lands at 93–95% of hand-written CUDA. When HBM is the bottleneck, the abstraction tax evaporates.
Compute-bound — Triton narrowly beats cuBLAS
| task | CUDA | torch | Triton |
| matmul 4096³ FP32 (TF32) | 3.9 TF | 25.8 | 28.9 |
| matmul 4096³ FP16 | 18.5 TF | 51.8 | 54.0 |
2.9× over our WMMA. Over cuBLAS: FP32 +12%, FP16 +4%. The spot where autotune narrowly edges human hand-tuning.
Flash Attention — Triton fp16 is polyglot
| N | CUDA FA fp32 | Triton fp32 | Triton fp16 | SDPA fp16 |
| 1024 | 0.324 | 0.148 | 0.122 | 0.076 |
| 2048 | 0.638 | 0.196 | 0.138 | 0.076 |
| 4096 | 1.256 | 0.358 | 0.207 | 0.127 |
| 8192 | 3.045 | 1.118 | 0.496 | 0.394 |
N=8192: Triton fp16 is 6.14× over our CUDA FP32, and 79% of SDPA (cuDNN FA-2). 100 lines of Triton reaches 80% of cuDNN. This is why Tri Dao wrote FA-2 in Triton.
One line = dozens of lines, repeated four times
tl.sum(x) = 15 lines of warp shuffle boilerplate from Lesson 3
tl.dot(a,b) = 50 lines of WMMA fragment + mma_sync from Lesson 5 (plus automatic TC selection by dtype)
- Grouped program-id swizzle — 9 lines here, painful to write in CUDA
- Online softmax 50 lines (Lesson 6) →
tl.max + tl.maximum + tl.exp + tl.sum, 15 lines
Regions where the abstraction cost is real
| region | Triton vs CUDA | cause |
| Small N (< 4 MB) | 3–12× slower | Launch floor 50–100 µs (Python → autotune cache → JIT → cuLaunch) |
| HBM-bound | 95% | almost none |
| Large matmul / FA | wins | autotune picks a better config than a human |
Practical line: if the Transformer layer is ≥1 ms, a 100 µs overhead is <10% — tolerable. If you launch 30 element-wise ops individually in Triton, you're cooked.
Two footguns
(1) TF32 benchmark lie. torch.matmul(fp32) doesn't use TF32 by default. tl.dot does. Compare them head-to-head and "Triton beats torch 2×" — which looks great but is wrong. For a fair fight:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
(2) Autotune stale writes. @triton.autotune trials can leave tainted writes in the output buffer. Use reset_to_zero=["partial_ptr"] and slice with best_config.kwargs["BLOCK_SIZE"] after the call.
Why keep learning CUDA
- Triton hits a wall and you fall back to CUDA (Blackwell mma, persistent kernels, async copy).
- You need to read the PTX Triton emits to chase perf bugs.
- vLLM, FA-3, Mamba are still CUDA.
- The answer to "why is this slow" is bank conflict, register spill, occupancy — all CUDA concepts.
CUDA = assembly, Triton = C. Write most of it in C; reserve assembly for hot paths.