CUDA 18권 · T2 KERNEL 패턴 · CONTENT-FIRST · A4 LANDSCAPE · 20p

Attention Kernel 계보

Online Softmax · FlashAttention v1·v2·v3 · Ring · GQA/MQA · MLA · PagedAttention · RadixAttention
Volume V07 / 18
Tier T2 Kernel 패턴
선행 V01 · V03 · V04 · V06
용도 attention kernel reading · kernel 설계 인출

목차 (19 content pages)

1. Attention naive — QKᵀ→softmax→AV, O(N²)p.2
2. Online Softmax 유도 (3-pass → 1-pass)p.3
3. Online Softmax 결합 공식 증명p.4
4. FA v1 핵심 아이디어 — SRAM tilep.5
5. FA1 Forward 의사코드p.6
6. FA1 Backward — recomputep.7
7. FA2 개선점 — loop 교환·rescale 절감p.8
8. FA2 Forward 의사코드p.9
9. FA3 Hopper — WGMMA·TMA·FP8p.10
10. FA3 Pingpong warpgroup 스케줄p.11
11. FA v1/v2/v3 정량 비교표p.12
12. Causal Mask block skipp.13
13. Ring Attention (context parallel)p.14
14. GQA / MQA — KV head 공유p.15
15. MLA (DeepSeek-V2) — KV latent 압축p.16
16. PagedAttention 자료구조p.17
17. PagedAttention kernel (gather-scatter)p.18
18. RadixAttention (SGLang prefix trie)p.19
19. Cheat Sheet — 변종 선택 결정트리p.20

범례

핵심 용어 (노랑)
매우 중요·표 헤더
정의·수식 박스
예시
빨강주의·실수 지점
면접·kernel reading 핵심
(!)니모닉 (권당 ≤ 5)
m / ℓrow max · row sum (denominator)
Br · BcQ-row · KV-col block
↗ V##다른 권 cross-ref
Out of Scope vLLM scheduler 자체 (↗ V16) · speculative decoding 알고리즘 상세 (↗ V08)
Milakov 2018 · FA1/2/3 papers · Ring Attn · DeepSeek-V2 · PagedAttn · SGLang

1 Scaled Dot-Product Attention 표기

Q, K, V ∈ ℝN×d
S = QKT / √d ∈ ℝN×N
P = softmax(S, axis=1)
O = P·V ∈ ℝN×d τ = 1/√d head_dim normalization · row-wise softmax

2 3-단계 분해

  1. QKT: 2·N·N·d FLOP · write S (N×N)
  2. softmax: 3 pass over S row → max, exp+sum, div
  3. P·V: 2·N·N·d FLOP → O

3 시공간 복잡도

항목FLOPHBM bytes
QKT2N2d2Nd + N2
softmaxO(N2)3N2
P·V2N2dN2 + 2Nd
totalΘ(N2d)Θ(N2)

4 Arithmetic Intensity ★ memory bound 증명

AI = FLOP / bytes
= 4N2d / (4N2 + 4Nd)
≈ d   (N ≫ d) d=64 → AI≈64 · A100 balance point ≈ 100 (19.5 TF / 1.5 TB/s · 4B FP32) → memory bound

5 S 저장의 비용

  • S, P ∈ ℝN×NO(N2) HBM
  • N=8K FP16 → 128 MB ⇒ batch×head 곱해 GB 단위
  • 학습 backward 위해 보존 필요 → activation checkpointing 요구

6 Naive CUDA 커널 분할

// 3 kernel launches
gemm(S = Q @ K.T * tau);     // N2d
softmax_row(P = exp(S - m) / d);
gemm(O = P @ V);              // N2d

중간 S/P를 HBM에 write+read ⇒ round-trip 3·N2 bytes.

7 Kernel fusion이 어려운 이유

관찰 softmax는 row 전체 max / sum 의존 → reduction barrier. row 전체 S를 다 계산해야 P 구성. 단일 streaming pass로 불가 (겉보기엔).

해법 → online softmax: running (m, ℓ) 유지 → streaming fusion 가능. ↗ V07 §2

8 수치 안정성 이슈

  • naive softmax: exp(x) overflow (x > 88 in FP32)
  • 해법: x − max 먼저 빼기 (safe softmax)
  • FP16 exp는 [−10, 10] 밖 부정확 → FP32 accumulator

9 길이 스케일링

NS bytes (FP16)FLOP / seq
1K2 MB0.13 GF
8K128 MB8.4 GF
128K32 GB2.1 TF

d=128, single head. N=128K → HBM이 병목, S 저장 불가능.

핵심 통찰: attention은 O(N2d) compute이지만 HBM O(N2) material-ization이 메모리 bound 원인. S를 쓰지 않으면 compute bound로 전환 가능.

1 Step 0: Naive softmax

yi = exi / Σj exj
  • FP32 exp overflow: x > 88
  • FP16: x > 11 overflow
  • 실전에서 그대로 쓰면 NaN

2 Step 1: Safe softmax (3-pass)

m = maxj xj
ℓ = Σj exj−m
yi = exi−m / ℓ 3 pass: m → ℓ → y · row당 3N HBM 트래픽

3 Step 2: Running 갱신식 ★

m(k) = max(m(k−1), xk)
(k) = ℓ(k−1)·em(k−1)−m(k) + exk−m(k) 새 xk가 더 크면 m 갱신 → 기존 ℓ는 eΔm < 1로 rescale

4 Step 3: 2-pass online

# pass 1: streaming (m, ℓ)
m, l = -inf, 0
for x in row:
    m_new = max(m, x)
    l     = l * exp(m - m_new) + exp(x - m_new)
    m     = m_new

# pass 2: output y
for i: y[i] = exp(x[i] - m) / l

HBM 트래픽 2N → safe보다 1N 절감.

5 Step 4: 1-pass fused (attention용) ★

핵심 softmax 출력만 필요한 경우 (OK로 정규화 지연) → (m, ℓ, O) 를 함께 streaming. O = Σ pi·vi 누산 중 m 갱신마다 rescale.
O(k) = O(k−1)·em(k−1)−m(k) + exk−m(k)·vk

6 Pass 수 비교 표

방식passHBM/rowfusion 가능?
naive22Nexp overflow
safe 3-pass33Nno
online 2-pass22Npartial
fused 1-pass11Nyes (FA)

7 Associativity 결합법칙

정의 결합자 ⊕: (ma, ℓa) ⊕ (mb, ℓb) = (M, L) 이 교환·결합 성립. ⇒ tree reduction / warp shuffle OK. 증명은 ↗ V07 §3.

8 exp2f scale 트릭 SFU 가속

x' = (x − m)·log2e
P = exp2f(x') __expf 대비 SFU throughput ↑ · FA3 표준

log2e ≈ 1.4427 compile-time 상수. τ 에 사전 흡수 가능.

9 LSE (log-sum-exp) 보존

LSE = log Σ exj = m + log ℓ

FA backward 위해 forward 끝에 LSE만 저장: N · 4 bytes → S 전체 (N2) 대신.

초기값: m0 = −∞, ℓ0 = 0. exp(−∞ − m_new) = 0 (IEEE 정확). 첫 step에서 safe.
online softmax 3-tuple: m·ℓ·O (max · denom · weighted sum)

1 문제 설정 ★

정의 두 disjoint 부분 집합 Sa, Sb (S = Sa ∪ Sb) 에 대해 각각 (ma, ℓa)와 (mb, ℓb) 를 이미 계산. 전체 S의 (mab, ℓab) 를 재계산 없이 구하기.

2 Partial 정의

ma = maxi∈Sa xi
a = Σi∈Sa exi−ma
mb = maxi∈Sb xi
b = Σi∈Sb exi−mb

3 결합 공식 ★★

mab = max(ma, mb)
ab = ℓa·ema−mab + ℓb·emb−mab correction factor: αa = ema−mab ∈ (0, 1]

4 증명 단계 1 — mab

mab = maxi∈S xi
= max( maxi∈Sa xi, maxi∈Sb xi )
= max(ma, mb)   □ max의 분배성 (sup over union)

5 증명 단계 2 — ℓab 확장

ab = Σi∈S exi−mab
= Σi∈Sa exi−mab + Σi∈Sb exi−mab 합의 분해 (Sa ∩ Sb = ∅)

6 증명 단계 3 — shift 삽입

Σi∈Sa exi−mab
= Σi∈Sa exi−ma+ma−mab
= ema−mab·Σi∈Sa exi−ma
= ema−mab·ℓa 지수의 곱셈 분리 → 상수 factor 반출

7 증명 단계 4 — 종합 ★

ab = ema−mab·ℓa + emb−mab·ℓb
= αa·ℓa + αb·ℓb   □ αa, αb ∈ (0, 1], 둘 중 하나는 정확히 1 (m_ab 선택된 쪽)

8 Output O 결합 ★

Oab = αa·Oa + αb·Ob Oa = Σi∈Sa exi−ma·vi (unnormalized numerator)

같은 shift 삽입 argument로 O도 동일 correction factor로 결합.

9 결합·교환성

성질성립?
교환 (a ⊕ b = b ⊕ a)✓ (max, +)
결합 ((a ⊕ b) ⊕ c = a ⊕ (b ⊕ c))
항등원 (−∞, 0)
tree reduction 가능
FP 정확성: (a ⊕ b) ⊕ c ≠ a ⊕ (b ⊕ c) in FP (부동소수점은 결합 아님). 수학적으로는 성립 · 수치적으로 미세한 차이 발생. determinism 문제는 ↗ V09 §7.
merge 3-rule: max·rescale·add (M → αa, αb → Σ)

1 3대 아이디어 ★

  1. Tile into SRAM: Q·K·V 블록을 SMEM에 올려 S materialization 회피
  2. Online softmax: (m, ℓ, O) streaming → S 저장 불필요 ↗ V07 §2
  3. Recomputation: backward에서 LSE만 저장, S/P 재계산

2 Tile 정의

  • Q를 Br row 단위로 분할 (N/Br tile)
  • K, V를 Bc col 단위로 분할 (N/Bc tile)
  • 1 CTA = 1 Q-tile 담당 (FA1)
smem ≈ (Br·d + 2·Bc·d)·2B
예: 64·64·2 + 2·64·64·2 = 24 KB A100 smem 192KB 여유 · double buffering 가능

3 HBM 트래픽 비교 ★

Naive: Θ(N·d + N2) bytes
FA: Θ(N2·d2 / M) bytes M = on-chip SRAM · d2/M « 1 이면 FA 우세

d=64, M=192KB → d2/M ≈ 0.022 → 45× HBM 절감. Dao et al. 2022 Thm 2.

4 Memory O(N) 성립 ★

저장 대상naiveFA
S, PN2·2B0 (SRAM)
ON·d·2BN·d·2B
LSE (bwd)N·4B
total HBMO(N2)O(N)

5 Loop 구조 (FA1)

outer: for j in 0..N step Bc   ← K/V tile
  load K[j:j+Bc], V[j:j+Bc] to smem
  inner: for i in 0..N step Br ← Q tile
    load Q[i:i+Br], (m,ℓ,O)[i] from HBM
    compute block, update (m,ℓ,O)
    store back (m,ℓ,O)[i]

Q를 매 outer iter마다 reload → FA2에서 역전.

6 Block-wise softmax 원리

Sij = QiKjT·τ (Br×Bc)
mnew = max(m, rowmax(Sij))
α = em−mnew
ℓ ← α·ℓ + rowsum(eSij−mnew)
O ← α·O + eSij−mnew·Vj

7 결합 공식과의 관계

위 갱신식 = (m, ℓ, O) ⊕ (rowmax(Sij), rowsum(Pij), PijVj). 매 KV tile당 1회 merge. ↗ V07 §3

8 복잡도 요약

지표
FLOP4N2d (naive 동일)
HBM bytesΘ(N2d2/M)
SMEMO(Brd + Bcd)
activation HBMO(N) (LSE만)
compute 증가 없음: FLOP는 naive와 동일. 차이는 HBM 트래픽과 활성값 저장량. FA는 IO-aware 설계.

1 입력 · 출력

  • in: Q, K, V ∈ HBM  (N×d)
  • out: O ∈ HBM (N×d), LSE ∈ HBM (N)
  • τ = 1/√d

2 Grid / CTA 구성 (FA1)

grid = (batch × head, 1)
1 CTA = 1 (batch, head)
각 CTA 안에서 Q tile loop (inner)
                K/V tile loop (outer) ← 병렬축 아님

병렬축 좁음 → short-seq에서 SM 활용률 낮음.

3 SMEM 배치

buffersize
Q_tileBr·d·2B
K_tile[stage]Bc·d·2B
V_tile[stage]Bc·d·2B
S_tileBr·Bc·4B

4 Forward 의사코드 ★★

# FA1: K/V outer, Q inner
for j in range(0, N, Bc):
    Kj = load_smem(K[j:j+Bc])     # cp.async
    Vj = load_smem(V[j:j+Bc])
    for i in range(0, N, Br):
        Qi  = load_smem(Q[i:i+Br])
        mi  = load(m[i:i+Br])     # from HBM
        li  = load(l[i:i+Br])
        Oi  = load(O[i:i+Br])

        Sij = (Qi @ Kj.T) * tau             # Br×Bc
        if causal: Sij += mask(i, j)

        m_new = maximum(mi, rowmax(Sij))
        Pij   = exp(Sij - m_new[:, None])
        alpha = exp(mi - m_new)
        l_new = alpha*li + rowsum(Pij)
        O_new = (Oi * (alpha*li)[:,None]
                   + Pij @ Vj) / l_new[:,None]

        store(m[i:i+Br] = m_new)
        store(l[i:i+Br] = l_new)
        store(O[i:i+Br] = O_new)

5 Forward 수식 (per block update)

Sij = QiKjT·τ
mnew = max(mi, rowmax(Sij))
α = emi−mnew, β = e−mnew·rowmax shift
new = α·ℓi + rowsum(eSij−mnew)
Onew = (α·ℓi·Oi + eSij−mnewVj) / ℓnew

6 비효율 지점 (FA2가 고침)

  • 매 KV tile마다 O/ℓ rescale (divide) → non-mma FLOP ~25%
  • Q_tile을 매 outer iter reload
  • (m, ℓ, O) HBM round-trip 반복

7 End-state 저장

# after loop: normalize
# FA1 rescale 이미 매 step → 추가 없음
LSE[i] = m[i] + log(l[i])   # for bwd
# FA2는 여기서 한번에 O /= l
FP16 accumulator 금지: P·V 누산은 반드시 FP32. FP16으로 누적하면 NaN 빈발. ↗ V09 §7

1 Gradient 수식 ★

dV = PT·dO
dP = dO·VT
dSij = Pij·(dPij − Di)
dQ = dS·K·τ
dK = dST·Q·τ Di = Σk Pik·dPik (row-wise dot)

2 Di precompute 동치

Di = Σj Oij·dOij
= Σj Pij·dPij O = P·V 관계에서 미분 체인 적용하면 동일

O와 dO의 row-wise dot → Di 먼저 1-pass 계산.

3 저장 vs 재계산

방식저장bwd 작업
NaiveP (N2)load only
FALSE (N)S, P recompute

4 Backward 의사코드 ★

# Step 1: D_i precompute (all i)
for i: D[i] = sum(O[i] * dO[i])

# Step 2: KV-outer loop (accumulate dK, dV)
for j in range(0, N, Bc):
    dKj = 0; dVj = 0
    Kj  = load(K[j:j+Bc])
    Vj  = load(V[j:j+Bc])

    for i in range(0, N, Br):
        Qi    = load(Q[i:i+Br])
        LSEi  = load(LSE[i:i+Br])
        dOi   = load(dO[i:i+Br])
        Di    = load(D[i:i+Br])

        Sij   = Qi @ Kj.T * tau
        Pij   = exp(Sij - LSEi[:, None])    # recompute
        dPij  = dOi @ Vj.T
        dSij  = Pij * (dPij - Di[:, None]) * tau

        dVj  += Pij.T @ dOi
        dKj  += dSij.T @ Qi
        atomic_add(dQ[i:], dSij @ Kj)

    write(dK[j:] = dKj)
    write(dV[j:] = dVj)

5 dQ atomic 이슈 ★

  • 같은 i가 여러 j-block에서 갱신 → atomicAdd
  • 해법 A: KV outer, Q inner (현재 표준)
  • 해법 B: j축 reduction buffer in smem (FA2 선호)
  • 해법 C: split-K + 2-stage reduce

6 복잡도 비교

방식MemFLOP bwd
Naive (P save)O(N2)5N2d
FA recomputeO(N)7N2d

FLOP 40% ↑ but HBM 트래픽 대폭 ↓ → 실행시간 FA 빠름.

7 Causal in bwd

  • j > i 영역 P=0 → dV, dK 기여 0
  • upper-triangular block 전체 skip
  • FLOP ~ 1/2 절감
recompute 부작용: forward의 (m, ℓ) 즉 LSE가 bwd S 재계산으로 정확히 재현되어야 함. 비결정적 dropout 금지 (mask state 저장 필요).

1 3대 변경점 ★

  1. Q outer / KV inner: 병렬축이 seq에 직접
  2. Deferred rescale: 매 step O/ℓ 대신 최후 1회
  3. Warp 분배 K-split: warp간 통신 1회로 감소

2 Loop 교환 효과

FA1: for j(KV): for i(Q): ...
     Q를 매번 reload

FA2: for i(Q):   for j(KV): ...
     Q는 register/smem에 고정
     (m,ℓ,O)도 register 영속
  • Q 로드 횟수: N/Bc → 1
  • (m, ℓ, O) HBM round-trip 제거

3 병렬축 확장

grid = (N/Br, batch·head) N축이 추가되어 short-seq에서도 SM 충분 활용

예: N=64K, BH=1, H=16, Br=64 → 1024×16 = 16K blocks → 132 SM 포화.

4 Deferred rescale 수식 ★

loop 중: ℓ' ← α·ℓ + rowsum(P)
O' ← α·O + P·V  (/ℓ 없이)
---
loop 끝: O ← O / ℓ rescale 횟수: N/Bc → 1 · non-mma FLOP 25% → 10%

5 Warp 분할 비교 ★

FA1 (Q-split):
  warp0 ← Q[0..15]    warp1 ← Q[16..31] ...
  각 warp가 전체 K loop
  smem로 O 통신 (매 KV tile)

FA2 (K-split):
  모든 warp가 같은 Q_i
  K/V를 warp별 분할 → 독립 partial O
  loop 끝에 warp-간 O reduce (smem 1회)

warp 간 통신 횟수: K_steps → 1회.

6 non-mma FLOP 비중

연산FA1FA2
rescale~25%~10%
softmax exp~15%~15%
mma (GEMM)~60%~75%

Dao 2023 Table 1 · A100 FP16 d=128.

7 Shared memory 재구성

  • K/V를 pipeline stage로 (num_stages=2~3)
  • cp.async prefetch → GEMM과 overlap ↗ V03 §6
  • Q는 register-heavy 배치 가능

8 dQ atomic 제거

핵심 FA2 backward: Q가 outer → 같은 Q_i에 대해 KV loop 안에서 누적. atomicAdd 불필요 (thread 단위 register 누적).

9 성능 정량

HW / lenFA1FA2
A100 FP16 d=128 N=2K124 TF225 TF
H100 FP16 d=128350 TF
속도 ratio1.0×1.8×

저자 보고 · Dao 2023 Fig 5.

FA2도 Ampere (sm_80) 타겟. Hopper의 WGMMA/TMA는 FA3에서 활용 → ↗ V07 §9.

1 Grid / CTA (FA2)

grid = (N/Br, batch × head)
1 CTA = 1 Q-tile × 1 (batch, head)

block threads = 128 or 256
  warps = 4 or 8
  각 warp가 Kj·Vj의 일부 담당 (K-split)

2 레지스터·SMEM 배치

위치내용
registerQ_tile (Br×d)
register(m, ℓ, O) per warp / thread
smem[stage]K_tile, V_tile (Bc×d)
smemwarp-간 O reduce buffer

3 초기화

m = -inf (Br)
l = 0   (Br)
O = 0   (Br × d)
Qi = load_smem(Q[i_block]) # once
Qi = load_reg(Qi)

4 FA2 Forward 의사코드 ★★

# Q outer: Q_i is fixed for entire CTA life
for j in range(0, N, Bc):
    Kj = load_smem_async(K[j:j+Bc])  # cp.async prefetch
    Vj = load_smem_async(V[j:j+Bc])
    cp_async_wait()

    Sij = (Qi @ Kj.T) * tau
    if causal:
        if j_end <= i_start: skip
        else: Sij += causal_mask(i, j)

    m_new = maximum(m, rowmax(Sij))
    Pij   = exp2f((Sij - m_new[:, None]) * log2e)
    alpha = exp2f((m - m_new) * log2e)
    l     = alpha*l + rowsum(Pij)
    O     = alpha[:, None]*O + Pij @ Vj  # no /l
    m     = m_new

# end of KV loop: single rescale
O = O / l[:, None]
LSE = m + log(l)
write_hbm(O[i_block], LSE[i_block])

5 K-split warp reduction

# Each warp w holds partial (m_w, l_w, O_w)
# after KV loop: reduce across warps
smem_write(m_w, l_w, O_w, warp_id=w)
__syncthreads()

# warp 0 reduces
if warp_id == 0:
    m_all = max over w
    for w: alpha_w = exp(m_w - m_all)
    l_all = sum(alpha_w * l_w)
    O_all = sum(alpha_w * O_w) / l_all
    write_hbm(O_all)

6 per-head_dim block 추천 ★

head dBrBcstages
6464643 / 2
12864643 / 2
25664322 / 2

A100 / H100 기준. d=256에서 smem 한계 도달.

head_dim ≤ 256: smem에 Br·d + 2·Bc·d 올려야 함. d=512 이상은 Q를 추가 tile로 나눠야 하는 별도 경로 필요.
FA2 변경 3키: Q외·defer·K분 (Q outer · deferred rescale · K-split warps)

1 Hopper가 요구한 재설계 ★

  1. WGMMA asynchronous → GEMM 발행 후 softmax 가능
  2. TMA async bulk copy → prefetch 복잡도 감소 ↗ V04 §4
  3. Thread Block Cluster / DSM → 추가 여지
  4. FP8 Tensor Core E4M3/E5M2 ↗ V09 §3

2 Warp Specialization 구성

CTA threads = 384 (3 warpgroup, 12 warp)
  WG0 (warp 0..3)   = producer (TMA issue)
  WG1 (warp 4..7)   = consumer A (WGMMA)
  WG2 (warp 8..11)  = consumer B (WGMMA)
    ↑ pingpong
  • WG0 reg = 40 (setmaxnreg.dec)
  • WG1, WG2 reg = 232 (setmaxnreg.inc)
  • WG1·WG2는 다른 Q-row 처리 (M축 분담)

3 TMA + mbarrier 파이프라인 ★

# producer warp (WG0)
for s in range(stages):
    mbarrier.expect_tx(K_full[s], bytes)
    tma.load(Kj[s], K_full[s])
    mbarrier.expect_tx(V_full[s], bytes)
    tma.load(Vj[s], V_full[s])

# consumer warp (WG1/2)
for j_step:
    mbarrier.wait(K_full[s], parity)
    wgmma.mma_async(S += Qi @ Kj[s].T)
    wgmma.wait_group(0)
    mbarrier.arrive(K_empty[s])
    softmax_update(m, l, S)

    mbarrier.wait(V_full[s], parity)
    wgmma.mma_async(O += P @ Vj[s])
    wgmma.wait_group(1)
    mbarrier.arrive(V_empty[s])

4 WGMMA operand 배치

연산A operandB operandAcc
QKTQ (reg)K (smem)S (reg FP32)
PVP (reg)V (smem)O (reg FP32)

B operand는 smem 직접 접근 (shape m64n*k16). ↗ V04 §8

5 FP8 incoherent processing ★

아이디어 무작위 직교 R 곱: Q' = Q·R, K' = K·R ⇒ S = Q·KT 불변. R이 outlier 분산 → block-quant scale 효율 ↑.
  • R = Hadamard (±1, 구현 단순)
  • V는 row/col block-scale (per-token · per-channel)
  • FP8 E4M3 · acc FP32

6 성능 수치

configTFLOPS
H100 FP16 d=128740
H100 FP8 d=1281,200
H100 FP8 peak (WGMMA)1,979

Shah et al. 2024 · H100 SXM · 저자 보고.

FP8 함정: forward 중 scale 동적 갱신하면 비결정. 미리 calibration scale 사용. ↗ V10 §10

1 Pingpong 동기 ★

동기 WGMMA 1회 async 발행 후 issue slot 여유 → 같은 warpgroup이 softmax 하는 동안 다른 warpgroup의 GEMM 발행. WGMMA unit 이용률 ↑.

2 2-WG 타임라인

time ►►►
WG1: |GEMM1|soft|GEMM2|GEMM1|soft|GEMM2|
WG2:        |GEMM1|soft|GEMM2|GEMM1|soft|
WG0:  TMA load K[j+1], V[j+1] 지속

GEMM1 = Q @ Kj.T
GEMM2 = P @ Vj
soft  = online softmax update

3 3-stage intra-WG

WG 1개 안 slot 파이프:
slot 0 | GEMM1 |  soft  | GEMM2 |
slot 1 |       | GEMM1  |  soft | GEMM2 |
slot 2 |       |        | GEMM1 |  soft | GEMM2 |

→ 다음 iter Sj+1 GEMM1 미리 시작

4 의사코드 (2-WG pingpong)

# WG1 and WG2 run symmetrically on different Q rows
# 조건 synchronization via named barriers
for j in range(0, N, Bc):
    wait_for(Kj[s])         # TMA done
    if wg == 1:
        barrier.wait(wg1_gemm_turn)
        wgmma.issue(S = Qi @ Kj.T)
        barrier.arrive(wg2_gemm_turn)
        softmax_update(...)
        barrier.wait(wg1_gemm_turn)
        wgmma.issue(O += P @ Vj)
        barrier.arrive(wg2_gemm_turn)
    else:   # wg == 2
        # 교차된 순서
        softmax_update(...)
        barrier.wait(wg2_gemm_turn)
        wgmma.issue(S2 = Qi2 @ Kj.T)
        ...
    arrive(Kj_empty[s])

5 mbarrier 운용

  • K_full[s], K_empty[s], V_full[s], V_empty[s]
  • parity bit로 같은 barrier 재사용 (named barrier 16개 한도)
  • WG0 expect_tx에 byte 등록 → TMA 완료 시 자동 arrive
  • WG1, WG2 간 named barrier 4개 추가 (gemm turn, softmax turn)

6 성능 분석

기법WGMMA util
FA2 style (1 WG)~55%
WS + TMA~75%
WS + Pingpong~90%

Shah et al. 2024 · H100 FP16 d=128. 정성적 추정.

7 수치 트릭 (exp2f + clamp)

x' = (S − m)·log2e·τ
P = exp2f(x')
P = clamp(P, 0, 1) (FP8 overflow 방지)
named barrier 충돌: 16개 한도. K/V stage × full/empty × 2 WG = 쉽게 도달. parity trick 필수.

1 3-column 비교표 ★★

관점v1 (2022)v2 (2023)v3 (2024)
핵심tile + online softmax+ deferred rescale
+ N축 병렬
+ TMA · WGMMA
+ WS · FP8
loop orderKV outer, Q innerQ outer, KV innerQ outer + WS
rescale매 KV step O/ℓ최후 1회최후 1회
병렬축batch × head+ seq (N/Br)+ tile scheduler
warp splitQ-splitK-splitWS producer/consumer
async copycp.async (opt)cp.asyncTMA
MMAmma.syncmma.syncWGMMA async
target SMsm_80sm_80, sm_90sm_90
FP8 지원E4M3 + incoherent
dQ atomicyes (KV outer)no (Q outer)no
Hopper featuresTMA, WGMMA, mbarrier, setmaxnreg

2 성능 수치 비교

HW / configFA1FA2FA3
A100 FP16 d=128 N=2K124 TF225 TF
H100 FP16 d=128350 TF740 TF
H100 FP8 d=1281,200 TF
H100 peak FP16 WGMMA989 TF
H100 peak FP8 WGMMA1,979 TF

저자 보고치. Dao 2022·2023, Shah 2024.

3 non-mma FLOP 비중

FA1FA2FA3
rescale~25%~10%~5%
exp (SFU)~15%~15%~12%
mma~60%~75%~83%

4 Memory O(N) 유지

세 버전 모두 forward activation은 LSE (N) 만 저장. S materialization 없음 공통.

5 지원 변형

변형FA1FA2FA3
Causal
GQA/MQA
Sliding window
ALiBi bias
Soft-cap (Gemma)
varlen (packed)
FP8

6 한계 공통

  • head_dim ≤ 256 (smem)
  • FA3 FP8: 정확도 손실 ~0.5% → incoherent로 보정
  • variable seq는 ragged/packed 처리 필요
  • Blackwell (sm_100)는 FA3 kernel 재타겟 중

7 진화 3-단계 요약

v1: IO-aware (SRAM tile + online softmax)
v2: Parallel-aware (Q-outer + seq 병렬)
v3: Hopper-aware (TMA + WGMMA + WS + FP8)
FA 진화: T · PR · WSF (Tile · Partition+Rescale · WarpSpec+FP8)

1 Causal 정의

Sij = −∞   if j > i token i는 자기 이전 토큰만 참조 · language modeling 기본

2 Block 분류

KV block j_start..j_end vs Q block i_start..i_end:

case A: j_end ≤ i_start     → full-keep, mask 불필요
case B: j_start > i_end     → full-skip (전부 −∞)
case C: overlap (diagonal)  → per-element mask
  • A: mask 연산 0
  • B: 전체 KV tile skip (GEMM 자체 생략)
  • C: diagonal block만 mask 적용

3 Compute 절감

keep ratio = 1/2 + Bc/(2N)
→ 거의 정확히 1/2 전체 N²/2 + diagonal strip

4 의사코드 ★

for j in range(0, N, Bc):
    j_end = j + Bc
    if j > i_end:        # case B
        continue         # full skip

    Sij = (Qi @ Kj.T) * tau

    if j_end > i_start: # case C
        Sij = where(
            row >= col_in_global,
            Sij, -inf
        )
    # else: case A, no mask

    # continue FA update

5 Mask 구현 패턴

방법비용
additive (+ −∞)FP op 1회
select (where)predicate compare
precomputed mask memsmem/HBM load

FA2는 compile-time per-block 분기 → mask 자체 load 없음.

6 Tile scheduler 효과

  • naive: 모든 (i, j) pair 할당 → upper-triangle warp도 대기
  • causal-aware scheduler: lower-triangle + diagonal 만 워크큐
  • SM간 load balance ↑

7 변종 mask

mask정의
causalj ≤ i
sliding windowi−W ≤ j ≤ i
prefix-LMprefix 내부 full, suffix causal
document-packed같은 doc_id 내에서만
tree (spec decode)branch 구조 기반

8 ALiBi · Soft-cap

ALiBi: Sij += −mh·|i−j|
Soft-cap: Sij ← cap·tanh(Sij/cap) mh = head-specific slope · soft-cap은 Gemma-2 사용
negative infinity: FP16 exp(−65504) 는 0이지만 일부 GPU에서 NaN. 대체로 큰 음수 (−1e4) 사용이 안전.

1 동기 ★

  • long-context: N = 1M → single GPU KV 저장 불가
  • seq축을 여러 GPU에 sharding 필요
  • FA의 online softmax 결합 공식 활용 → GPU간 partial merge

2 Sharding 구조

N = P·Np
rank r 은 Qr, Kr, Vr ∈ ℝNp×d 보유 P개 GPU가 ring 위에 위치 · Q는 rank 고정, KV는 순환

3 Ring rotation

time step 0:   rank r holds (K_r, V_r)
            compute Q_r · K_r
step 1:  send K_r → r+1, recv K_{r−1}
         compute Q_r · K_{r−1}
step 2:  순환 P−1 회
end:    Q_r 가 모든 rank의 KV 본 셈

4 의사코드 ★

# per rank r
Qr = load_local()
Kr_cur = Kr; Vr_cur = Vr
(m, l, O) = (-inf, 0, 0)

for step in range(P):
    # async send/recv while compute
    isend(Kr_cur, Vr_cur, dst=r+1)
    irecv(K_next, V_next, src=r-1)

    # FA2-style inner update with K_cur, V_cur
    S = Qr @ Kr_cur.T * tau
    if causal: mask(S, step)
    (m, l, O) ← online_update(m, l, O, S, Vr_cur)

    wait_comm()
    Kr_cur, Vr_cur = K_next, V_next

O = O / l; write(O_r)

5 Causal 처리 in ring

문제 rank r의 Qr는 자기보다 낮은 index만 attend → step에 따라 mask 조건 다름. 일부 step은 full skip, 일부는 diagonal.
  • step s에서 KV는 rank r−s의 것
  • r−s > r ⇒ 전부 future ⇒ skip
  • load imbalance 발생 → striped 배치로 완화

6 Comm-Compute overlap

compute: |FA(step 0)|FA(step 1)|FA(step 2)|
comm:     |send/recv→|send/recv→|send/recv→|
          ↑ NCCL SendRecv async

d=128, Np=8K, P=8 → compute 시간 > comm 시간이 일반.

7 변형 · 확장

방식특징
Ring Attention (Liu '23)Q 고정, KV 순환
Striped Attentioncausal load balance
Tree Attention (long-ctx)tree-based reduce
Context ParallelMegatron CP = ring variant
분산 맥락 상세 ↗ V15 §12 · §13. 여기서는 kernel 관점.

1 정의 ★

이름Hkv특징
MHAH표준 (GPT-2/3)
MQA1cache 1/H배 (PaLM)
GQAg (e.g. 8)중간 (Llama-2 70B)
MLAcompressedcache r/D (DeepSeek) ↗ V07 §15

2 Group 수식

Q ∈ H heads · d
K, V ∈ g heads · d (g | H)
group size = H/g
head h → KV head ⌊h/(H/g)⌋ 여러 Q-head가 같은 KV-head 공유

3 KV cache 크기

KV_bytes = 2·N·Hkv·d·size(dtype)
MHA → MQA: H× 절감
MHA → GQA(g=8, H=32): 4× 절감

4 Kernel 구현 차이

  • MHA: 각 head 독립 attention kernel
  • MQA/GQA: Q를 H/g 그룹으로 묶어 한 CTA에 처리
  • KV는 1회 load → smem 재사용 (H/g배)
  • arithmetic intensity ↑ (compute-bound 전환)

5 Decode AI 분석 ★

FLOP = 2·N·H·d (Q 고정, K full)
KV bytes = 2·N·Hkv·d·dtype
AI = H / (Hkv·dtype) MHA: AI=H/(H·2)=0.5 · MQA: AI=H/(1·2)=H/2

decode는 memory bound → MQA/GQA로 AI 끌어올려 Tensor Core 활용.

6 모델별 설정

모델HHkv방식
Llama-2 7B3232MHA
Llama-2 70B648GQA
Llama-3 8B328GQA
Mistral 7B328GQA
PaLM1MQA

7 FA kernel tiling 영향

MHA kernel layout:
  grid = (N/Br, batch × H)
  1 CTA = 1 Q-head, B_r rows

GQA kernel layout:
  grid = (N/Br, batch × H_kv)
  1 CTA = (H/g) Q-heads group
  Q_tile = Br × (H/g) × d   ← register ↑
  K/V tile = Bc × d (shared)

8 품질 영향

  • MQA: perplexity 소폭 ↑ (~0.1~0.3)
  • GQA g=8: MHA 대비 무의미한 차이
  • 학습 시점 GQA 스타일로 훈련 가능 (Llama-2 paper)

9 PagedAttention과의 상호작용

block table의 Hkv·d per page ⇒ GQA → page 크기 축소. 같은 메모리로 더 많은 요청 수용. ↗ V07 §16

g | H 제약: g는 H의 약수여야 함 (보통 1, 2, 4, 8). 비정수 ratio는 구현 복잡.

1 아이디어 ★

정의 KV를 저차원 latent c (rank r « d) 로 압축. cache는 c만 저장. attention 시 up-projection WUK, WUV 곱해 복원.

2 수식 ★

ci = WDKV·xi ∈ ℝr
Ki = WUK·ci ∈ ℝH·d
Vi = WUV·ci ∈ ℝH·d
cache = {ci}, size = N·r (ll N·H·d)

3 Cache 압축률

방식per-token KV bytes
MHA (FP16)2·H·d·2
GQA g=82·8·d·2
MLA r=512512·2 + RoPE

DeepSeek-V2: r=512, H·d=16384 → ~32× 절감.

4 Absorb trick ★

핵심 attention 수식에서 WQTWUK를 미리 곱해 하나의 행렬로 변환. decode 시 Ki 복원 없이 ci와 직접 attention 가능.
S = (WQx)·(WUKc)T·τ
= xT·(WQTWUK)·c·τ
= xT·WQKabs·c·τ per-head Wabs 미리 곱 · runtime GEMM 1회 감소

5 RoPE 분리 (decoupled) ★

이슈 RoPE는 position에 의존하는 회전 → c → K의 linear map에 흡수 불가. 해법: RoPE를 별도 채널에 둠.
  • K = concat(Knope from c, Krope raw)
  • Krope ∈ ℝdr에만 RoPE 적용
  • cache = c (absorb용) + Krope (decoupled)

6 Kernel 구조

# decode: 1 query token
q_nope = W_Q @ x             # (H, d)
q_rope = apply_rope(W_QR @ x)
q_abs  = q_nope @ W_UK       # absorb

# attend to cached c, K_rope
for page in block_table:
    c_page      = load(c[page])       # (Bp, r)
    K_rope_page = load(K_rope[page])
    S_nope = q_abs @ c_page.T * tau
    S_rope = q_rope @ K_rope_page.T
    S = S_nope + S_rope
    # online softmax
    O += softmax_update(...)

# V도 c 기반 복원
O = O @ W_UV  # or absorbed

7 장단점

cache 32× 절감구현 복잡 (RoPE 분리)
long context 유리absorb trick 필요
품질 MHA 수준학습부터 MLA로 해야
prefill vs decode: prefill에서는 absorb가 오히려 불리 (K 전체 복원 후 GEMM이 Tensor Core 이용). decode에서만 absorb 경로 선택.
MLA 3-키: 압축·absorb·분리 (compress c · absorb WQK · decouple RoPE)

1 동기 ★

  • seq 길이 가변 · 동시 R개 요청
  • contiguous KV 할당 → 내부 단편화 60%+
  • OS virtual memory 패턴 차용
  • KV를 고정 크기 page로 쪼개 관리

2 기본 수식

Bp = page_size (token 수)
per-page bytes = Bp·Hkv·d·2 (K+V) vLLM 기본 Bp = 16

3 자료구조 시각화 ★★

┌─────────────────────────────────────────────────┐
│  physical block_pool (device HBM)               │
│  ┌────┬────┬────┬────┬────┬────┬────┬────┐      │
│  │ p0 │ p1 │ p2 │ p3 │ p4 │ p5 │ p6 │ p7 │ ...  │
│  └────┴────┴────┴────┴────┴────┴────┴────┘      │
│   각 block = Bp×Hkv×D×2B (K,V)                   │
└─────────────────────────────────────────────────┘
          ▲       ▲       ▲
          │       │       │  (referenced by table)
┌─────────┴───────┴───────┴────────────┐
│  block_table[req_id]                 │
│  req A: [p3, p1, p5,  ...]           │
│  req B: [p3, p1, p2, p7, ...]        │
│  req C: [p6, p4, p0,  ...]           │
│          ↑   ↑                       │
│          └───┴─ shared prefix pages  │
└──────────────────────────────────────┘

logical token t → page:
  page_idx  = block_table[req][t // Bp]
  offset    = t % Bp
  K_ptr     = block_pool[page_idx][offset]

4 메타데이터 필드

field타입의미
block_poolTensor [P, Bp, Hkv, d, 2]pre-allocated
block_tableint32 [R, Lmax/Bp]req → page list
seq_lenint32 [R]current length
ref_countint32 [P]page refcount (CoW)
free_listqueue사용 가능 page idx

5 Fragmentation 비교 ★

방식internal fragexternal frag
contiguous max_len60~80%R↑ 시 OOM
contiguous actual~0%매 요청 alloc/free
paged Bp=16<Bp·bytes0 (page 풀)

vLLM 논문: 메모리 활용률 >96%.

6 Block size 영향

Bp장점단점
16frag ↓, 공유 세밀page table ↑
32중간중간
128HBM throughput ↑frag ↑, 공유 거침

7 Prefix sharing (CoW) ★

  • 여러 seq가 같은 system prompt → 같은 page 가리킴
  • ref_count ≥ 2인 page에 write 발생 시 copy-on-write
  • 분기 시점에만 새 page alloc
  • 4K system prompt · R=100 → 100× 메모리 절감

8 Beam · parallel sampling

  • beam K개가 같은 prefix page 공유
  • 분기 시 새 page만 alloc
  • fork overhead ≈ O(1) page

9 Allocator 정책

alloc_page(req):
    if free_list empty:
        evict() or preempt()
    p = free_list.pop()
    ref_count[p] = 1
    return p

free_req(req):
    for p in block_table[req]:
        ref_count[p] -= 1
        if ref_count[p] == 0:
            free_list.push(p)
scheduler 자체 동작 ↗ V16 §5 · §6. 여기서는 kernel 입력 자료구조만.

1 변화 요지 ★

  • K/V가 비연속 (page 단위 흩어짐)
  • iteration index → page lookup → gather load
  • TMA 2D bulk copy 유리 (per-page descriptor)

2 Grid / CTA (v1)

grid = (R_active, H_q)
1 CTA = 1 (request, q_head)
1 CTA가 자기 req의 모든 page를 loop
  • 짧은 seq에서 효율
  • 긴 seq에서 SM 활용률 저조

3 Decode 의사코드

# 1 thread block = 1 query token, 1 head
q = load(Q[req, head])          # (d,)
(m, l, O) = (-inf, 0, 0)

for t_block in range(0, seq_len, Bp):
    page = block_table[req][t_block // Bp]
    K_p  = load(block_pool[page][.., K])
    V_p  = load(block_pool[page][.., V])

    s  = q @ K_p.T * tau        # (Bp,)
    if causal: mask_suffix(s, t_block)

    (m, l, O) = online_update(m, l, O, s, V_p)

O /= l
write(O_out[req, head], O)

4 v2: Split-K ★

변경 K축을 P partition으로 split → 여러 CTA가 병렬. 긴 seq에서 SM 활용 ↑. partial (m, ℓ, O) → 2nd-stage reduce kernel.
stage 1: grid = (R × H × P)
   CTA p가 partition p의 page만 loop
   → partial (m_p, l_p, O_p)

stage 2: grid = (R × H)
   online softmax merge P partials:
     M     = max_p m_p
     α_p   = exp(m_p − M)
     L     = Σ_p α_p · l_p
     O_out = Σ_p α_p · O_p / L

5 Gather-scatter 패턴 ★

# page 단위 bulk load
for stage in stages:
    # TMA 2D: (Bp, H_kv, d)
    tma_load(smem[stage],
             block_pool[page_idx[stage]])
    mbarrier.expect_tx(...)

# compute consumer
for stage in stages:
    mbarrier.wait(stage)
    compute(smem[stage])

page_idx는 block_table에서 CTA 시작 시점에 한 번에 로드.

6 성능 특성

지표
memory 활용96%+
decode latencyFA 대비 ~1.1×
throughput (R↑)FA 대비 2~4×
page indirection 오버헤드~5%

Kwon et al. 2023 · vLLM paper Fig.

7 FP8 KV cache

  • page 단위 per-token scale 저장
  • scale layout: 각 page에 appended Bp scalar
  • load 시 dequant fuse

8 GQA / MLA와의 조합

  • GQA: Hkv·d per-page 줄어 → Bp ↑ 가능
  • MLA: page에 c (r-dim) + Krope 분리 저장
  • kernel은 per-head 대신 per-group 루프
실수: kv_cache_dtype=fp8 때 page 단위 scale 저장 위치 잘못 잡으면 정확도 무너짐. stride 계산 주의.
PagedAttn v1/v2: 1CTA·splitK·reduce

1 정의 ★

정의 모든 요청의 prefix를 radix trie로 관리. 새 요청은 longest-match prefix 노드 찾아 KV 재사용 → compute · memory 동시 절감.

2 자료구조 ★

root
 ├─ "You are"───┬─ " a helpful..."  [req A,B,C]  ref=3
 │              └─ " an expert..."  [req D]      ref=1
 │
 └─ "###user\n" ├─ "What is GPU..." [req E]      ref=1
                └─ "How to..."      [req F,G]    ref=2

각 노드 = (token seq, KV page list, ref_count, LRU timestamp)
각 edge = token prefix (shared)

3 주요 연산

op동작
match_prefix(tokens)(node, matched_len)
insert(tokens[matched:])새 child 추가
lock_ref(node)ref_count ++
unlock_ref(node)ref_count −−
evict()LRU leaf with ref=0

4 Match 알고리즘

match_prefix(tokens):
    node = root
    matched = 0
    while matched < len(tokens):
        child = node.find_child(tokens[matched])
        if child is None: break
        k = common_prefix_len(
              child.tokens,
              tokens[matched:])
        matched += k
        if k < len(child.tokens):
            break          # partial match in edge
        node = child
    return (node, matched)

5 LRU eviction 정책

  • leaf (children 없음) + ref_count = 0 대상
  • oldest LRU timestamp부터 evict
  • evict된 page → block_pool free_list 반환
  • non-leaf는 children 먼저 evict되어야 reachable

6 Hit rate 효과

시나리오hit rate
독립 질문 (cold)0~10%
chat (system prompt)60~80%
agent (tool loop)70~90%
few-shot batch90%+

SGLang 2024 보고.

7 PagedAttention과의 관계

  • trie 노드 내부 KV = PagedAttention page 리스트
  • CoW trigger: 분기 시점에 새 page만 alloc
  • ref_count ≥ 2 edge에서 write 시 split

8 Kernel 영향

  • prefill: matched prefix tokens의 KV 이미 존재 → compute skip
  • 새 suffix만 forward
  • attention kernel 자체는 동일 (paged)
  • 변경은 scheduler + cache 관리 쪽 ↗ V16 §14

9 Radix vs hash prefix cache

방식
trie (radix)부분 prefix 공유tree 관리 복잡
hash (vLLM prefix)O(1) lookupblock align 필요
token-align: radix는 edge를 token 단위 자름. Bp와 불일치하면 partial page 발생. 실무에서는 Bp 경계 기준으로 edge split.

1 공식 1-liner 모음 ★

attention: O = softmax(QKT/√d)·V
safe sm: yi = exi−m/ℓ, m=max, ℓ=Σex−m
online update: m'=max(m,x), ℓ'=αℓ+ex−m'
merge: ℓabaabb, α=em−mab
LSE = m + log ℓ
FA HBM: Θ(N2d2/M) bytes
FA mem (act): O(N) (LSE only)
causal keep: ~ 1/2 of blocks
MQA cache: 1/H · MHA
MLA cache: r / (H·d) ≈ 1/32
Paged: logical t → (block_table[r][t/Bp], t%Bp)

2 Recompute 조건

  • forward activation −95% (N2 → N)
  • backward FLOP +40%
  • 실행시간 여전히 빠름 (HBM 병목 완화)

3 변종 선택 결정트리 ★★

start
 │
 ├─ prefill or training?
 │    ├─ sm_90+?        → FA3 (WGMMA, TMA, FP8)
 │    ├─ sm_80?         → FA2
 │    └─ older?         → FA1 (legacy)
 │
 ├─ decode (1 query token)?
 │    ├─ R 작음 · 긴 seq? → PagedAttn v2 (split-K)
 │    ├─ R 큼 · 짧은 seq? → PagedAttn v1
 │    └─ prefix 공유?    → RadixAttn + Paged
 │
 ├─ context > 100K · multi-GPU?
 │    └─ Ring Attention (CP)
 │
 ├─ model cache 압박?
 │    ├─ train부터 설계 가능? → MLA
 │    └─ 그 외                → GQA g=8
 │
 └─ speculative decode 맞추기? → tree mask FA3

4 성능 bottleneck 표

phasebound최적화
prefill (긴)computeFA3 FP8, WGMMA
decodememoryMQA/GQA/MLA
long-ctxHBM + commRing, Paged
batch 큼SM 활용continuous batch

5 흔한 실수 5선

  1. P·V 누산을 FP16으로 (→ FP32 필수)
  2. −∞ 대신 −inf float 값 쓴 채 FP16 exp
  3. dropout 재현 위해 mask state 누락 (bwd recompute 불일치)
  4. FA3 FP8 scale 동적 갱신 (비결정)
  5. PagedAttn page 경계에서 stride 계산 off-by-one

6 연관 권

  • Online softmax 수치 ↗ V09 §9
  • FA3 WGMMA · TMA ↗ V04 §4·§7
  • FA kernel as GEMM pattern ↗ V06
  • Ring Attention → CP ↗ V15 §12
  • PagedAttn in vLLM path ↗ V16 §7
  • Speculative tree mask ↗ V08 §12·§13
  • FP8 quant ↗ V10 §10

7 진화 타임라인

yearlandmark
2017Attention is all you need
2018Online softmax (Milakov)
2022FA v1, Ring Attn
2023FA v2, GQA, PagedAttn
2024FA v3, MLA, RadixAttn
attention 5-축: compute·mem·mask·cache·share (FA · MLA/GQA · causal · Paged · Radix)