ESSAY 10 · 2026.04.19 · L4
50 줄의 Python 이 5000 줄의 CUDA 를 대체할 수 있는가
추상화의 비용은 어디서 나타나는가. L4 에서 네 개의 커널로 받은 네 구간의 답.
질문 · 추상화의 ROI
대상 · Triton / CUDA / cuBLAS·cuDNN
길이 · essay
물음
Triton 은 "CUDA 를 안 쓰고도 CUDA 수준의 성능" 이라 선전한다. 이건 마케팅에 가깝게 들리지만, Tri Dao 는 실제로 Flash Attention-2 를 Triton 으로 썼다. OpenAI 의 Kernel Gym 도 Triton 이다. 그 말은 참인가?
이 질문을 숫자로만 풀기로 했다. 레슨 1–6 에서 쓴 네 개의 커널 — reduction, softmax, matmul, flash attention — 을 Triton 으로 재작성. 같은 L4 (sm_89) 에서 CUDA / torch 내장 / Triton 세 버전을 나란히 돌린다. 질문은 한 가지다. 어느 구간에서 추상화가 비싸고, 어느 구간에서 공짜인가.
네 구간
| 구간 | Triton vs CUDA | 원인 |
| 작은 N (< 4 MB) | 3–12× 뒤짐 | Python → autotune → JIT → cuLaunch 의 런치 floor ~50–100 µs |
| HBM 바운드 중간 | 95% | 거의 없음 — HBM 이 bottleneck 이면 컴파일러가 사람을 크게 못 넘음 |
| HBM 바운드 큰 | 동률 | 없음 |
| 컴퓨트 바운드 큰 matmul/FA | 이김 | autotune 이 사람 손 튜닝을 근소하게 넘음 |
증거 1 · HBM 바운드는 전부 동률
| task | CUDA | torch | Triton |
| Reduction 67M | 258 GB/s · 86% | 254 · 85% | 245 · 82% |
| Softmax 4096² | 237 GB/s | 240 | 221 |
HBM 바닥 (300 GB/s) 의 82–86%. 세 접근의 격차는 노이즈 수준이다. 이 구간에선 "누가 더 빨리 HBM 을 빨아들이는가" 가 전부라 JIT 가 손 CUDA 를 못 따라가는 이유가 없다.
증거 2 · 컴퓨트 바운드에선 Triton 이 cuBLAS 를 이긴다
| matmul 4096³ | CUDA | cuBLAS | Triton |
| FP32 (TF32 TC) | 3.9 TF (우리 v3) | 25.8 | 28.9 (+12%) |
| FP16 | 18.5 TF (우리 WMMA) | 51.8 | 54.0 (+4%) |
우리 WMMA 대비 2.9×. 더 중요한 건 cuBLAS 조차 근소하게 넘음. 이유는 autotune config 가 사람보다 더 많은 점을 본다는 것. 그리고 TF32/FP16 선택이 tl.dot 의 dtype 으로 자동 결정.
증거 3 · Flash Attention (N=8192)
Triton fp16 vs 우리 CUDA FA fp326.14× 빠름
Triton fp32 vs SDPA fp322.35× 빠름 (torch fp32 가 L4 에서 FA 안 탐)
Triton fp16 vs SDPA fp16 (cuDNN FA-2)0.79× (우리가 25% 느림)
100 줄 Triton 이 cuDNN FA-2 의 79%. Tri Dao 가 FA-2 를 Triton 으로 쓴 이유. 이 숫자를 보면 Triton 은 더 이상 실험이 아니다.
한 줄 = 수십 줄 이 네 번 반복
tl.sum(x, axis=0) = warp shuffle boilerplate 15 줄
tl.dot(a, b) = WMMA fragment + load_matrix_sync + mma_sync 50 줄 (dtype 으로 TC 자동)
- Grouped program-id swizzle 9 줄 = CUDA 에선 짜기 자체가 고역
- Online softmax 50 줄 (레슨 6) →
tl.max + tl.maximum + tl.exp + tl.sum 15 줄
비용이 실존하는 구간
레슨 3 의 reduction 에서 n=2²⁰ (4 MB) 일 때 Triton 이 CUDA 의 3–12 배 느렸다. 원인은 하나, launch floor 50–100 µs. Python 인터프리터 → autotune 캐시 조회 → JIT (처음 한 번) → cuLaunchKernel. 이 floor 는 커널이 짧을수록 상대 비중이 커진다.
Transformer layer 하나가 ≥1 ms 면 100 µs 는 10% 미만 — 인내 가능. 하지만 element-wise 30 개를 각각 Triton 런치하면 망한다. 작은 연산은 PyTorch eager 나 torch.compile 이 낫다.
두 개의 footgun
(1) TF32 벤치 거짓말
torch.matmul(fp32) 는 기본 TF32 미사용. tl.dot 는 사용. 그대로 비교하면 "Triton 이 torch 를 2× 이김" 처럼 보임. 공정 비교:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
이 두 줄을 세팅한 뒤에야 두 쪽 다 TF32 TC 경로. 세팅 전 Phase 3 결과는 2× 거짓이었음.
(2) Autotune stale write
@triton.autotune 이 trial 중 partial buffer 를 다른 config 로 오염. reduction 의 부분합이 섞여 rel_err = 1.75 로 터짐. 해결:
@triton.autotune(configs=..., reset_to_zero=["partial_ptr"])
@triton.jit
def kernel(...): ...
best = kernel.best_config
block = best.kwargs["BLOCK_SIZE"]
return partial[:cdiv(n, block)].sum()
그래서 왜 여전히 CUDA
- Triton 이 막힘 — Blackwell 의 BF8/FP4 mma, persistent kernel, async copy 세부 제어.
- PTX 독해 — Triton 이 낸 PTX 를 읽어야 perf bug 추적이 된다.
TRITON_CACHE_DIR 에 *.ptx, *.cubin.
- 레퍼런스 코드 — vLLM, FA-3, Mamba 커널은 아직 CUDA. 그 코드 읽으려면 CUDA 가 모국어여야.
- 성능 진단 어휘 — "왜 느림?" 의 답은 bank conflict, register spill, occupancy. 이 개념은 Triton 에서도 그대로 통용.
비유
CUDA = 어셈블리, Triton = C. 대부분은 C 로 짜고, 핫패스만 어셈블리로 남긴다. 그리고 C 컴파일러 버그가 생기면 어셈블리를 읽어야 한다.
ESSAY 10 · 2026.04.19 · L4
Can 50 lines of Python replace 5000 lines of CUDA?
Where does the cost of abstraction show up? Four kernels on L4, four regimes with the answer.
question · the ROI of abstraction
subjects · Triton / CUDA / cuBLAS·cuDNN
length · essay
The question
Triton's pitch is "CUDA-level performance without CUDA." Sounds like marketing — but Tri Dao actually wrote Flash Attention-2 in Triton. OpenAI's Kernel Gym is Triton. Is the claim true?
I decided to answer only with numbers. The four kernels from Lessons 1–6 — reduction, softmax, matmul, flash attention — rewritten in Triton. Run CUDA / torch built-in / Triton side by side on the same L4 (sm_89). One question: in which regime is abstraction expensive, and in which is it free?
Four regimes
| regime | Triton vs CUDA | cause |
| Small N (< 4 MB) | 3–12× slower | Python → autotune → JIT → cuLaunch; launch floor ~50–100 µs |
| HBM-bound medium | 95% | nearly none — if HBM is the bottleneck, a compiler can't overshoot a human by much |
| HBM-bound large | tie | none |
| Compute-bound large matmul / FA | wins | autotune narrowly edges out human tuning |
Evidence 1 · HBM-bound is a tie
| task | CUDA | torch | Triton |
| Reduction 67M | 258 GB/s · 86% | 254 · 85% | 245 · 82% |
| Softmax 4096² | 237 GB/s | 240 | 221 |
82–86% of HBM floor (300 GB/s). The spread among the three approaches is noise. In this regime, everything is "who can drink HBM fastest," and there's no reason a JIT would fall behind hand-written CUDA.
Evidence 2 · Triton edges out cuBLAS in compute-bound
| matmul 4096³ | CUDA | cuBLAS | Triton |
| FP32 (TF32 TC) | 3.9 TF (our v3) | 25.8 | 28.9 (+12%) |
| FP16 | 18.5 TF (our WMMA) | 51.8 | 54.0 (+4%) |
2.9× over our WMMA. More importantly, narrowly over cuBLAS. The reason: autotune configs explore more points than a human. And TF32/FP16 selection happens automatically through tl.dot's dtype.
Evidence 3 · Flash Attention (N=8192)
Triton fp16 vs our CUDA FA fp326.14× faster
Triton fp32 vs SDPA fp322.35× faster (torch fp32 doesn't take the FA path on L4)
Triton fp16 vs SDPA fp16 (cuDNN FA-2)0.79× (we're 25% slower)
100 lines of Triton at 79% of cuDNN FA-2. This is why Tri Dao wrote FA-2 in Triton. Look at the number and Triton stops being an experiment.
One line = dozens of lines, repeated four times
tl.sum(x, axis=0) = 15 lines of warp shuffle boilerplate
tl.dot(a, b) = 50 lines of WMMA fragment + load_matrix_sync + mma_sync (with automatic TC selection by dtype)
- Grouped program-id swizzle in 9 lines — writing the same thing in CUDA is painful
- Online softmax 50 lines (Lesson 6) →
tl.max + tl.maximum + tl.exp + tl.sum, 15 lines
Where the cost is real
At n=2²⁰ (4 MB) in the reduction of Lesson 3, Triton was 3–12× slower than CUDA. One cause: launch floor 50–100 µs. Python interpreter → autotune cache lookup → JIT (once) → cuLaunchKernel. The floor's relative weight grows as the kernel shrinks.
If a Transformer layer is ≥1 ms, 100 µs is <10% — fine. But launch 30 element-wise ops individually in Triton and you're done. For small ops, PyTorch eager or torch.compile is better.
Two footguns
(1) TF32 benchmark lie
torch.matmul(fp32) doesn't use TF32 by default. tl.dot does. Compare them head-to-head and you see "Triton beats torch 2×." A fair comparison:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
With those two lines, both sides take the TF32 TC path. Before I set them, Phase 3's 2× lead was a lie.
(2) Autotune stale writes
@triton.autotune can leave a partial buffer polluted by other configs mid-trial. Reduction partials got mixed, rel_err blew up to 1.75. Fix:
@triton.autotune(configs=..., reset_to_zero=["partial_ptr"])
@triton.jit
def kernel(...): ...
best = kernel.best_config
block = best.kwargs["BLOCK_SIZE"]
return partial[:cdiv(n, block)].sum()
So why still CUDA
- Triton hits walls — Blackwell's BF8/FP4 mma, persistent kernel, async-copy fine-grained control.
- Reading PTX — chasing perf bugs means reading the PTX Triton emits.
*.ptx, *.cubin in TRITON_CACHE_DIR.
- Reference code — vLLM, FA-3, Mamba kernels are still CUDA. Reading them requires CUDA as your first language.
- Perf diagnosis vocabulary — "why is it slow?" answers are bank conflict, register spill, occupancy. Those concepts carry into Triton unchanged.
Analogy
CUDA = assembly, Triton = C. Most code in C, hot paths in assembly. And when the C compiler has a bug, you have to read the assembly.