cudatraining · 학습 기록

ESSAY 11 · 2026.04.20 · L4 · FINAL

300 줄짜리 Triton FA 가 cuDNN FA-2 의 80–90% 를 따라잡는 방법

LLaMA-7B shape 기준, L4 (sm_89), torch.compile(fullgraph=True). 세 가지 트릭의 합 — 4-D grid, IS_CAUSAL: tl.constexpr, torch.library.custom_op.

kernel · ~100 줄 op wrapper · 76 줄 결과 · FA-2 의 78–90% · 29–33× vs naïve

질문

레슨 8 의 Triton FA 는 2-D (N, d) non-causal 이었다. 실제 LLM 은 그 shape 가 아니다. 진짜 shape 는 (B, H, N, d) 4-D, 대부분 causal, 그리고 torch.compile 아래에서 그래프가 안 끊겨야 한다.

이 세 가지를 만족시키면서 100 줄 커널이 어디까지 밀리는가.

결과 먼저

(B, H, N)oursSDPA (FA-2)ours/SDPAvs naïve
(1,32,512) d=1280.100 ms · 21.5 TF0.1001.00×10.97×
(1,32,1024)0.223 · 38.5 TF0.2020.90×29.6×
(1,32,2048)0.784 · 43.8 TF0.6130.78×31.4×
(1,32,4096)2.964 · 46.4 TF2.5590.86×32.7×
(8,12,1024) d=640.302 · 42.7 TF0.3031.00×
(16,12,512) d=640.249 · 25.9 TF0.2821.13×

d=128 LLaMA shape 에서 cuDNN FA-2 의 78–90%. d=64 GPT-2 shape 에선 동률 또는 13% 우세. naïve 대비 29–33×. kernel body + op wrapper = 300 줄 미만.

트릭 1 · 2-D → 4-D 는 stride 4 + grid 3

(batch, head) 짝마다 완전 독립적인 attention 문제 — 루프로 돌리지 않고 grid 축으로 표현한다.

grid = (triton.cdiv(N, BLOCK_M), H, B)

# kernel 내부
pid_m = tl.program_id(0)   # BLOCK_M of queries
pid_h = tl.program_id(1)   # which head
pid_b = tl.program_id(2)   # which batch

q_base = Q_ptr + pid_b * stride_qb + pid_h * stride_qh
k_base = K_ptr + pid_b * stride_kb + pid_h * stride_kh
# 이 네 줄 아래는 레슨 8 과 동일

Triton 런치는 이 grid 전체를 "available 하면 동시 시작". L4 의 58 SM 이 최대한 채워진다. 루프로 풀었으면 SM 이 놀았을 것.

트릭 2 · Causal 의 속도는 mask 가 아니라 loop skip

mask 만 씌우면 상삼각 타일도 여전히 HBM 에서 로드 + -inf 로 치환. 진짜 속도는:

if IS_CAUSAL:
    end_n = tl.minimum(N, (pid_m + 1) * BLOCK_M)
else:
    end_n = N
for start_n in range(0, end_n, BLOCK_N):
    ...

평균 N/2 로 떨어지며 K/V load 와 2 번의 tl.dot 이 반으로.

검증. non-causal (1,32,2048,128) 이 2.643 ms. 같은 shape causal 이 0.784 ms. FLOP 은 절반인데 시간은 3.3× 빠름. 이유는 로딩 파이프라인이 타일 수에 비례해서 amortize 되기 때문. FA-v2 의 causal 최적화가 이 한 줄.

트릭 3 · tl.constexpr — 런타임 if 를 컴파일 타임으로 접기

일반 인자로 is_causal 을 넘기면 타이트한 K 루프 안에 브랜치 + warp scheduler 묶임. tl.constexpr 로 찍으면 Triton 이 True 와 False 두 개의 특수화 커널을 각각 컴파일. 런타임은 해당하는 것만 dispatch — 브랜치 자체가 존재 안 함.

추가 이득: @triton.autotune(key=[..., "IS_CAUSAL"]) 로 causal / non-causal 마다 autotune 도 독립적으로. 실측에서 causal 은 BM=64, BN=128, non-causal 은 더 큰 타일이 유리하다는 게 관찰됨.

트릭 4 · torch.library.custom_op — op 의 옷을 입히다

@custom_op("triton_training::flash_attention_mha",
           mutates_args=(), device_types="cuda")
def flash_attention_mha(q, k, v, is_causal: bool = False) -> Tensor:
    return triton_flash_attention_mha(q, k, v, is_causal=is_causal)

@flash_attention_mha.register_fake
def _(q, k, v, is_causal):
    return torch.empty_like(q)

이 한 조각이 세 가지를 준다:

  1. torch.compile 이 그래프를 안 끊는다. register_fake 가 shape inference 에 쓰여 Dynamo 가 unknown Python 함수로 보지 않음. fullgraph=True 로 LLaMA-스타일 AttentionBlock 이 통째로 한 그래프.
  2. 직렬화 노드로 남는다. ONNX/AOT/TorchScript 에서 triton_training::flash_attention_mha 이 그대로 기록.
  3. torch.ops.<ns>.<op> 드롭인. 다운스트림은 Triton 임포트 없이 호출. vLLM 의 패턴 그대로.

bit-exact + graph break 0

[1] raw_wrapper vs torch.ops      err = 0.00e+00  (causal True/False)
[2] eager vs compiled function    err = 0.00e+00  — fullgraph=True 통과
[3] eager vs compiled block       err = 0.00e+00  — AttentionBlock 전체가 1 그래프
[4] schema: triton_training::flash_attention_mha(
        Tensor q, Tensor k, Tensor v, bool is_causal=False) -> Tensor

마지막 20% 는 Triton 의 갭이 아니다

FA-2 와의 78–90% 구간. 나머지 10–22% 는 cuDNN FA-2 가 쓰는:

Triton 에 persistent / warp specialization 이 오는 중이지만 아직 튜토리얼 레벨. 이 갭을 닫으려면 CUTLASS 3.x — 한 레이어 더 아래로 내려가야 한다. 다음 라운드의 주제.

함정 기록 (요약)

포지션

cuDNN 의 80–90% 속도를 300 줄로. 같은 일을 CUDA 로 하면 레슨 6 의 FA v1 (~650 줄) 의 5–10 배 + 새 GPU 마다 수작업 재튜닝. Triton 을 배우는 이유 자체가 이 ROI 에 있다.

다음

Backward + autograd (register_autograd) → GQA (MQA/grouped) → persistent kernel + async copy 로 마지막 20% 닫기 → vLLM PagedAttention 포팅. Phase 2 의 재료.