| 항목 | FLOP | HBM bytes |
|---|---|---|
| QKT | 2N2d | 2Nd + N2 |
| softmax | O(N2) | 3N2 |
| P·V | 2N2d | N2 + 2Nd |
| total | Θ(N2d) | Θ(N2) |
// 3 kernel launches gemm(S = Q @ K.T * tau); // N2d softmax_row(P = exp(S - m) / d); gemm(O = P @ V); // N2d
중간 S/P를 HBM에 write+read ⇒ round-trip 3·N2 bytes.
해법 → online softmax: running (m, ℓ) 유지 → streaming fusion 가능. ↗ V07 §2
| N | S bytes (FP16) | FLOP / seq |
|---|---|---|
| 1K | 2 MB | 0.13 GF |
| 8K | 128 MB | 8.4 GF |
| 128K | 32 GB | 2.1 TF |
d=128, single head. N=128K → HBM이 병목, S 저장 불가능.
# pass 1: streaming (m, ℓ) m, l = -inf, 0 for x in row: m_new = max(m, x) l = l * exp(m - m_new) + exp(x - m_new) m = m_new # pass 2: output y for i: y[i] = exp(x[i] - m) / l
HBM 트래픽 2N → safe보다 1N 절감.
| 방식 | pass | HBM/row | fusion 가능? |
|---|---|---|---|
| naive | 2 | 2N | exp overflow |
| safe 3-pass | 3 | 3N | no |
| online 2-pass | 2 | 2N | partial |
| fused 1-pass | 1 | 1N | yes (FA) |
log2e ≈ 1.4427 compile-time 상수. τ 에 사전 흡수 가능.
FA backward 위해 forward 끝에 LSE만 저장: N · 4 bytes → S 전체 (N2) 대신.
같은 shift 삽입 argument로 O도 동일 correction factor로 결합.
| 성질 | 성립? |
|---|---|
| 교환 (a ⊕ b = b ⊕ a) | ✓ (max, +) |
| 결합 ((a ⊕ b) ⊕ c = a ⊕ (b ⊕ c)) | ✓ |
| 항등원 (−∞, 0) | ✓ |
| tree reduction 가능 | ✓ |
d=64, M=192KB → d2/M ≈ 0.022 → 45× HBM 절감. Dao et al. 2022 Thm 2.
| 저장 대상 | naive | FA |
|---|---|---|
| S, P | N2·2B | 0 (SRAM) |
| O | N·d·2B | N·d·2B |
| LSE (bwd) | — | N·4B |
| total HBM | O(N2) | O(N) |
outer: for j in 0..N step Bc ← K/V tile
load K[j:j+Bc], V[j:j+Bc] to smem
inner: for i in 0..N step Br ← Q tile
load Q[i:i+Br], (m,ℓ,O)[i] from HBM
compute block, update (m,ℓ,O)
store back (m,ℓ,O)[i]
Q를 매 outer iter마다 reload → FA2에서 역전.
위 갱신식 = (m, ℓ, O) ⊕ (rowmax(Sij), rowsum(Pij), PijVj). 매 KV tile당 1회 merge. ↗ V07 §3
| 지표 | 값 |
|---|---|
| FLOP | 4N2d (naive 동일) |
| HBM bytes | Θ(N2d2/M) |
| SMEM | O(Brd + Bcd) |
| activation HBM | O(N) (LSE만) |
grid = (batch × head, 1)
1 CTA = 1 (batch, head)
각 CTA 안에서 Q tile loop (inner)
K/V tile loop (outer) ← 병렬축 아님
병렬축 좁음 → short-seq에서 SM 활용률 낮음.
| buffer | size |
|---|---|
| Q_tile | Br·d·2B |
| K_tile[stage] | Bc·d·2B |
| V_tile[stage] | Bc·d·2B |
| S_tile | Br·Bc·4B |
# FA1: K/V outer, Q inner for j in range(0, N, Bc): Kj = load_smem(K[j:j+Bc]) # cp.async Vj = load_smem(V[j:j+Bc]) for i in range(0, N, Br): Qi = load_smem(Q[i:i+Br]) mi = load(m[i:i+Br]) # from HBM li = load(l[i:i+Br]) Oi = load(O[i:i+Br]) Sij = (Qi @ Kj.T) * tau # Br×Bc if causal: Sij += mask(i, j) m_new = maximum(mi, rowmax(Sij)) Pij = exp(Sij - m_new[:, None]) alpha = exp(mi - m_new) l_new = alpha*li + rowsum(Pij) O_new = (Oi * (alpha*li)[:,None] + Pij @ Vj) / l_new[:,None] store(m[i:i+Br] = m_new) store(l[i:i+Br] = l_new) store(O[i:i+Br] = O_new)
# after loop: normalize # FA1 rescale 이미 매 step → 추가 없음 LSE[i] = m[i] + log(l[i]) # for bwd # FA2는 여기서 한번에 O /= l
O와 dO의 row-wise dot → Di 먼저 1-pass 계산.
| 방식 | 저장 | bwd 작업 |
|---|---|---|
| Naive | P (N2) | load only |
| FA | LSE (N) | S, P recompute |
# Step 1: D_i precompute (all i) for i: D[i] = sum(O[i] * dO[i]) # Step 2: KV-outer loop (accumulate dK, dV) for j in range(0, N, Bc): dKj = 0; dVj = 0 Kj = load(K[j:j+Bc]) Vj = load(V[j:j+Bc]) for i in range(0, N, Br): Qi = load(Q[i:i+Br]) LSEi = load(LSE[i:i+Br]) dOi = load(dO[i:i+Br]) Di = load(D[i:i+Br]) Sij = Qi @ Kj.T * tau Pij = exp(Sij - LSEi[:, None]) # recompute dPij = dOi @ Vj.T dSij = Pij * (dPij - Di[:, None]) * tau dVj += Pij.T @ dOi dKj += dSij.T @ Qi atomic_add(dQ[i:], dSij @ Kj) write(dK[j:] = dKj) write(dV[j:] = dVj)
| 방식 | Mem | FLOP bwd |
|---|---|---|
| Naive (P save) | O(N2) | 5N2d |
| FA recompute | O(N) | 7N2d |
FLOP 40% ↑ but HBM 트래픽 대폭 ↓ → 실행시간 FA 빠름.
FA1: for j(KV): for i(Q): ...
Q를 매번 reload
FA2: for i(Q): for j(KV): ...
Q는 register/smem에 고정
(m,ℓ,O)도 register 영속
예: N=64K, BH=1, H=16, Br=64 → 1024×16 = 16K blocks → 132 SM 포화.
FA1 (Q-split): warp0 ← Q[0..15] warp1 ← Q[16..31] ... 각 warp가 전체 K loop smem로 O 통신 (매 KV tile) FA2 (K-split): 모든 warp가 같은 Q_i K/V를 warp별 분할 → 독립 partial O loop 끝에 warp-간 O reduce (smem 1회)
warp 간 통신 횟수: K_steps → 1회.
| 연산 | FA1 | FA2 |
|---|---|---|
| rescale | ~25% | ~10% |
| softmax exp | ~15% | ~15% |
| mma (GEMM) | ~60% | ~75% |
Dao 2023 Table 1 · A100 FP16 d=128.
| HW / len | FA1 | FA2 |
|---|---|---|
| A100 FP16 d=128 N=2K | 124 TF | 225 TF |
| H100 FP16 d=128 | — | 350 TF |
| 속도 ratio | 1.0× | 1.8× |
저자 보고 · Dao 2023 Fig 5.
grid = (N/Br, batch × head) 1 CTA = 1 Q-tile × 1 (batch, head) block threads = 128 or 256 warps = 4 or 8 각 warp가 Kj·Vj의 일부 담당 (K-split)
| 위치 | 내용 |
|---|---|
| register | Q_tile (Br×d) |
| register | (m, ℓ, O) per warp / thread |
| smem[stage] | K_tile, V_tile (Bc×d) |
| smem | warp-간 O reduce buffer |
m = -inf (Br) l = 0 (Br) O = 0 (Br × d) Qi = load_smem(Q[i_block]) # once Qi = load_reg(Qi)
# Q outer: Q_i is fixed for entire CTA life for j in range(0, N, Bc): Kj = load_smem_async(K[j:j+Bc]) # cp.async prefetch Vj = load_smem_async(V[j:j+Bc]) cp_async_wait() Sij = (Qi @ Kj.T) * tau if causal: if j_end <= i_start: skip else: Sij += causal_mask(i, j) m_new = maximum(m, rowmax(Sij)) Pij = exp2f((Sij - m_new[:, None]) * log2e) alpha = exp2f((m - m_new) * log2e) l = alpha*l + rowsum(Pij) O = alpha[:, None]*O + Pij @ Vj # no /l m = m_new # end of KV loop: single rescale O = O / l[:, None] LSE = m + log(l) write_hbm(O[i_block], LSE[i_block])
# Each warp w holds partial (m_w, l_w, O_w) # after KV loop: reduce across warps smem_write(m_w, l_w, O_w, warp_id=w) __syncthreads() # warp 0 reduces if warp_id == 0: m_all = max over w for w: alpha_w = exp(m_w - m_all) l_all = sum(alpha_w * l_w) O_all = sum(alpha_w * O_w) / l_all write_hbm(O_all)
| head d | Br | Bc | stages |
|---|---|---|---|
| 64 | 64 | 64 | 3 / 2 |
| 128 | 64 | 64 | 3 / 2 |
| 256 | 64 | 32 | 2 / 2 |
A100 / H100 기준. d=256에서 smem 한계 도달.
CTA threads = 384 (3 warpgroup, 12 warp)
WG0 (warp 0..3) = producer (TMA issue)
WG1 (warp 4..7) = consumer A (WGMMA)
WG2 (warp 8..11) = consumer B (WGMMA)
↑ pingpong
# producer warp (WG0) for s in range(stages): mbarrier.expect_tx(K_full[s], bytes) tma.load(Kj[s], K_full[s]) mbarrier.expect_tx(V_full[s], bytes) tma.load(Vj[s], V_full[s]) # consumer warp (WG1/2) for j_step: mbarrier.wait(K_full[s], parity) wgmma.mma_async(S += Qi @ Kj[s].T) wgmma.wait_group(0) mbarrier.arrive(K_empty[s]) softmax_update(m, l, S) mbarrier.wait(V_full[s], parity) wgmma.mma_async(O += P @ Vj[s]) wgmma.wait_group(1) mbarrier.arrive(V_empty[s])
| 연산 | A operand | B operand | Acc |
|---|---|---|---|
| QKT | Q (reg) | K (smem) | S (reg FP32) |
| PV | P (reg) | V (smem) | O (reg FP32) |
B operand는 smem 직접 접근 (shape m64n*k16). ↗ V04 §8
| config | TFLOPS |
|---|---|
| H100 FP16 d=128 | 740 |
| H100 FP8 d=128 | 1,200 |
| H100 FP8 peak (WGMMA) | 1,979 |
Shah et al. 2024 · H100 SXM · 저자 보고.
time ►►► WG1: |GEMM1|soft|GEMM2|GEMM1|soft|GEMM2| WG2: |GEMM1|soft|GEMM2|GEMM1|soft| WG0: TMA load K[j+1], V[j+1] 지속 GEMM1 = Q @ Kj.T GEMM2 = P @ Vj soft = online softmax update
WG 1개 안 slot 파이프: slot 0 | GEMM1 | soft | GEMM2 | slot 1 | | GEMM1 | soft | GEMM2 | slot 2 | | | GEMM1 | soft | GEMM2 | → 다음 iter Sj+1 GEMM1 미리 시작
# WG1 and WG2 run symmetrically on different Q rows # 조건 synchronization via named barriers for j in range(0, N, Bc): wait_for(Kj[s]) # TMA done if wg == 1: barrier.wait(wg1_gemm_turn) wgmma.issue(S = Qi @ Kj.T) barrier.arrive(wg2_gemm_turn) softmax_update(...) barrier.wait(wg1_gemm_turn) wgmma.issue(O += P @ Vj) barrier.arrive(wg2_gemm_turn) else: # wg == 2 # 교차된 순서 softmax_update(...) barrier.wait(wg2_gemm_turn) wgmma.issue(S2 = Qi2 @ Kj.T) ... arrive(Kj_empty[s])
| 기법 | WGMMA util |
|---|---|
| FA2 style (1 WG) | ~55% |
| WS + TMA | ~75% |
| WS + Pingpong | ~90% |
Shah et al. 2024 · H100 FP16 d=128. 정성적 추정.
| 관점 | v1 (2022) | v2 (2023) | v3 (2024) |
|---|---|---|---|
| 핵심 | tile + online softmax | + deferred rescale + N축 병렬 | + TMA · WGMMA + WS · FP8 |
| loop order | KV outer, Q inner | Q outer, KV inner | Q outer + WS |
| rescale | 매 KV step O/ℓ | 최후 1회 | 최후 1회 |
| 병렬축 | batch × head | + seq (N/Br) | + tile scheduler |
| warp split | Q-split | K-split | WS producer/consumer |
| async copy | cp.async (opt) | cp.async | TMA |
| MMA | mma.sync | mma.sync | WGMMA async |
| target SM | sm_80 | sm_80, sm_90 | sm_90 |
| FP8 지원 | — | — | E4M3 + incoherent |
| dQ atomic | yes (KV outer) | no (Q outer) | no |
| Hopper features | — | — | TMA, WGMMA, mbarrier, setmaxnreg |
| HW / config | FA1 | FA2 | FA3 |
|---|---|---|---|
| A100 FP16 d=128 N=2K | 124 TF | 225 TF | — |
| H100 FP16 d=128 | — | 350 TF | 740 TF |
| H100 FP8 d=128 | — | — | 1,200 TF |
| H100 peak FP16 WGMMA | 989 TF | ||
| H100 peak FP8 WGMMA | 1,979 TF | ||
저자 보고치. Dao 2022·2023, Shah 2024.
| FA1 | FA2 | FA3 | |
|---|---|---|---|
| rescale | ~25% | ~10% | ~5% |
| exp (SFU) | ~15% | ~15% | ~12% |
| mma | ~60% | ~75% | ~83% |
세 버전 모두 forward activation은 LSE (N) 만 저장. S materialization 없음 공통.
| 변형 | FA1 | FA2 | FA3 |
|---|---|---|---|
| Causal | ✓ | ✓ | ✓ |
| GQA/MQA | — | ✓ | ✓ |
| Sliding window | — | ✓ | ✓ |
| ALiBi bias | — | ✓ | ✓ |
| Soft-cap (Gemma) | — | — | ✓ |
| varlen (packed) | — | ✓ | ✓ |
| FP8 | — | — | ✓ |
KV block j_start..j_end vs Q block i_start..i_end: case A: j_end ≤ i_start → full-keep, mask 불필요 case B: j_start > i_end → full-skip (전부 −∞) case C: overlap (diagonal) → per-element mask
for j in range(0, N, Bc): j_end = j + Bc if j > i_end: # case B continue # full skip Sij = (Qi @ Kj.T) * tau if j_end > i_start: # case C Sij = where( row >= col_in_global, Sij, -inf ) # else: case A, no mask # continue FA update
| 방법 | 비용 |
|---|---|
| additive (+ −∞) | FP op 1회 |
| select (where) | predicate compare |
| precomputed mask mem | smem/HBM load |
FA2는 compile-time per-block 분기 → mask 자체 load 없음.
| mask | 정의 |
|---|---|
| causal | j ≤ i |
| sliding window | i−W ≤ j ≤ i |
| prefix-LM | prefix 내부 full, suffix causal |
| document-packed | 같은 doc_id 내에서만 |
| tree (spec decode) | branch 구조 기반 |
time step 0: rank r holds (K_r, V_r)
compute Q_r · K_r
step 1: send K_r → r+1, recv K_{r−1}
compute Q_r · K_{r−1}
step 2: 순환 P−1 회
end: Q_r 가 모든 rank의 KV 본 셈
# per rank r Qr = load_local() Kr_cur = Kr; Vr_cur = Vr (m, l, O) = (-inf, 0, 0) for step in range(P): # async send/recv while compute isend(Kr_cur, Vr_cur, dst=r+1) irecv(K_next, V_next, src=r-1) # FA2-style inner update with K_cur, V_cur S = Qr @ Kr_cur.T * tau if causal: mask(S, step) (m, l, O) ← online_update(m, l, O, S, Vr_cur) wait_comm() Kr_cur, Vr_cur = K_next, V_next O = O / l; write(O_r)
compute: |FA(step 0)|FA(step 1)|FA(step 2)|
comm: |send/recv→|send/recv→|send/recv→|
↑ NCCL SendRecv async
d=128, Np=8K, P=8 → compute 시간 > comm 시간이 일반.
| 방식 | 특징 |
|---|---|
| Ring Attention (Liu '23) | Q 고정, KV 순환 |
| Striped Attention | causal load balance |
| Tree Attention (long-ctx) | tree-based reduce |
| Context Parallel | Megatron CP = ring variant |
| 이름 | Hkv | 특징 |
|---|---|---|
| MHA | H | 표준 (GPT-2/3) |
| MQA | 1 | cache 1/H배 (PaLM) |
| GQA | g (e.g. 8) | 중간 (Llama-2 70B) |
| MLA | compressed | cache r/D (DeepSeek) ↗ V07 §15 |
decode는 memory bound → MQA/GQA로 AI 끌어올려 Tensor Core 활용.
| 모델 | H | Hkv | 방식 |
|---|---|---|---|
| Llama-2 7B | 32 | 32 | MHA |
| Llama-2 70B | 64 | 8 | GQA |
| Llama-3 8B | 32 | 8 | GQA |
| Mistral 7B | 32 | 8 | GQA |
| PaLM | — | 1 | MQA |
MHA kernel layout: grid = (N/Br, batch × H) 1 CTA = 1 Q-head, B_r rows GQA kernel layout: grid = (N/Br, batch × H_kv) 1 CTA = (H/g) Q-heads group Q_tile = Br × (H/g) × d ← register ↑ K/V tile = Bc × d (shared)
block table의 Hkv·d per page ⇒ GQA → page 크기 축소. 같은 메모리로 더 많은 요청 수용. ↗ V07 §16
c (rank r « d) 로 압축. cache는 c만 저장.
attention 시 up-projection WUK, WUV 곱해 복원.
| 방식 | per-token KV bytes |
|---|---|
| MHA (FP16) | 2·H·d·2 |
| GQA g=8 | 2·8·d·2 |
| MLA r=512 | 512·2 + RoPE |
DeepSeek-V2: r=512, H·d=16384 → ~32× 절감.
# decode: 1 query token q_nope = W_Q @ x # (H, d) q_rope = apply_rope(W_QR @ x) q_abs = q_nope @ W_UK # absorb # attend to cached c, K_rope for page in block_table: c_page = load(c[page]) # (Bp, r) K_rope_page = load(K_rope[page]) S_nope = q_abs @ c_page.T * tau S_rope = q_rope @ K_rope_page.T S = S_nope + S_rope # online softmax O += softmax_update(...) # V도 c 기반 복원 O = O @ W_UV # or absorbed
| 장 | 단 |
|---|---|
| cache 32× 절감 | 구현 복잡 (RoPE 분리) |
| long context 유리 | absorb trick 필요 |
| 품질 MHA 수준 | 학습부터 MLA로 해야 |
┌─────────────────────────────────────────────────┐
│ physical block_pool (device HBM) │
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
│ │ p0 │ p1 │ p2 │ p3 │ p4 │ p5 │ p6 │ p7 │ ... │
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
│ 각 block = Bp×Hkv×D×2B (K,V) │
└─────────────────────────────────────────────────┘
▲ ▲ ▲
│ │ │ (referenced by table)
┌─────────┴───────┴───────┴────────────┐
│ block_table[req_id] │
│ req A: [p3, p1, p5, ...] │
│ req B: [p3, p1, p2, p7, ...] │
│ req C: [p6, p4, p0, ...] │
│ ↑ ↑ │
│ └───┴─ shared prefix pages │
└──────────────────────────────────────┘
logical token t → page:
page_idx = block_table[req][t // Bp]
offset = t % Bp
K_ptr = block_pool[page_idx][offset]
| field | 타입 | 의미 |
|---|---|---|
| block_pool | Tensor [P, Bp, Hkv, d, 2] | pre-allocated |
| block_table | int32 [R, Lmax/Bp] | req → page list |
| seq_len | int32 [R] | current length |
| ref_count | int32 [P] | page refcount (CoW) |
| free_list | queue | 사용 가능 page idx |
| 방식 | internal frag | external frag |
|---|---|---|
| contiguous max_len | 60~80% | R↑ 시 OOM |
| contiguous actual | ~0% | 매 요청 alloc/free |
| paged Bp=16 | <Bp·bytes | 0 (page 풀) |
vLLM 논문: 메모리 활용률 >96%.
| Bp | 장점 | 단점 |
|---|---|---|
| 16 | frag ↓, 공유 세밀 | page table ↑ |
| 32 | 중간 | 중간 |
| 128 | HBM throughput ↑ | frag ↑, 공유 거침 |
alloc_page(req):
if free_list empty:
evict() or preempt()
p = free_list.pop()
ref_count[p] = 1
return p
free_req(req):
for p in block_table[req]:
ref_count[p] -= 1
if ref_count[p] == 0:
free_list.push(p)
grid = (R_active, H_q) 1 CTA = 1 (request, q_head) 1 CTA가 자기 req의 모든 page를 loop
# 1 thread block = 1 query token, 1 head q = load(Q[req, head]) # (d,) (m, l, O) = (-inf, 0, 0) for t_block in range(0, seq_len, Bp): page = block_table[req][t_block // Bp] K_p = load(block_pool[page][.., K]) V_p = load(block_pool[page][.., V]) s = q @ K_p.T * tau # (Bp,) if causal: mask_suffix(s, t_block) (m, l, O) = online_update(m, l, O, s, V_p) O /= l write(O_out[req, head], O)
stage 1: grid = (R × H × P)
CTA p가 partition p의 page만 loop
→ partial (m_p, l_p, O_p)
stage 2: grid = (R × H)
online softmax merge P partials:
M = max_p m_p
α_p = exp(m_p − M)
L = Σ_p α_p · l_p
O_out = Σ_p α_p · O_p / L
# page 단위 bulk load for stage in stages: # TMA 2D: (Bp, H_kv, d) tma_load(smem[stage], block_pool[page_idx[stage]]) mbarrier.expect_tx(...) # compute consumer for stage in stages: mbarrier.wait(stage) compute(smem[stage])
page_idx는 block_table에서 CTA 시작 시점에 한 번에 로드.
| 지표 | 값 |
|---|---|
| memory 활용 | 96%+ |
| decode latency | FA 대비 ~1.1× |
| throughput (R↑) | FA 대비 2~4× |
| page indirection 오버헤드 | ~5% |
Kwon et al. 2023 · vLLM paper Fig.
root
├─ "You are"───┬─ " a helpful..." [req A,B,C] ref=3
│ └─ " an expert..." [req D] ref=1
│
└─ "###user\n" ├─ "What is GPU..." [req E] ref=1
└─ "How to..." [req F,G] ref=2
각 노드 = (token seq, KV page list, ref_count, LRU timestamp)
각 edge = token prefix (shared)
| op | 동작 |
|---|---|
| match_prefix(tokens) | (node, matched_len) |
| insert(tokens[matched:]) | 새 child 추가 |
| lock_ref(node) | ref_count ++ |
| unlock_ref(node) | ref_count −− |
| evict() | LRU leaf with ref=0 |
match_prefix(tokens):
node = root
matched = 0
while matched < len(tokens):
child = node.find_child(tokens[matched])
if child is None: break
k = common_prefix_len(
child.tokens,
tokens[matched:])
matched += k
if k < len(child.tokens):
break # partial match in edge
node = child
return (node, matched)
| 시나리오 | hit rate |
|---|---|
| 독립 질문 (cold) | 0~10% |
| chat (system prompt) | 60~80% |
| agent (tool loop) | 70~90% |
| few-shot batch | 90%+ |
SGLang 2024 보고.
| 방식 | 장 | 단 |
|---|---|---|
| trie (radix) | 부분 prefix 공유 | tree 관리 복잡 |
| hash (vLLM prefix) | O(1) lookup | block align 필요 |
start │ ├─ prefill or training? │ ├─ sm_90+? → FA3 (WGMMA, TMA, FP8) │ ├─ sm_80? → FA2 │ └─ older? → FA1 (legacy) │ ├─ decode (1 query token)? │ ├─ R 작음 · 긴 seq? → PagedAttn v2 (split-K) │ ├─ R 큼 · 짧은 seq? → PagedAttn v1 │ └─ prefix 공유? → RadixAttn + Paged │ ├─ context > 100K · multi-GPU? │ └─ Ring Attention (CP) │ ├─ model cache 압박? │ ├─ train부터 설계 가능? → MLA │ └─ 그 외 → GQA g=8 │ └─ speculative decode 맞추기? → tree mask FA3
| phase | bound | 최적화 |
|---|---|---|
| prefill (긴) | compute | FA3 FP8, WGMMA |
| decode | memory | MQA/GQA/MLA |
| long-ctx | HBM + comm | Ring, Paged |
| batch 큼 | SM 활용 | continuous batch |
| year | landmark |
|---|---|
| 2017 | Attention is all you need |
| 2018 | Online softmax (Milakov) |
| 2022 | FA v1, Ring Attn |
| 2023 | FA v2, GQA, PagedAttn |
| 2024 | FA v3, MLA, RadixAttn |