CUDA 18VOL · CONTENT-FIRST · A4 LANDSCAPE · 16p

기타 LLM 커널 단권화

MoE · KV Cache Ops · RoPE · Norm · Sampling · Speculative Decoding
Volume V08/18
Tier T2 Kernel 패턴
선행 V01 · V05 · V06 · V07
용도 LLM forward pass 부품 인출 트리거

목차

1. MoE 구조 개요p.2
2. MoE Permute / Unpermutep.3
3. MoE Grouped GEMM (ragged)p.4
4. Dropless MoE (Megablocks)p.5
5. KV Cache layoutp.6
6. KV Cache quantizationp.7
7. Rotary Position Embeddingp.8
8. RMSNorm / LayerNorm kernelp.9
9. SwiGLU / GELU fusionp.10
10. Sampling kernelp.11
11. Beam Search kernelp.12
12. Speculative Decodingp.13
13. Medusa · EAGLE · Tree attnp.14
14. Masked fill · attention biasp.15
15. Cheat Sheetp.16

범례

핵심 용어 (노란 배경)
매우 중요 · 표 헤더
정의·공식 박스
예시·워크드 박스
빨강주의·실수 포인트
인출 핵심
(!)니모닉 (권당 ≤5)
타 권 cross-ref
흐름·인과
∵∴이유·결론
인쇄 A4 가로 / 여백 없음 / 배경 그래픽 포함 · 범위 제외 vLLM 시스템 (↗ V16) · 학습 관련 (↗ V17)
Megablocks · Mixtral · Llama 2/3 · vLLM · TRT-LLM · Leviathan'23 · Medusa · EAGLE

1 MoE 정의 sparse FFN

정의 Mixture-of-Experts: FFN 블록을 N개 expert로 분할. token당 top-k(보통 k=1,2)개 expert만 활성. 파라미터 N×, 계산 k/N×.
  • dense FFN: 모든 token이 모든 W₁·W₂ 통과
  • sparse MoE: token별 조건부 경로 선택
  • Mixtral 8×7B · DeepSeek-V3 · Qwen-MoE

2 Forward 5단 라탑게전합

  1. Router: logits = x · Wr ∈ ℝN
  2. Top-k: 상위 k index · score 선택
  3. Gate: softmax(top-k) · normalize
  4. Expert FFN ej: yj = W2,j·σ(W1,j·x)
  5. Combine: y = Σj∈top-k gj·yj

3 Router 수식

logits = x · Wr  Wr ∈ ℝd×N
idx, s = topk(logits, k)
g = softmax(s)   (또는 sigmoid · noisy) x : [T, d] token embedding   N : expert 수   k : 활성 expert (1~2)

4 Top-k Router 커널

// 한 block = 한 token · thread N개 병렬
__shared__ float s_logits[N];
int t = threadIdx.x;
s_logits[t] = dot(x[tid], Wr[:,t]);
__syncthreads();
// warp-level bitonic/heap top-k
topk_warp(s_logits, k, idx, val);
g = softmax(val);

N ≤ warp 시 single-warp top-k, 그 외 block-wide merge tree.

5 MoE 변형 표

모델Nk특징
Switch1281hard routing
GShardE2capacity cap
Mixtral-8×7B82softmax gate
DeepSeek-V32568shared+routed
Qwen-MoE604fine-grained

N 수치는 각 모델 카드 기준 · 2024~2025.

6 Capacity & Drop

capacity = ⌈T · k / N · Cf
Cf : capacity factor (1.0~1.25)
  • expert 하나에 몰린 token 수 > capacity → drop
  • 학습 시만 주로 사용 · inference는 dropless 선호 (↗ §4)

7 Load balancing loss 학습용

Laux = N · Σj fj · pj
fj : expert j로 라우팅된 token 비율
pj : expert j의 평균 router prob

Inference 시는 aux loss 불필요 (학습 ↗ V17).

8 MoE 커널 Pipeline

token [T,d]
   │ router
   ▼
logits [T,N] → top-k → idx [T,k], g [T,k]
   │ permute (↗ §2)
   ▼
grouped GEMM W1 (↗ §3)
   │ act (SwiGLU ↗ §9)
   ▼
grouped GEMM W2
   │ unpermute + weighted sum
   ▼
y [T,d]

9 Out-of-scope 범위

  • expert parallel · all-to-all 통신 → ↗ V15 §4
  • MoE 서빙 배치 스케줄링 → ↗ V16 §5
  • routing loss 역전파 → ↗ V17 §6
흔한 실수: top-k 이후 softmax가 아닌 전체 N에 softmax 후 top-k를 하면 gradient sparse 성질 깨짐.

1 왜 permute? ★

동기 grouped GEMM은 expert별로 연속된 token을 요구. router 결과는 token 순서 → expert 순서로 재배치해야 GEMM이 가능.
  • k>1 이면 token 하나가 k번 복제됨
  • 역변환(unpermute)은 weighted sum으로 복원

2 자료구조 변화도 ★

(k=1, T=6, N=3 예시)

[1] router 결과 (token order)
  t0→E1  t1→E0  t2→E2  t3→E1  t4→E0  t5→E2

[2] histogram (expert별 카운트)
  E0: 2  E1: 2  E2: 2

[3] exclusive prefix sum (offset)
  E0: 0  E1: 2  E2: 4

[4] permuted (expert-major)
  idx:   0  1  2  3  4  5
  slot: E0 E0 E1 E1 E2 E2
  src : t1 t4 t0 t3 t2 t5
      └──┬──┘└──┬──┘└──┬──┘
         group0  group1 group2
         (→ grouped GEMM ↗ §3)

[5] unpermute (역매핑 · weighted)
  y[t_src] = g · expert_out[slot]
  (k>1 이면 k개 slot 합산)

3 Permute 3-stage 커널 히프재

  1. Histogram: count[e] = Σt [assign(t)=e]   (atomic add)
  2. Exclusive scan: offset[e] = Σe'<e count[e']   (↗ V05 §3)
  3. Scatter: slot = offset[e] + atomicAdd(local[e],1)

4 의사코드

// kernel 1: histogram
__global__ hist(idx, cnt){
  int t = global_tid;
  int e = idx[t];
  atomicAdd(&cnt[e], 1);
}
// kernel 2: exclusive scan (block-wide)
scan_exclusive(cnt, offset, N);
// kernel 3: scatter
__global__ scatter(idx, offset, perm, inv){
  int t = global_tid, e = idx[t];
  int slot = offset[e] + atomicAdd(&loc[e],1);
  perm[slot] = t;   // src index
  inv[t]     = slot; // reverse map
}

5 Permute 메모리 형태

버퍼shape의미
idx[T,k]token→expert
cnt[N]expert별 token
offset[N+1]group 시작
perm[T·k]slot→src token
inv[T·k]src→slot

6 Unpermute 역매핑 + 가중합

__global__ unpermute(out, expert_y, inv, g, T, k, d){
  int t = blockIdx.x;
  int c = threadIdx.x;
  float acc = 0.f;
  for(int j=0; j<k; j++){
    int slot = inv[t*k + j];
    acc += g[t*k + j] * expert_y[slot*d + c];
  }
  out[t*d + c] = acc;
}

7 성능 주의

  • histogram: expert 수 N 작음 → atomic contention
  • 완화: block-local privatization → global reduce (↗ V01 §10)
  • scatter의 src load는 random → coalesce 깨짐 정상
  • d가 작을 땐 permute + GEMM1 fuse (CUTLASS epilogue)

8 Sort-based variant

대안 (expert_id, token_id) pair를 radix sort로 정렬 → prefix sum 없이 permute 완료.

N 큰 경우(N>64) radix sort가 histogram+scatter보다 안정.

결정론 필요 시 atomicAdd 순서 비결정 → sort-based 선택.

1 Ragged batch 문제

문제 permute 후 expert별 token 수 me가 서로 다름. standard batched GEMM은 모든 m 동일해야 함.
expert 0: m=17  × [d, H]
expert 1: m= 5  × [d, H]
expert 2: m=23  × [d, H]
 ...
expert N: m=11  × [d, H]

2 Grouped GEMM 정의 ★

Ye = Xe · W1,eT  e = 0..N-1
Xe ∈ ℝme×d, W1,e ∈ ℝH×d 각 e는 독립 GEMM · 한 kernel launch로 처리 · GEMM 본체 ↗ V06 §5

3 두 접근

방식아이디어pros
Groupedexpert별 독립 GEMMsimple
Block-sparse하나의 sparse matmuldropless (↗ §4)

4 CUTLASS Grouped GEMM ↗ V06

problem_sizes [N]: {(m_e, n, k)}
ptr_A [N]: X_0, X_1, ... X_{N-1}
ptr_B [N]: W_0, W_1, ... W_{N-1}
ptr_C [N]: Y_0, Y_1, ... Y_{N-1}
lda/ldb/ldc [N]

kernel: 각 CTA가 (e, tile_m, tile_n) 담당
  → tile scheduler가 ragged m_e 고려
  → 빈 expert(m_e=0) skip

↗ V06 §10 CUTLASS GroupedGemm & Persistent kernel 구조.

5 Tile scheduling

  • Per-expert: expert별 grid(⌈me/BM⌉, ⌈n/BN⌉) · load imbalance
  • Persistent: CTA 수 고정 · work queue에서 (e, tile) pull · balance ↑
  • Hopper: ↗ V04 §4 TMA descriptor per expert

6 Work estimate

FLOPtotal = 2 · H · d · Σe me = 2 · H · d · T · k T · k 가 total 활성 token 수 · dense FFN 대비 정확히 k/N · (단, H는 보통 dense의 1/k배로 설정)

7 두 GEMM 구조 (FFN)

X[m_e,d]
  │ GEMM1 W1_e  [d, 2H]  (gate·up)
  ▼
Z[m_e,2H]
  │ SwiGLU  (↗ §9)
  ▼
A[m_e, H]
  │ GEMM2 W2_e  [H, d]
  ▼
Y[m_e, d]

8 Expert 병렬 축

분할비고
expertCTA.zragged me
mCTA.ytileM
nCTA.xtileN
ksplit-K optme 작으면

9 Small-m 문제

me < 16 인 expert: Tensor Core tile 작아 under-util. → split-K 또는 GEMV 경로 fallback.

Mixtral-8×7B decode 단계는 T=1 (batch=1 per seq) → GEMV 위주.

quant MoE: expert별 scale 다름 → 각 ptr_B에 scale pointer도 함께 (↗ V10 §7).

1 왜 dropless ★

모티브 capacity cap → token drop, padding, 품질 저하. Megablocks: capacity 없이 block-sparse GEMM 하나로 모든 expert 처리.
  • quality: drop 0
  • load balance 압력 ↓ (학습 속도에 영향 없음)
  • inference에서 기본으로 선호

2 Block-sparse 표현

full FFN weight W ∈ [N·H, d]
  W_stack: expert 0 ... expert N 을 행 방향 concat
  mask: [T·k, N]  block-diagonal-ish

X' [T·k, d]  (permuted)
W_stack [N·H, d]

result [T·k, H] 은 block-sparse matmul:
  row r ∈ expert e 이면 W_e 구간만 접촉

3 비교 표

방식droppadkernel
Token choicebatched
Grouped有(tile)N grouped
Megablocksblock-sparse 1

4 Block-sparse layout (BCSR)

BCSR Block Compressed Sparse Row: block 단위의 CSR. row_ptr · col_idx · block_val.
block tile: BM × BK (e.g. 128×128)

row_ptr  [T/BM + 1]  : row당 active block 누적
col_idx  [nnz]       : 각 block의 열 tile idx
values   [nnz,BM,BK] : block data

Megablocks : SDD (Sparse-Dense→Dense) · DSD · DDS 세 변형.

5 Kernel 구조

  1. CTA = (row block tile)
  2. row_ptr로 active col block 스캔
  3. 각 active block → shmem load → MMA
  4. dense accumulator → epilogue

6 Formula

Y[i,j] = Σb∈active(i) X[i, Kb] · W[Kb, j]
active(i) = { b : block (i,b) ≠ 0 }

7 Padding 자동화

  • expert 당 token 수 me가 BM 배수가 아니어도 block 내부만 masked
  • token drop 없이 tile 경계 처리 → correctness
  • load imbalance → persistent CTA로 완화

8 구현 노트

단계커널
topologypermute 결과에서 block meta
SDDblock-sparse matmul
DSD/DDSbackward / dW (↗ V17)

9 Inference 전용 단순화

단순화 학습 backward 불필요 → SDD만 필요. 사실상 permute + grouped GEMM과 등가하지만 block-level masking으로 padding 자동.
학습 세부 (routing loss · backward SDD) → ↗ V17 §6.

1 왜 KV cache ★

정의 autoregressive decode에서 past K · V를 재계산 대신 저장. prefill 이후 decode는 step당 한 token만 forward.
  • prefill: O(L²d) attention
  • decode: O(L·d) attention per step (cache hit)

2 기본 dim 라배헤시디

KV ∈ ℝ[Llayer, B, Hkv, S, Dh] L : layer 수   B : batch   Hkv : KV head (MHA=H, GQA=G, MQA=1)   S : seq len   Dh : head dim

3 메모리 예산

모델LHkvDhKB/token
Llama3-8B328128128
Llama3-70B808128320
Mistral-7B328128128

FP16 · 2(K,V) 포함 · KB/token = 2·L·Hkv·Dh·2B.

4 Layout 변형 ★

layoutorderaccess 특성
head-major[B,H,S,D]head별 연속 · attention kernel 친화
seq-major[B,S,H,D]append 쉬움 · scatter 효율
paged[page, H, D]vLLM · ↗ V16

vLLM PagedAttention 세부 → ↗ V16 §3.

5 Update kernel (decode)

// new token k,v 1개 삽입
__global__ kv_update(K_cache, V_cache,
                         k_new, v_new,
                         b, s, h, D){
  int tid = threadIdx.x;
  if(tid < D){
    K_cache[b,h,s,tid] = k_new[b,h,tid];
    V_cache[b,h,s,tid] = v_new[b,h,tid];
  }
}

보통 QKV projection epilogue와 fuse.

6 Contiguous vs Paged

  • contiguous: 한 sequence가 연속 S slot · 재배치 필요 시 memcpy 큼
  • paged: block(=page) 단위 할당 · block_table로 논리↔물리 mapping
  • fragmentation ↓ / attention kernel의 간접 주소 한 단계

7 Head-sharing (GQA/MQA)

HqHkvcache 절감
MHAHH
GQA-8HH/G
MQAH1
MLAH~D/r

attention kernel 연결 → ↗ V07 §6.

8 Prefill vs Decode

prefill: K,V ← x · WK, x · WV (한 번에 S개)
decode:  k,v ← xnew · WK,V (1개)
append: cache[S_cur : S_cur+1] ← (k,v)

9 실패 모드

flash attn 호환: head-major [B,H,S,D]가 표준. seq-major 쓰면 attention kernel이 stride 재해석 필요.
beam search 분기 시 cache copy 필요 (↗ §11).

1 왜 quantize

모티브 long context에서 KV cache가 메모리 1순위. INT8 → 2×, INT4 → 4× 절감. attention은 memory-bound → speedup 직결.
  • weight quant는 compute-bound 완화 (↗ V10 §5)
  • KV quant는 memory-bound 완화

2 Quantization 공식

q = round(clamp(x / s, qmin, qmax))
x̂ = q · s
s = amax(|x|) / qmax INT8 : qmax=127   INT4 : 7   per-token · per-head scale 독립

3 Scale granularity

granularity메모리정확도
per-tensor1낮음
per-channel (head)H
per-tokenS높음
per-token per-headS·H최고

KV는 보통 per-token (outlier가 token 단위로 나타남).

4 저장 layout

K_q : [L, B, H_kv, S, D_h]  INT8/INT4
K_s : [L, B, H_kv, S]       FP16/BF16

decode 시:
  fetch K_q[b,h,:S,:] 와 K_s[b,h,:S]
  on-the-fly dequant:
    K̂[t,d] = K_q[t,d] · K_s[t]

5 Fused dequant attention 핵심

  1. K_q shmem 로드 (INT8) — 대역폭 ½
  2. score = Q · dequant(K_q, K_s) — MMA 입력에 scale 곱
  3. softmax는 FP32 유지
  4. V_q도 동일 경로로 dequant → attn·V

flash attention의 softmax online 흐름 유지 → ↗ V07 §4.

6 INT4 추가 세부

  • 4-bit packed: byte당 2 element (nibble)
  • unpack in register → INT8 확장 → MMA
  • group_size=64 or 128 within token (↗ V10 §6)

7 Quant 포인트

  x → W_K → k_fp  ─┐
                   │ online amax (per token · head)
                   ▼
                 k_q, k_s
                   │ KV cache 저장 (INT8/INT4)
                   ▼
  attention (fused dequant)
                   │
                   ▼
                 output FP16

8 정확도 고려

dtypePPL ΔKV mem
FP16 (baseline)0
FP8 (E4M3)~00.5×
INT8 per-token+0.020.5×
INT4 group-128+0.1~0.30.25×

Δ는 WikiText2 PPL 기준 범위 · 모델/calibration 의존.

9 주의

RoPE 이후 값을 cache하면 sin/cos 분포가 섞여 outlier ↑. RoPE 전 K 저장 + 매 step 회전 or RoPE 후 FP16 유지.
pre vs post softmax quant: post는 불가 (smax 후 스케일 의미 없음).

1 RoPE 아이디어 ★

정의 position m의 회전 변환을 Q,K에 적용. 내적 Q·K가 상대위치 m-n만의 함수가 되도록 설계.
  • absolute embedding과 달리 추가 파라미터 없음
  • sequence 연장 시 interpolation 가능 (YaRN, NTK)

2 복소 수식 ★

qm ∈ ℂd/2 : 2개씩 묶어 복소수 해석
θi = b−2i/d (i=0..d/2−1)
RoPE(qm)i = qm,i · ej·m·θi
⟨RoPE(qm), RoPE(kn)⟩ = f(q,k, m−n) b : base (보통 10000)   θi : rotation freq   j : 허수단위

3 실수 2×2 rotation ★

[q'2i; q'2i+1] =
  [cos(mθi), −sin(mθi);
   sin(mθi),  cos(mθi)] · [q2i; q2i+1] head dim Dh개 component를 (2i, 2i+1) 쌍으로 처리
풀어 쓰면:
  q'_{2i}   = q_{2i}·c − q_{2i+1}·s
  q'_{2i+1} = q_{2i}·s + q_{2i+1}·c
  (c = cos(mθ_i), s = sin(mθ_i))

4 sin/cos cache

  • pre-computed: cos[Smax, D/2], sin[Smax, D/2]
  • FP32로 보존 권장 (정밀도 손실 예방)
  • head 간 공유 · layer 간 공유

5 Kernel 구조

__global__ rope(Q, K, cos_t, sin_t, B, H, S, D){
  int t = blockIdx.y, h = blockIdx.z;
  int i = threadIdx.x; // 0..D/2
  float c = cos_t[t*D/2 + i];
  float s = sin_t[t*D/2 + i];
  float a = Q[..2*i], b = Q[..2*i+1];
  Q[..2*i]   = a*c - b*s;
  Q[..2*i+1] = a*s + b*c;
}

QKV projection epilogue로 fuse하는 것이 표준.

6 Interleave vs Split

방식pair모델
Interleaved(0,1)(2,3)…원 RoPE
Split-half(i, i+D/2)Llama 계열

같은 회전이지만 element layout이 다름 → 구현 시 불일치 주의.

7 Context 확장

  • PI (Position Interpolation): m ← m · Ltrain/Lnew
  • NTK: b ← b · (Lnew/Ltrain)d/(d−2)
  • YaRN: freq별 wavelength으로 차등 스케일 (low freq만 interp)

YaRN: wavelength ≫ Ltrain 이면 extrapolate, 짧으면 interpolate.

8 융합 전략

QKV_proj → (q,k,v)
   │
   ├─ rope(q, cos, sin)   ┐ epilogue
   ├─ rope(k, cos, sin)   │ 로 fuse
   └─ k,v → KV cache      ┘
   │
   ▼
attention (↗ V07)
MLA의 RoPE는 latent와 decouple → RoPE 부분만 별도 head.

1 정의 비교

LayerNorm: y = γ · (x − μ)/√(σ² + ε) + β
RMSNorm:    y = γ · x / √(mean(x²) + ε) μ,σ² : mean · var   γ,β : learnable   RMSNorm은 mean 없음 → 1-pass 더 쉬움

2 왜 RMSNorm

  • mean 제거: 1-pass 축 합 1개만
  • Llama · GPT-NeoX · Mistral 채택
  • β(bias) 제거로 parameter 절감

3 2-pass → 1-pass

방식pass수치
naive 2-pass2정확
1-pass (sum, sumsq)1cancel risk
Welford 1-pass1안정
RMSNorm 1-pass1단순

4 Welford 1-pass LayerNorm용

count ← count + 1
δ ← x − mean
mean ← mean + δ/count
M2 ← M2 + δ · (x − mean)
→ var = M2 / count running update · FP32 안정성   warp shuffle로 병렬 merge

5 Parallel reduction layout

// block당 token 1개 · D thread 병렬
int t = threadIdx.x;
float v = x[row, t];
float sq = v*v;
// warp reduce sum
for(int o=16; o>0; o>>=1)
  sq += __shfl_xor_sync(0xFFFFFFFF, sq, o);
// block reduce across warps (shmem)
...
float rms = rsqrtf(sq/D + eps);
y[row,t] = v * rms * gamma[t];

6 Vectorized load

  • D=4096 · FP16 → half2/float4로 4원소/thread
  • 128-bit access · shmem 없이 register 누산
  • D > block 시 thread당 multiple element

7 Residual fusion ★

fuse attention/FFN 출력 y에 residual x 더한 뒤 바로 다음 layer Norm 입력. add → norm 두 kernel을 1개로 융합.
x_res = x + attn_out    (add)
y     = RMSNorm(x_res)  (norm)
─────────────────────────────
fused: 한 kernel에서
  read x, attn_out
  sum = x + attn_out
  write x_res to mem  (다음 residual용)
  reduce sum(x_res²)
  y = x_res · rms · γ

8 수치 안정성

  • mean·var 축적은 FP32 강제
  • mixed precision: x FP16/BF16, 통계 FP32, 출력 FP16 (↗ V09 §9)
  • ε ≈ 1e−5 ~ 1e−6 (RMSNorm은 1e−6 흔함)

9 실패 모드

var 계산 naive: Σ(x−μ)² 를 (Σx²) − (Σx)²/n 로 대체 시 catastrophic cancellation. FP16이면 +∞ · NaN 가능.
shmem 부족: D 크면 block당 1 token 기본이지만 thread당 원소 ↑ 필요.

1 FFN 구조 변화

actFFN 식mat 수
ReLUW2·ReLU(W1x)2
GELUW2·GELU(W1x)2
SwiGLUW2·(Swish(Wgx) ⊙ Wux)3

SwiGLU: gate / up / down 세 matrix · Hff는 보통 dense FFN의 2/3.

2 Activation 수식

Swish(x) = x · σ(x)   σ(x) = 1/(1+e−x)
GELU(x) = 0.5·x·(1 + erf(x/√2))
GELU-tanh ≈ 0.5·x·(1 + tanh(√(2/π)·(x + 0.044715·x³)))
SwiGLU(a,b) = Swish(a) ⊙ b ⊙ : element-wise product   a = Wgx, b = Wux

3 GEMM + act + GEMM 흐름

x [T, d]
  │ GEMM1  (gate·up · [d, 2H])
  ▼
z [T, 2H]  →  a (first H), b (second H)
  │ SwiGLU: s = silu(a) * b
  ▼
s [T, H]
  │ GEMM2  [H, d]
  ▼
y [T, d]

4 Fusion 옵션 ★

fuse이득난이도
GEMM1 + actwrite 생략low (epilogue)
act + GEMM2read 생략high (prologue)
all three최대대부분 비현실

실전은 GEMM1 epilogue에 SwiGLU fuse가 표준. ↗ V06 §11.

5 Epilogue pseudo

// GEMM1 output tile [BM, 2BN]
half a = tile[m, n];
half b = tile[m, n + BN];   // second half
half sa = a * __hsigmoid(a); // Swish
half out = sa * b;
// write out [BM, BN] → feeds GEMM2
smem[m, n] = out;

6 GELU 근사식 선택

  • exact erf: CUDA erff · 느림, 정확
  • tanh approx: GPT-2 · BERT · 하드웨어 친화
  • sigmoid approx: x·σ(1.702·x) — 일부 구현

inference: tanh approx가 속도·정확도 균형.

7 수치 세부

act대칭음수 영역
ReLUY= 0
GELU~Ysmall leak
SwishNsmooth neg lobe

Swish는 non-monotone (min at x≈−1.28).

8 메모리 이득

fused (epilogue): H FLOP, 0 extra mem
unfused: H element write + H element read
bytes saved = T · H · 2 · 2B (FP16)
Quant FFN: GEMM 출력이 INT32 accumulator. act 전에 dequant 필요 → epilogue order = dequant → act → requant (↗ V10).

1 Sampling 파이프 온페탑피

  1. logits [B, V] from LM head
  2. penalty: repetition / presence
  3. temperature: l ← l / τ
  4. top-k: 상위 k만 남김
  5. top-p (nucleus): 누적 확률 ≤ p
  6. softmax + sample

2 Temperature · Penalty 식

li ← li / τ  (τ=0 → argmax)
repetition: li ← li / rp if li>0 else li · rp
presence: li ← li − α  (등장한 token i에 한해)

3 V 크기

모델V
Llama 232,000
Llama 3128,256
GPT-NeoX50,257

V ~ 10⁵ → reduction·top-k 커널의 주 축.

4 Top-k 알고리즘

방식복잡도용도
full sortO(V log V)k 큼
bitonic top-kO(V log² k)k≤1024
radix selectO(V · bits)k 중간
heap (warp)O(V log k)k≤32

실전 top-k ∈ {50, 64, 200} 대부분 radix select 또는 bitonic.

5 Top-p (nucleus) ★

  1. descending sort or top-k + sort
  2. softmax → cumulative sum
  3. cutoff idx = 최소 i s.t. cumi > p
  4. i 이후 logit = −∞ (mask)
  5. 재정규화 softmax → sample

top-k(V') + top-p 조합이 일반적: k'≈1024로 V 축소 후 nucleus.

6 Fused sampling kernel

// block = 1 batch element · V thread (또는 분할)
penalty_apply(l);
scale(l, 1/τ);
(vals, idxs) = topk(l, K);
p = softmax(vals);
cum = scan_inclusive(p);
cutoff = lower_bound(cum, p_top);
mask(&p, cutoff);
renorm(p);
out = inv_cdf_sample(p, rand);

7 Softmax 안정화

lmax = max(l)
pi = exp(li − lmax) / Σ exp(lj − lmax)

overflow 방지 · FP16 logit은 FP32로 승격 후 exp.

8 Sample 선택

method
greedyargmax(l)
multinomialinv-CDF
Gumbel-maxargmax(l + g)

Gumbel: gi = −log(−log(ui)) · sort 없이 argmax 한 번.

9 Batch 주의

batch 결정론: batch element마다 다른 seed 사용 · global RNG 공유 시 race. curand_philox per-slot.
per-request τ·k·p가 다름 → kernel 인자 [B] array.

1 정의

beam search 각 step에서 beam width W개 후보 seq를 유지. 확률 곱(= log-prob 합)이 큰 상위 W를 선택.
  • greedy: W=1
  • W↑ → quality ↑, memory ↑, 다양성 ↓
  • 실전 W ∈ {4, 8, 16}

2 Score 누적

scoreseq = Σt log p(yt | y<t)
length norm: score / lenα (α ∈ [0.6,1.0]) 곱 대신 log 합 · length penalty로 짧은 seq 선호 완화

3 확장 규칙

  1. 현 W개 beam 각각에 V개 token 확장 후보
  2. 총 W·V 후보 점수 = score + log p
  3. 상위 W개 선택 → 다음 step beam
  4. EOS 만난 beam은 finalized buffer로 분리

4 Top-W of (W·V) 커널 ★

// 각 beam b ∈ [0,W)
for(b=0; b<W; b++){
  logp[b,:] = log_softmax(logits[b,:]);
  cand[b,:] = beam_score[b] + logp[b,:];
}
// flatten [W,V] → [W·V] 후 top-W
(vals, flat_idx) = topk(cand, W);
int parent_b = flat_idx / V;
int tok      = flat_idx % V;

5 GPU 자료구조

버퍼shape
beam_score[B, W]
beam_tokens[B, W, Smax]
parent[B, W, Smax]
KV cache[L, B·W, H, S, D]

effective batch = B·W → KV mem W배.

6 KV cache 재배치 ★

문제 선택된 W개 beam의 parent beam이 섞임. 이전 step cache [B·W] → 새 parent 순서로 gather 필요.
prev beams: 0 1 2 3
selected parents: 2 0 2 1
  → new beam 0 ← old 2 의 cache
  → new beam 1 ← old 0
  → new beam 2 ← old 2 (share)
  → new beam 3 ← old 1

7 Cache gather 방식

  • full copy: [L·B·W·H·S·D] gather → 비쌈
  • pointer swap: cache slot index를 parent로 재할당 (paged KV에 유리)
  • lazy: 새 token에 대해서만 write, 조상 공유

paged KV는 block_table만 rewire · ↗ V16 §3.

8 종료 조건

  1. 모든 beam이 EOS 또는 max_len 도달
  2. finalized pool에 W_finish 개 확보 & best_alive ≤ best_finished
  3. return top-W by length-normalized score

9 Beam vs Sampling

beamsample
qualitydeterministic maxstochastic
diversity낮음높음
KV
용도MT · summarizechat · creative
중복 beam: 같은 parent에서 top-W 모두 선택되면 부모-cache share. KV gather 시 src 중복 고려.

1 아이디어 ★

핵심 작은 draft 모델이 γ token을 빠르게 생성. target 모델이 γ+1개를 한 번에 verify → 여러 token을 1 step으로 수락.
  • target 분포와 정확히 동일 sample (no approx)
  • memory-bound decode 구간에서 FLOP 여유 활용

2 Loop 구조

while not done:
  1) draft: x_1..x_γ ~ q(·|ctx)      (γ forward, small)
  2) target: p(·|ctx + x_1..x_i) for i=0..γ  (1 forward, large, batched)
  3) for i = 1..γ:
       r ~ U(0,1)
       if r < min(1, p_i(x_i) / q_i(x_i)): accept x_i
       else: reject → resample from p_i - q_i, stop
  4) if 모두 accept: bonus token ~ p_{γ+1}

3 Acceptance rule ★

accept xi if   r < min(1, pi(xi) / qi(xi))

reject 시 재샘플:
p'(x) = max(0, pi(x) − qi(x)) / Z
Z = Σx max(0, pi(x) − qi(x)) p : target prob   q : draft prob   r : U(0,1)   결과 분포 = p

4 기대 accept rate

α = Ex∼q[ min(1, p(x)/q(x)) ]
   = Σx min(p(x), q(x)) α ∈ [0,1] · q≈p 이면 α→1 · 독립이면 α≈0

5 기대 token / step ★

E[accepted + 1] = (1 − αγ+1) / (1 − α) γ개 draft의 prefix accept 기대값 + bonus 1 (모두 accept 시)
αγ=3γ=4γ=6
0.51.881.941.98
0.72.532.803.08
0.853.063.564.26

6 Speedup 공식 ★

c = costdraft / costtarget  (c ≪ 1)
speedup = (1 − αγ+1) / ((1 − α) · (γ·c + 1)) 분자 : 평균 수락 길이   분모 : draft γ forward + target 1 forward

Leviathan 2023 수식 · c는 보통 1/7B ~ 1/70B ratio.

7 γ 선택

  • α 낮으면 γ 작게 (낭비 ↓)
  • c 작으면 γ 크게 여유
  • 실전 γ ∈ {3, 4, 5}
  • 도메인별 α 측정 → adaptive γ

8 Verify 커널

input ctx + [x_1..x_γ]  len = C+γ
  → target forward (하나의 batch)
  → logits [γ+1, V]
  → softmax row별
  → acceptance loop (CPU or GPU)
  → update KV cache with accepted prefix
KV cache rollback: reject 시 거부된 위치 이후 cache invalidate. prefix만 유효 → seq len 조정.

1 왜 self-draft

  • 별도 draft 모델 운영 비용 ↑
  • target의 마지막 hidden을 활용해 draft 생성
  • Medusa · EAGLE · Lookahead 등 variant

2 Medusa 다중 head 병렬

구조 target 마지막 hidden ht 위에 k개의 predictor head. 각 head k는 t+k+1 위치 token 분포 예측.
h_t
  ├── LM head (t+1)       original
  ├── Medusa_1 head (t+2)
  ├── Medusa_2 head (t+3)
  └── Medusa_3 head (t+4)

Top-s per head → 후보 조합 → tree

3 EAGLE

  • single extra transformer layer를 draft로 사용
  • 입력: target hidden + 직전 embedding
  • Medusa 대비 accept rate ↑

4 Tree attention ★

정의 draft 후보가 트리 구조(여러 branch) → 한 번의 target forward로 모두 검증. tree mask로 각 node가 자신의 조상만 참조.
        ctx
          │
         t+1(A)         t+1(B)
        /    \            │
    t+2(a1) t+2(a2)   t+2(b1)
       │      │          │
    t+3(.) t+3(.)     t+3(.)

mask M[i,j] = 1  iff  j ∈ ancestors(i) ∪ {i}

5 Tree mask 예

node order: ctx, A, B, a1, a2, b1
M =
 [1 0 0 0 0 0]  ctx
 [1 1 0 0 0 0]  A
 [1 0 1 0 0 0]  B
 [1 1 0 1 0 0]  a1
 [1 1 0 0 1 0]  a2
 [1 0 1 0 0 1]  b1
(행 i가 attend 가능한 col j)

flash attention + custom mask → ↗ §14.

6 Verify 흐름

  1. tree의 linearized node sequence · position_ids
  2. causal tree mask
  3. target 1 forward → 각 node의 다음 token prob
  4. 각 경로(root→leaf)에 spec decode acceptance 적용
  5. 최장 accepted path 선택

7 Tree 크기 최적화

파라미터의미
depth예측 step 수
topkdepth각 level 후보
node 예산검증 cost
calibrationαnode로 가지치기

tree 크기는 target 1 forward batch의 seq dim으로 들어감.

8 변형 비교

방식draft 재료α 범위
Vanilla spec별도 소형 모델0.5~0.75
Medusamulti-head0.6~0.8
EAGLE1-layer on hidden0.75~0.85
Lookaheadn-gram cache가변

α는 논문 보고 범위 · 데이터 분포 의존.

9 커널 포인트

  • tree mask는 희소 → bit-packed matrix
  • position embedding은 tree depth 기반 (선형 idx 아님)
  • KV append는 accepted path node만
오용: tree root가 여러 개면 prefix 공유 깨짐 → 항상 root 1개.

1 Mask 역할

정의 attention score S = QKT에 특정 위치를 무효화. boolean mask 또는 additive bias 두 형태.
  • causal: j > i 금지
  • padding: pad 토큰 무시
  • tree / custom: 조상만 허용

2 Additive vs Boolean

additive:  S' = S + B
  B[i,j] = 0 (allow) · −∞ (block)

boolean: S'[i,j] = S[i,j] if M[i,j] else −∞ softmax(−∞) = 0 · 수학적 동일 · 메모리 형태만 차이

3 −∞ 대체값

  • FP16: −65504.0 또는 −1e4 (큰 값)
  • BF16: 동일, exp→0
  • INT8 quant: scale 고려 큰 음수

실제 −∞는 IEEE NaN 위험 → 큰 음수 상수 사용.

4 Causal mask 융합

// flash attn 내부 · score 타일별
if(j > i) s_ij = -1e4;
// 타일 경계 밖은 skip (조기 break)
if(j_tile_start > i_tile_end) continue;

flash attention은 causal을 타일 경계 비교로 loop 단축 (↗ V07 §4).

5 Sliding window

allow(i,j) iff max(0, i−w+1) ≤ j ≤ i
Mistral 계열: w = 4096 local causal · full cache보다 메모리 낮음

6 ALiBi bias

B[i,j] = −mh · |i − j|  (j ≤ i)
mh : head별 slope (등비수열) position embedding 대체 · 거리 기반 선형 penalty

7 Custom mask 커널

// packed bit mask: M[i, j/8] & (1 << (j%8))
bool allow = (M[i, j>>3] >> (j & 7)) & 1;
float s_ij = allow ? qk : -1e4;
// 또는 per-(i,j) additive bias load
s_ij += bias[i, j];

8 Mask 형태 비교

형태mem비용
implicit causal0O(1)
sliding w0O(1)
bit maskS²/81 bit load
FP32 bias4·S²load + add
tree bit maskT²/8spec decode

9 실패 모드

전부 mask인 row: softmax 분모 = 0 → NaN. 최소 1 slot allow 강제 또는 row 전체 skip.
bias dtype 혼용: FP16 bias + FP32 score 시 BF16로 통일 권장.

1 Forward 1-layer 순서 ★

x [T,d]
  │  RMSNorm (§8)
  ▼
x_n
  │  QKV_proj  (GEMM ↗ V06)
  ▼
q,k,v
  │  RoPE (§7) on q,k
  ▼
  │  KV cache append (§5)
  ▼
  │  Attention (↗ V07)
  ▼
a
  │  O_proj (GEMM)
  │  + residual (§8 fuse)
  ▼
x1
  │  RMSNorm
  ▼
  │  FFN: GEMM1 → SwiGLU (§9) → GEMM2
  │  or MoE: router (§1) → permute (§2)
  │         → grouped GEMM (§3) → unpermute
  ▼
  │  + residual
  ▼
xout → next layer

2 Sampling 순서

x_final [T,d]
  │  RMSNorm
  │  LM head [d,V]
  ▼
logits
  │  penalty · temperature (§10)
  ▼
  │  top-k · top-p · softmax
  ▼
token
  │  (spec decode: verify §12)
  ▼
output

3 Kernel 복잡도 표 ★

kernelFLOPmembound
RMSNormT·dT·dmemory
QKV projT·d·3dT·dcompute
RoPET·H·DT·H·Dmemory
Attention prefillT²·H·DT·H·Dcompute
Attention decodeS·H·DS·H·Dmemory
FFN GEMM1T·d·2HffT·dcompute
SwiGLUT·HffT·Hffmemory
FFN GEMM2T·Hff·dT·Hffcompute
MoE route+permT·N + T·kTmemory
Grouped GEMMT·k·d·Hffraggedcompute
SamplingV log VVmemory

4 Cross-ref 지도

주제상세
attention 본체↗ V07
GEMM · CUTLASS↗ V06
reduction · scan↗ V05
quantization 알고리즘↗ V10
paged KV · 서빙↗ V16
학습 · backward↗ V17
expert parallel 통신↗ V15
Hopper TMA/WGMMA↗ V04

5 핵심 수식 3선 ★

RoPE: q'2i,2i+1 = R(mθi) · q2i,2i+1
Spec: E[acc+1] = (1−αγ+1)/(1−α)
SwiGLU: y = Swish(Wgx) ⊙ (Wux)

6 흔한 실수 7선

  1. RoPE interleave/split 불일치
  2. RMSNorm var FP16 축적 → NaN
  3. MoE permute 후 atomic race
  4. KV cache beam parent rewire 누락
  5. softmax max subtract 생략 → overflow
  6. spec decode reject 후 KV rollback 누락
  7. top-p 후 renormalize 잊기

7 Fusion 우선순위

fuse
1QKV_proj + RoPE + KV write
2Add + RMSNorm (residual)
3GEMM1 + SwiGLU
4Attn + O_proj epilogue
5MoE permute + GEMM1
LLM layer 6단: 노Q로어오포 (름 · QKV · 프 · 텐션 · ·proj · 워드 FFN/MoE)
out-of-scope 재확인: vLLM 서빙 시스템은 ↗ V16, 학습 backward·optimizer는 ↗ V17.