LESSON 09 · 2026.04.20 · L4
MHA + Causal FA — 300 줄로 cuDNN 의 80%
2-D Triton FA 를 (B, H, N, d) + causal mask 로 확장.
torch.library.custom_op 로 등록, torch.compile(fullgraph=True)
로 LLaMA 스타일 attention block 이 한 그래프에 들어간다.
GPU · L4 · fp16
kernel · 192 줄 (docstring 포함)
op wrapper · 76 줄
바뀐 다섯 가지
| 변경 | 위치 | 핵심 |
| 2-D → 4-D stride | kernel 시그니처 | (B,H,N,d) 의 4 stride × 3 tensor 전부 런치에서 전달 |
| 3-D grid | (cdiv(N, BM), H, B) | 한 program 이 (batch, head) Q 블록 하나 담당 |
| causal 특수화 | IS_CAUSAL: tl.constexpr | causal / non-causal 두 커널 컴파일, 런타임 브랜치 제거 |
| loop skip (FA-v2) | end_n = min(N, (pid_m+1)*BM) | 상삼각 타일 자체를 iteration 에서 배제 |
| Diagonal mask | offs_m[:,None] >= offs_n[None,:] | 대각 걸친 타일 한 개에만 실질 적용 |
정확도 (fp32 기준 vs max rel_err)
| 단계 | shape 수 | fp32 worst | fp16 worst |
| non-causal | 57 | 3.5e-3 | 3.4e-4 |
| causal (+ N=129/513 edge) | 60 | 1.1e-3 | 3.2e-4 |
속도 — LLaMA-7B causal, d=128, fp16
| (B, H, N) | ours | SDPA (FA-2) | ours/SDPA | vs naive |
| (1,32,512) | 0.100 ms · 21.5 TF | 0.100 | 1.00× | 10.97× |
| (1,32,1024) | 0.223 · 38.5 TF | 0.202 | 0.90× | 29.6× |
| (1,32,2048) | 0.784 · 43.8 TF | 0.613 | 0.78× | 31.4× |
| (1,32,4096) | 2.964 · 46.4 TF | 2.559 | 0.86× | 32.7× |
| (2,32,2048) | 1.565 · 43.9 TF | 1.372 | 0.88× | 31.3× |
GPT-2 (d=64) 짧은 shape 에선 (16,12,512) 에서 1.13× 우세, (8,12,1024) 에서 동률.
교훈 1 · Causal 의 속도는 mask 가 아니다
mask 만 씌우면 상삼각도 여전히 로드 + -inf 로 변환만. 진짜 속도는 loop skip:
if IS_CAUSAL:
end_n = tl.minimum(N, (pid_m + 1) * BLOCK_M)
else:
end_n = N
for start_n in range(0, end_n, BLOCK_N):
...
평균 N/2 로 떨어지며 K/V 로딩 + tl.dot 두 번이 반으로. 검증: non-causal (1,32,2048,128) 2.643 ms vs causal 0.784 ms — FLOP 절반인데 시간은 3.3× 빠름 (파이프라인 amortize).
교훈 2 · tl.constexpr = 런타임 if 접기
일반 인자로 is_causal 넘기면: 타이트 K 루프 안에 브랜치 + warp scheduler 묶임. tl.constexpr → Triton 이 True/False 두 특수화 커널 컴파일. 런타임은 하나만 dispatch, 브랜치 자체가 존재 안 함. autotune(key=[...,"IS_CAUSAL"]) 로 causal/non-causal autotune 도 독립 — causal 은 BM=64/BN=128, non-causal 은 더 큰 타일이 유리하다는 게 실측됨.
교훈 3 · custom_op 가 주는 세 가지
@custom_op("triton_training::flash_attention_mha",
mutates_args=(), device_types="cuda")
def flash_attention_mha(q, k, v, is_causal: bool = False) -> Tensor: ...
@flash_attention_mha.register_fake
def _(q, k, v, is_causal): return torch.empty_like(q)
torch.compile 이 그래프를 안 끊는다. register_fake 가 shape inference 에 쓰여 Dynamo 가 unknown op 로 보지 않음. fullgraph=True 에서 LLaMA AttentionBlock 전체가 한 그래프.
- 직렬화 노드로 남는다. ONNX/AOT/TorchScript 에서
triton_training::flash_attention_mha 이 그대로 기록.
torch.ops.<ns>.<op> path 드롭인. vLLM 같은 다운스트림이 Triton 임포트 없이 호출. 실제 vLLM 이 커스텀 커널 노출하는 패턴.
교훈 4 · 300 줄의 ROI
- kernel 100 줄 + op wrapper 76 줄 + autotune/register_fake ≈ 300 줄
- d=128 LLaMA-7B causal 에서 cuDNN FA-2 의 78–90%
- d=64 GPT-2 에선 동률 or 13% 우세
- naïve 대비 29–33×
torch.compile(fullgraph=True) 로 attention block 에 드롭인
같은 일을 CUDA 로 하면 레슨 6 의 FA v1 (~500 줄 + 150 줄 host) 대비 5–10 배 분량 + 새 GPU 마다 수작업 재튜닝. Triton 을 배우는 이유 자체.
마지막 20% 갭의 정체
FA-2 의 남은 22% 는 Triton 의 갭이 아니라 cuDNN FA-2 의:
- async copy + double/triple buffer (smem load ↔ 이전 타일 mma 오버랩)
- persistent kernel 스케줄링
- warp specialization (compute warp vs load-store warp)
Triton 에 persistent / warp-spec 이 오는 중이지만 튜토리얼 수준. 이 갭을 닫으려면 CUTLASS 3.x — 다음 라운드 주제.
포지션 요약
cuDNN 의 80% 속도를 300 줄로. 나머지 20% 는 한 단계 더 아래 레이어에서만 얻어진다.
LESSON 09 · 2026.04.20 · L4
MHA + Causal FA — 80% of cuDNN in 300 lines
Extend the 2-D Triton FA to (B, H, N, d) + a causal mask. Register with torch.library.custom_op, and with torch.compile(fullgraph=True) an entire LLaMA-style attention block lands in a single graph.
GPU · L4 · fp16
kernel · 192 lines (incl. docstring)
op wrapper · 76 lines
Five things that changed
| change | location | core idea |
| 2-D → 4-D stride | kernel signature | Pass all 4 strides × 3 tensors for (B,H,N,d) at launch |
| 3-D grid | (cdiv(N, BM), H, B) | One program handles a single (batch, head) Q block |
| causal specialization | IS_CAUSAL: tl.constexpr | Compile two kernels (causal / non-causal); remove the runtime branch |
| loop skip (FA-v2) | end_n = min(N, (pid_m+1)*BM) | Drop upper-triangle tiles from iteration entirely |
| Diagonal mask | offs_m[:,None] >= offs_n[None,:] | Only the tile that straddles the diagonal actually applies it |
Accuracy (max rel_err vs fp32 reference)
| stage | # shapes | fp32 worst | fp16 worst |
| non-causal | 57 | 3.5e-3 | 3.4e-4 |
| causal (+ N=129/513 edge) | 60 | 1.1e-3 | 3.2e-4 |
Speed — LLaMA-7B causal, d=128, fp16
| (B, H, N) | ours | SDPA (FA-2) | ours/SDPA | vs naive |
| (1,32,512) | 0.100 ms · 21.5 TF | 0.100 | 1.00× | 10.97× |
| (1,32,1024) | 0.223 · 38.5 TF | 0.202 | 0.90× | 29.6× |
| (1,32,2048) | 0.784 · 43.8 TF | 0.613 | 0.78× | 31.4× |
| (1,32,4096) | 2.964 · 46.4 TF | 2.559 | 0.86× | 32.7× |
| (2,32,2048) | 1.565 · 43.9 TF | 1.372 | 0.88× | 31.3× |
On GPT-2 (d=64) short shapes: 1.13× ahead at (16,12,512), tied at (8,12,1024).
Lesson 1 · Causal speed isn't from the mask
Just masking still loads the upper triangle and flips it to -inf. The real speed comes from loop skip:
if IS_CAUSAL:
end_n = tl.minimum(N, (pid_m + 1) * BLOCK_M)
else:
end_n = N
for start_n in range(0, end_n, BLOCK_N):
...
On average it drops to N/2, so both the K/V loads and the two tl.dots are halved. Verified: non-causal (1,32,2048,128) 2.643 ms vs causal 0.784 ms — half the FLOPs, but 3.3× faster (pipeline amortization).
Lesson 2 · tl.constexpr = fold away the runtime if
Pass is_causal as a plain arg and you get a branch inside the tight K loop plus warp-scheduler tangles. With tl.constexpr, Triton compiles two specialized kernels (True/False). At runtime, only one is dispatched — the branch doesn't even exist. Using autotune(key=[...,"IS_CAUSAL"]) also separates the autotune space: empirically, causal prefers BM=64 / BN=128 while non-causal likes bigger tiles.
Lesson 3 · Three gifts from custom_op
@custom_op("triton_training::flash_attention_mha",
mutates_args=(), device_types="cuda")
def flash_attention_mha(q, k, v, is_causal: bool = False) -> Tensor: ...
@flash_attention_mha.register_fake
def _(q, k, v, is_causal): return torch.empty_like(q)
torch.compile doesn't break the graph. register_fake supplies shape inference so Dynamo doesn't see an unknown op. Under fullgraph=True, the entire LLaMA AttentionBlock lives in one graph.
- Serialization-preserving. ONNX/AOT/TorchScript record
triton_training::flash_attention_mha as-is.
- Drop-in
torch.ops.<ns>.<op> path. Downstreams like vLLM call without importing Triton. The same pattern vLLM uses to expose custom kernels.
Lesson 4 · The ROI of 300 lines
- kernel 100 lines + op wrapper 76 lines + autotune/register_fake ≈ 300 lines
- 78–90% of cuDNN FA-2 on d=128 LLaMA-7B causal
- Tied or 13% ahead on GPT-2 (d=64)
- 29–33× over naive
- Drops into an attention block under
torch.compile(fullgraph=True)
Doing the same in CUDA: 5–10× the code (Lesson 6's FA v1 was ~500 lines + 150 lines of host), plus manual re-tuning on every new GPU. This is the whole reason to learn Triton.
The remaining 20% gap
The 22% left to FA-2 isn't Triton's gap — it's what cuDNN FA-2 has:
- async copy + double/triple buffering (smem load ↔ prior tile mma overlap)
- persistent kernel scheduling
- warp specialization (compute warps vs load-store warps)
Triton is getting persistent / warp-spec but it's tutorial-stage. Closing this gap takes CUTLASS 3.x — next round.
Position summary
80% of cuDNN in 300 lines. The remaining 20% only comes from the layer below.