cudatraining · 학습 기록

LESSON 05 · 2026.04.18 · T4

Softmax & Fusion — Flash Attention 의 수학적 절반

커널 3 개를 1 개로 합치면 정확히 2 배 빨라진다. 그리고 그 옆에서 태어난 online softmax 는 Flash Attention 의 심장이 된다.

GPU · T4 peak HBM · 320 GB/s sweep · 12 runs

세 버전

결과

Nv1 msv1 GB/sv2 msv2 GB/sv3 msv3 GB/s
10240.2932290.1452310.141356
20480.5302530.2852360.279361
40961.0672520.5562410.764264
81922.2122431.470183 ↓1.669241

교훈 1 · Fusion 의 약속은 관찰 가능

이론 HBM trips: v1 = 4, v2 = 2 → v2 가 2 배 빨라야 한다. 관찰: 2.02× (N=1024), 1.86× (N=2048), 1.92× (N=4096). 거의 이론치. LLM inference 최적화의 절반 이상이 fusion 인 이유가 이 숫자에 있다.

교훈 2 · Occupancy cliff @ N=8192

v2 smem usage = N × 4 bytes. N=8192 면 32 KB. T4 의 SM 당 64 KB smem 에서 블록이 2 개 만 상주 → threads/SM 이 1024 에서 512 로 떨어지며 occupancy 50%.

bandwidth 가 부족한 게 아니라 latency hiding 할 warp 가 부족한 것. 이게 shared memory 많이 쓰는 커널의 보편적 함정이다.

교훈 3 · L2 가 v3 의 "extra read" 를 흡수

이론상 v3 는 3 trips 로 v2 보다 1.5 배 느려야 한다. 그런데 N=1024 에서 v3 가 살짝 더 빠르다. 이유: 행 = 4 KB → L2 (4 MB) 에 여유 → pass 1 의 입력이 pass 2 에서 L2 hit. 유효 GB/s 356 (이론의 111%) 이 그 증거. N 이 커지면 L2 가 밀려나며 혜택 감소, v2 가 재우세.

교훈 4 · v3 의 진짜 가치 — online update 공식

v3 자체가 우리 크기에선 v2 보다 빠른 게 아니다. 하지만:

new_max = max(m1, m2)
new_sum = s1 * exp(m1 - new_max) + s2 * exp(m2 - new_max)

FA 는 여기에 tiled matmul fusion 을 얹어 중간 행렬 (P = softmax(Q@K^T)) 을 HBM 에 내리지 않는다. v3 을 구현 = FA 의 수학적 절반 이해. 나머지 절반은 attention-specific tiling — 다음 레슨.

Regime 지도

작은 N (smem 여유)     v2       fusion 완승 (2×)
중간 N (L2 히트)       v2 ≈ v3  L2 가 v3 의 3번째 read 흡수
큰 N (smem 포화)       v3 / FA  v2 occupancy 붕괴
매우 큰 N (attention)  FA       중간 행렬 materialize 불가
LLM serving 번역

Decode phase: seq_len 짧음 → v2 같은 단순 fused softmax 로 충분. 병목은 KV cache HBM 로드.
Prefill phase: seq_len 수천~수만. Attention score 행렬이 거대 → Flash Attention 필수. 우리 v3 online 수식이 FA 의 뼈대.