LESSON 05 · 2026.04.18 · T4
Softmax & Fusion — Flash Attention 의 수학적 절반
커널 3 개를 1 개로 합치면 정확히 2 배 빨라진다. 그리고 그 옆에서
태어난 online softmax 는 Flash Attention 의 심장이 된다.
GPU · T4
peak HBM · 320 GB/s
sweep · 12 runs
세 버전
- v1 naive — 3 개 커널 순차 (max → sum → divide). HBM 4 trips/elem
- v2 fused — 1 커널, shared memory 에 행 캐싱. HBM 2 trips/elem
- v3 online — 1 커널 + normalize pass. 임의 N 지원. HBM 3 trips/elem
결과
| N | v1 ms | v1 GB/s | v2 ms | v2 GB/s | v3 ms | v3 GB/s |
| 1024 | 0.293 | 229 | 0.145 | 231 | 0.141 | 356 |
| 2048 | 0.530 | 253 | 0.285 | 236 | 0.279 | 361 |
| 4096 | 1.067 | 252 | 0.556 | 241 | 0.764 | 264 |
| 8192 | 2.212 | 243 | 1.470 | 183 ↓ | 1.669 | 241 |
교훈 1 · Fusion 의 약속은 관찰 가능
이론 HBM trips: v1 = 4, v2 = 2 → v2 가 2 배 빨라야 한다. 관찰: 2.02× (N=1024), 1.86× (N=2048), 1.92× (N=4096). 거의 이론치. LLM inference 최적화의 절반 이상이 fusion 인 이유가 이 숫자에 있다.
교훈 2 · Occupancy cliff @ N=8192
v2 smem usage = N × 4 bytes. N=8192 면 32 KB. T4 의 SM 당 64 KB smem 에서 블록이 2 개 만 상주 → threads/SM 이 1024 에서 512 로 떨어지며 occupancy 50%.
bandwidth 가 부족한 게 아니라 latency hiding 할 warp 가 부족한 것. 이게 shared memory 많이 쓰는 커널의 보편적 함정이다.
교훈 3 · L2 가 v3 의 "extra read" 를 흡수
이론상 v3 는 3 trips 로 v2 보다 1.5 배 느려야 한다. 그런데 N=1024 에서 v3 가 살짝 더 빠르다. 이유: 행 = 4 KB → L2 (4 MB) 에 여유 → pass 1 의 입력이 pass 2 에서 L2 hit. 유효 GB/s 356 (이론의 111%) 이 그 증거. N 이 커지면 L2 가 밀려나며 혜택 감소, v2 가 재우세.
교훈 4 · v3 의 진짜 가치 — online update 공식
v3 자체가 우리 크기에선 v2 보다 빠른 게 아니다. 하지만:
- 임의의 N 에 대응 가능 (v2 는 N > 12288 에서 실패)
- online update 공식이 Flash Attention 의 output 업데이트 공식과 같다
new_max = max(m1, m2)
new_sum = s1 * exp(m1 - new_max) + s2 * exp(m2 - new_max)
FA 는 여기에 tiled matmul fusion 을 얹어 중간 행렬 (P = softmax(Q@K^T)) 을 HBM 에 내리지 않는다. v3 을 구현 = FA 의 수학적 절반 이해. 나머지 절반은 attention-specific tiling — 다음 레슨.
Regime 지도
작은 N (smem 여유) v2 fusion 완승 (2×)
중간 N (L2 히트) v2 ≈ v3 L2 가 v3 의 3번째 read 흡수
큰 N (smem 포화) v3 / FA v2 occupancy 붕괴
매우 큰 N (attention) FA 중간 행렬 materialize 불가
LLM serving 번역
Decode phase: seq_len 짧음 → v2 같은 단순 fused softmax 로 충분. 병목은 KV cache HBM 로드.
Prefill phase: seq_len 수천~수만. Attention score 행렬이 거대 → Flash Attention 필수. 우리 v3 online 수식이 FA 의 뼈대.
LESSON 05 · 2026.04.18 · T4
Softmax & Fusion — the mathematical half of Flash Attention
Fuse three kernels into one and you get exactly 2× speedup. And online softmax, born next to it, becomes the heart of Flash Attention.
GPU · T4
peak HBM · 320 GB/s
sweep · 12 runs
Three versions
- v1 naive — 3 kernels sequential (max → sum → divide). HBM: 4 trips/elem
- v2 fused — 1 kernel, row cached in shared memory. HBM: 2 trips/elem
- v3 online — 1 kernel + normalize pass. Any N. HBM: 3 trips/elem
Results
| N | v1 ms | v1 GB/s | v2 ms | v2 GB/s | v3 ms | v3 GB/s |
| 1024 | 0.293 | 229 | 0.145 | 231 | 0.141 | 356 |
| 2048 | 0.530 | 253 | 0.285 | 236 | 0.279 | 361 |
| 4096 | 1.067 | 252 | 0.556 | 241 | 0.764 | 264 |
| 8192 | 2.212 | 243 | 1.470 | 183 ↓ | 1.669 | 241 |
Lesson 1 · The promise of fusion is observable
Theoretical HBM trips: v1 = 4, v2 = 2 → v2 should be 2× faster. Observed: 2.02× (N=1024), 1.86× (N=2048), 1.92× (N=4096). Very close to theory. This is why more than half of LLM-inference optimization is fusion.
Lesson 2 · Occupancy cliff @ N=8192
v2 smem usage = N × 4 bytes. At N=8192, that's 32 KB. Against T4's 64 KB smem per SM, only 2 blocks reside → threads/SM drops from 1024 to 512 → 50% occupancy.
It's not a bandwidth shortage — it's a shortage of warps to hide latency. This is the generic trap of shared-memory-heavy kernels.
Lesson 3 · L2 absorbs v3's "extra read"
In theory v3 should be 1.5× slower than v2 because of 3 trips. Yet at N=1024, v3 is actually slightly faster. Reason: a row is 4 KB → fits comfortably in L2 (4 MB) → pass 1's input hits L2 in pass 2. Effective 356 GB/s (111% of theoretical) is the evidence. As N grows, L2 gets evicted and the benefit fades, v2 retakes the lead.
Lesson 4 · v3's real value — the online update formula
v3 isn't faster than v2 at our sizes. But:
- Supports arbitrary N (v2 fails for N > 12288)
- Its online update formula is the output-update formula of Flash Attention
new_max = max(m1, m2)
new_sum = s1 * exp(m1 - new_max) + s2 * exp(m2 - new_max)
FA layers tiled matmul fusion on top of this so the intermediate matrix (P = softmax(Q@K^T)) never lands in HBM. Implementing v3 = understanding the mathematical half of FA. The other half is attention-specific tiling — next lesson.
Regime map
small N (smem slack) v2 fusion clean win (2×)
mid N (L2 hits) v2 ≈ v3 L2 absorbs v3's 3rd read
large N (smem saturated) v3 / FA v2's occupancy collapses
very large N (attention) FA intermediate matrix cannot materialize
LLM-serving translation
Decode phase: short seq_len → a simple fused softmax like v2 suffices. The bottleneck is loading KV cache from HBM.
Prefill phase: seq_len in the thousands to tens of thousands. The attention score matrix is huge → Flash Attention is mandatory. Our v3 online formula is the skeleton of FA.