ESSAY 11 · 2026.04.20 · L4 · FINAL
300 줄짜리 Triton FA 가 cuDNN FA-2 의 80–90% 를 따라잡는 방법
LLaMA-7B shape 기준, L4 (sm_89), torch.compile(fullgraph=True).
세 가지 트릭의 합 — 4-D grid, IS_CAUSAL: tl.constexpr, torch.library.custom_op.
kernel · ~100 줄
op wrapper · 76 줄
결과 · FA-2 의 78–90% · 29–33× vs naïve
질문
레슨 8 의 Triton FA 는 2-D (N, d) non-causal 이었다. 실제 LLM 은 그 shape 가 아니다. 진짜 shape 는 (B, H, N, d) 4-D, 대부분 causal, 그리고 torch.compile 아래에서 그래프가 안 끊겨야 한다.
이 세 가지를 만족시키면서 100 줄 커널이 어디까지 밀리는가.
결과 먼저
| (B, H, N) | ours | SDPA (FA-2) | ours/SDPA | vs naïve |
| (1,32,512) d=128 | 0.100 ms · 21.5 TF | 0.100 | 1.00× | 10.97× |
| (1,32,1024) | 0.223 · 38.5 TF | 0.202 | 0.90× | 29.6× |
| (1,32,2048) | 0.784 · 43.8 TF | 0.613 | 0.78× | 31.4× |
| (1,32,4096) | 2.964 · 46.4 TF | 2.559 | 0.86× | 32.7× |
| (8,12,1024) d=64 | 0.302 · 42.7 TF | 0.303 | 1.00× | — |
| (16,12,512) d=64 | 0.249 · 25.9 TF | 0.282 | 1.13× | — |
d=128 LLaMA shape 에서 cuDNN FA-2 의 78–90%. d=64 GPT-2 shape 에선 동률 또는 13% 우세. naïve 대비 29–33×. kernel body + op wrapper = 300 줄 미만.
트릭 1 · 2-D → 4-D 는 stride 4 + grid 3
(batch, head) 짝마다 완전 독립적인 attention 문제 — 루프로 돌리지 않고 grid 축으로 표현한다.
grid = (triton.cdiv(N, BLOCK_M), H, B)
# kernel 내부
pid_m = tl.program_id(0) # BLOCK_M of queries
pid_h = tl.program_id(1) # which head
pid_b = tl.program_id(2) # which batch
q_base = Q_ptr + pid_b * stride_qb + pid_h * stride_qh
k_base = K_ptr + pid_b * stride_kb + pid_h * stride_kh
# 이 네 줄 아래는 레슨 8 과 동일
Triton 런치는 이 grid 전체를 "available 하면 동시 시작". L4 의 58 SM 이 최대한 채워진다. 루프로 풀었으면 SM 이 놀았을 것.
트릭 2 · Causal 의 속도는 mask 가 아니라 loop skip
mask 만 씌우면 상삼각 타일도 여전히 HBM 에서 로드 + -inf 로 치환. 진짜 속도는:
if IS_CAUSAL:
end_n = tl.minimum(N, (pid_m + 1) * BLOCK_M)
else:
end_n = N
for start_n in range(0, end_n, BLOCK_N):
...
평균 N/2 로 떨어지며 K/V load 와 2 번의 tl.dot 이 반으로.
검증. non-causal (1,32,2048,128) 이 2.643 ms. 같은 shape causal 이 0.784 ms. FLOP 은 절반인데 시간은 3.3× 빠름. 이유는 로딩 파이프라인이 타일 수에 비례해서 amortize 되기 때문. FA-v2 의 causal 최적화가 이 한 줄.
트릭 3 · tl.constexpr — 런타임 if 를 컴파일 타임으로 접기
일반 인자로 is_causal 을 넘기면 타이트한 K 루프 안에 브랜치 + warp scheduler 묶임. tl.constexpr 로 찍으면 Triton 이 True 와 False 두 개의 특수화 커널을 각각 컴파일. 런타임은 해당하는 것만 dispatch — 브랜치 자체가 존재 안 함.
추가 이득: @triton.autotune(key=[..., "IS_CAUSAL"]) 로 causal / non-causal 마다 autotune 도 독립적으로. 실측에서 causal 은 BM=64, BN=128, non-causal 은 더 큰 타일이 유리하다는 게 관찰됨.
트릭 4 · torch.library.custom_op — op 의 옷을 입히다
@custom_op("triton_training::flash_attention_mha",
mutates_args=(), device_types="cuda")
def flash_attention_mha(q, k, v, is_causal: bool = False) -> Tensor:
return triton_flash_attention_mha(q, k, v, is_causal=is_causal)
@flash_attention_mha.register_fake
def _(q, k, v, is_causal):
return torch.empty_like(q)
이 한 조각이 세 가지를 준다:
torch.compile 이 그래프를 안 끊는다. register_fake 가 shape inference 에 쓰여 Dynamo 가 unknown Python 함수로 보지 않음. fullgraph=True 로 LLaMA-스타일 AttentionBlock 이 통째로 한 그래프.
- 직렬화 노드로 남는다. ONNX/AOT/TorchScript 에서
triton_training::flash_attention_mha 이 그대로 기록.
torch.ops.<ns>.<op> 드롭인. 다운스트림은 Triton 임포트 없이 호출. vLLM 의 패턴 그대로.
bit-exact + graph break 0
[1] raw_wrapper vs torch.ops err = 0.00e+00 (causal True/False)
[2] eager vs compiled function err = 0.00e+00 — fullgraph=True 통과
[3] eager vs compiled block err = 0.00e+00 — AttentionBlock 전체가 1 그래프
[4] schema: triton_training::flash_attention_mha(
Tensor q, Tensor k, Tensor v, bool is_causal=False) -> Tensor
마지막 20% 는 Triton 의 갭이 아니다
FA-2 와의 78–90% 구간. 나머지 10–22% 는 cuDNN FA-2 가 쓰는:
- async copy + double/triple buffer — smem 로드와 이전 타일
mma 를 완전 오버랩
- persistent kernel — block 을 계속 살려두고 타일을 재분배
- warp specialization — 일부 warp 는 compute, 일부는 load-store 전담
Triton 에 persistent / warp specialization 이 오는 중이지만 아직 튜토리얼 레벨. 이 갭을 닫으려면 CUTLASS 3.x — 한 레이어 더 아래로 내려가야 한다. 다음 라운드의 주제.
함정 기록 (요약)
- L4 stockout — us-west4-a → us-east4-c zone rotation.
Python.h not found — libpython3.10-dev 설치.
- git 히스토리 없음 — 레슨 9 중간에 발견. 레슨 0 에서
git init + 첫 커밋이 make vector_add 보다 먼저다.
- naive OOM —
(B,H,N,N) score 텐서가 6 GB. naive_mem_bytes < 4 GB 가드 + "(skipped)" 로 벤치.
- "Not enough SMs" 경고 — L4 58 SM 이
max_autotune_gemm 요구치 미달. 기능 영향 없음.
포지션
cuDNN 의 80–90% 속도를 300 줄로. 같은 일을 CUDA 로 하면 레슨 6 의 FA v1 (~650 줄) 의 5–10 배 + 새 GPU 마다 수작업 재튜닝. Triton 을 배우는 이유 자체가 이 ROI 에 있다.
다음
Backward + autograd (register_autograd) → GQA (MQA/grouped) → persistent kernel + async copy 로 마지막 20% 닫기 → vLLM PagedAttention 포팅. Phase 2 의 재료.
ESSAY 11 · 2026.04.20 · L4 · FINAL
How a 300-line Triton FA closes in on 80–90% of cuDNN FA-2
LLaMA-7B shapes, L4 (sm_89), under torch.compile(fullgraph=True). Three tricks combined — 4-D grid, IS_CAUSAL: tl.constexpr, torch.library.custom_op.
kernel · ~100 lines
op wrapper · 76 lines
result · 78–90% of FA-2 · 29–33× vs naïve
The question
Lesson 8's Triton FA was 2-D (N, d) and non-causal. Real LLMs don't look like that. The real shapes are (B, H, N, d) 4-D, mostly causal, and the graph must not break under torch.compile.
Meeting all three, how far can a 100-line kernel push?
Results first
| (B, H, N) | ours | SDPA (FA-2) | ours/SDPA | vs naïve |
| (1,32,512) d=128 | 0.100 ms · 21.5 TF | 0.100 | 1.00× | 10.97× |
| (1,32,1024) | 0.223 · 38.5 TF | 0.202 | 0.90× | 29.6× |
| (1,32,2048) | 0.784 · 43.8 TF | 0.613 | 0.78× | 31.4× |
| (1,32,4096) | 2.964 · 46.4 TF | 2.559 | 0.86× | 32.7× |
| (8,12,1024) d=64 | 0.302 · 42.7 TF | 0.303 | 1.00× | — |
| (16,12,512) d=64 | 0.249 · 25.9 TF | 0.282 | 1.13× | — |
78–90% of cuDNN FA-2 on LLaMA-like d=128 shapes. Tied or 13% ahead on GPT-2 d=64. 29–33× over naïve. kernel body + op wrapper = under 300 lines.
Trick 1 · 2-D → 4-D = 4 strides + 3-D grid
Every (batch, head) pair is a fully independent attention problem — don't loop over them; express the fan-out on the grid axes.
grid = (triton.cdiv(N, BLOCK_M), H, B)
# inside the kernel
pid_m = tl.program_id(0) # BLOCK_M of queries
pid_h = tl.program_id(1) # which head
pid_b = tl.program_id(2) # which batch
q_base = Q_ptr + pid_b * stride_qb + pid_h * stride_qh
k_base = K_ptr + pid_b * stride_kb + pid_h * stride_kh
# everything below these four lines is identical to Lesson 8
Triton's launch starts the whole grid "as soon as resources allow." L4's 58 SMs fill up. With a loop, SMs would sit idle.
Trick 2 · Causal speed is loop skip, not the mask
Masking alone still loads the upper triangle from HBM and flips it to -inf. Real speed comes from:
if IS_CAUSAL:
end_n = tl.minimum(N, (pid_m + 1) * BLOCK_M)
else:
end_n = N
for start_n in range(0, end_n, BLOCK_N):
...
Average drops to N/2, so both the K/V loads and the two tl.dot calls are halved.
Verified. non-causal (1,32,2048,128) = 2.643 ms. Same shape causal = 0.784 ms. Half the FLOPs, 3.3× faster. The reason: the load pipeline amortizes proportionally to the number of tiles. FA-v2's causal optimization is this one line.
Trick 3 · tl.constexpr — fold a runtime if into compile time
Pass is_causal as a normal argument and you get a branch inside the tight K loop, tangled with the warp scheduler. Mark it tl.constexpr and Triton compiles two specialized kernels (True and False). At runtime, only one is dispatched — the branch simply isn't there.
Extra win: @triton.autotune(key=[..., "IS_CAUSAL"]) separates autotune per case. Empirically causal prefers BM=64, BN=128, non-causal prefers bigger tiles.
Trick 4 · torch.library.custom_op — dress it as an op
@custom_op("triton_training::flash_attention_mha",
mutates_args=(), device_types="cuda")
def flash_attention_mha(q, k, v, is_causal: bool = False) -> Tensor:
return triton_flash_attention_mha(q, k, v, is_causal=is_causal)
@flash_attention_mha.register_fake
def _(q, k, v, is_causal):
return torch.empty_like(q)
This one block gives you three things:
torch.compile doesn't break the graph. register_fake supplies shape inference so Dynamo doesn't see an unknown Python function. Under fullgraph=True, the whole LLaMA-style AttentionBlock stays in one graph.
- Serialization-preserving. ONNX/AOT/TorchScript record
triton_training::flash_attention_mha as-is.
- Drop-in
torch.ops.<ns>.<op>. Downstreams call without importing Triton. The same pattern vLLM uses.
bit-exact + zero graph breaks
[1] raw_wrapper vs torch.ops err = 0.00e+00 (causal True/False)
[2] eager vs compiled function err = 0.00e+00 — fullgraph=True passes
[3] eager vs compiled block err = 0.00e+00 — AttentionBlock is a single graph
[4] schema: triton_training::flash_attention_mha(
Tensor q, Tensor k, Tensor v, bool is_causal=False) -> Tensor
The last 20% is not Triton's gap
In the 78–90% of FA-2 range, the remaining 10–22% is what cuDNN FA-2 has:
- async copy + double/triple buffering — full overlap of smem loads with the prior tile's
mma
- persistent kernel — keep blocks alive and re-dispatch tiles
- warp specialization — some warps compute, others do load-store
Triton is getting persistent / warp specialization but they're tutorial-grade. Closing the gap means going one layer deeper — CUTLASS 3.x. Next round's topic.
Trap log (digest)
- L4 stockout — us-west4-a → us-east4-c zone rotation.
Python.h not found — install libpython3.10-dev.
- No git history — discovered mid-Lesson 9.
git init + first commit should come before make vector_add.
- naive OOM — the
(B,H,N,N) score tensor was 6 GB. Guard with naive_mem_bytes < 4 GB and tag "(skipped)" in the bench.
- "Not enough SMs" warning — L4's 58 SMs fall short of
max_autotune_gemm's requirement. No functional impact.
Position
80–90% of cuDNN in 300 lines. Doing the same in CUDA would be 5–10× Lesson 6's FA v1 (~650 lines) plus hand-tuning on every new GPU. This ROI is the whole reason to learn Triton.
What's next
Backward + autograd (register_autograd) → GQA (MQA/grouped) → closing the last 20% with persistent kernel + async copy → porting vLLM's PagedAttention. The material for Phase 2.