cudatraining · 학습 기록

LESSON 09 · 2026.04.20 · L4

MHA + Causal FA — 300 줄로 cuDNN 의 80%

2-D Triton FA 를 (B, H, N, d) + causal mask 로 확장. torch.library.custom_op 로 등록, torch.compile(fullgraph=True) 로 LLaMA 스타일 attention block 이 한 그래프에 들어간다.

GPU · L4 · fp16 kernel · 192 줄 (docstring 포함) op wrapper · 76 줄

바뀐 다섯 가지

변경위치핵심
2-D → 4-D stridekernel 시그니처(B,H,N,d) 의 4 stride × 3 tensor 전부 런치에서 전달
3-D grid(cdiv(N, BM), H, B)한 program 이 (batch, head) Q 블록 하나 담당
causal 특수화IS_CAUSAL: tl.constexprcausal / non-causal 두 커널 컴파일, 런타임 브랜치 제거
loop skip (FA-v2)end_n = min(N, (pid_m+1)*BM)상삼각 타일 자체를 iteration 에서 배제
Diagonal maskoffs_m[:,None] >= offs_n[None,:]대각 걸친 타일 한 개에만 실질 적용

정확도 (fp32 기준 vs max rel_err)

단계shape 수fp32 worstfp16 worst
non-causal573.5e-33.4e-4
causal (+ N=129/513 edge)601.1e-33.2e-4

속도 — LLaMA-7B causal, d=128, fp16

(B, H, N)oursSDPA (FA-2)ours/SDPAvs naive
(1,32,512)0.100 ms · 21.5 TF0.1001.00×10.97×
(1,32,1024)0.223 · 38.5 TF0.2020.90×29.6×
(1,32,2048)0.784 · 43.8 TF0.6130.78×31.4×
(1,32,4096)2.964 · 46.4 TF2.5590.86×32.7×
(2,32,2048)1.565 · 43.9 TF1.3720.88×31.3×

GPT-2 (d=64) 짧은 shape 에선 (16,12,512) 에서 1.13× 우세, (8,12,1024) 에서 동률.

교훈 1 · Causal 의 속도는 mask 가 아니다

mask 만 씌우면 상삼각도 여전히 로드 + -inf 로 변환만. 진짜 속도는 loop skip:

if IS_CAUSAL:
    end_n = tl.minimum(N, (pid_m + 1) * BLOCK_M)
else:
    end_n = N
for start_n in range(0, end_n, BLOCK_N):
    ...

평균 N/2 로 떨어지며 K/V 로딩 + tl.dot 두 번이 반으로. 검증: non-causal (1,32,2048,128) 2.643 ms vs causal 0.784 ms — FLOP 절반인데 시간은 3.3× 빠름 (파이프라인 amortize).

교훈 2 · tl.constexpr = 런타임 if 접기

일반 인자로 is_causal 넘기면: 타이트 K 루프 안에 브랜치 + warp scheduler 묶임. tl.constexpr → Triton 이 True/False 두 특수화 커널 컴파일. 런타임은 하나만 dispatch, 브랜치 자체가 존재 안 함. autotune(key=[...,"IS_CAUSAL"]) 로 causal/non-causal autotune 도 독립 — causal 은 BM=64/BN=128, non-causal 은 더 큰 타일이 유리하다는 게 실측됨.

교훈 3 · custom_op 가 주는 세 가지

@custom_op("triton_training::flash_attention_mha",
           mutates_args=(), device_types="cuda")
def flash_attention_mha(q, k, v, is_causal: bool = False) -> Tensor: ...

@flash_attention_mha.register_fake
def _(q, k, v, is_causal): return torch.empty_like(q)
  1. torch.compile 이 그래프를 안 끊는다. register_fake 가 shape inference 에 쓰여 Dynamo 가 unknown op 로 보지 않음. fullgraph=True 에서 LLaMA AttentionBlock 전체가 한 그래프.
  2. 직렬화 노드로 남는다. ONNX/AOT/TorchScript 에서 triton_training::flash_attention_mha 이 그대로 기록.
  3. torch.ops.<ns>.<op> path 드롭인. vLLM 같은 다운스트림이 Triton 임포트 없이 호출. 실제 vLLM 이 커스텀 커널 노출하는 패턴.

교훈 4 · 300 줄의 ROI

같은 일을 CUDA 로 하면 레슨 6 의 FA v1 (~500 줄 + 150 줄 host) 대비 5–10 배 분량 + 새 GPU 마다 수작업 재튜닝. Triton 을 배우는 이유 자체.

마지막 20% 갭의 정체

FA-2 의 남은 22% 는 Triton 의 갭이 아니라 cuDNN FA-2 의:

Triton 에 persistent / warp-spec 이 오는 중이지만 튜토리얼 수준. 이 갭을 닫으려면 CUTLASS 3.x — 다음 라운드 주제.

포지션 요약

cuDNN 의 80% 속도를 300 줄로. 나머지 20% 는 한 단계 더 아래 레이어에서만 얻어진다.