// 한 block = 한 token · thread N개 병렬 __shared__ float s_logits[N]; int t = threadIdx.x; s_logits[t] = dot(x[tid], Wr[:,t]); __syncthreads(); // warp-level bitonic/heap top-k topk_warp(s_logits, k, idx, val); g = softmax(val);
N ≤ warp 시 single-warp top-k, 그 외 block-wide merge tree.
| 모델 | N | k | 특징 |
|---|---|---|---|
| Switch | 128 | 1 | hard routing |
| GShard | E | 2 | capacity cap |
| Mixtral-8×7B | 8 | 2 | softmax gate |
| DeepSeek-V3 | 256 | 8 | shared+routed |
| Qwen-MoE | 60 | 4 | fine-grained |
N 수치는 각 모델 카드 기준 · 2024~2025.
Inference 시는 aux loss 불필요 (학습 ↗ V17).
token [T,d] │ router ▼ logits [T,N] → top-k → idx [T,k], g [T,k] │ permute (↗ §2) ▼ grouped GEMM W1 (↗ §3) │ act (SwiGLU ↗ §9) ▼ grouped GEMM W2 │ unpermute + weighted sum ▼ y [T,d]
(k=1, T=6, N=3 예시)
[1] router 결과 (token order)
t0→E1 t1→E0 t2→E2 t3→E1 t4→E0 t5→E2
[2] histogram (expert별 카운트)
E0: 2 E1: 2 E2: 2
[3] exclusive prefix sum (offset)
E0: 0 E1: 2 E2: 4
[4] permuted (expert-major)
idx: 0 1 2 3 4 5
slot: E0 E0 E1 E1 E2 E2
src : t1 t4 t0 t3 t2 t5
└──┬──┘└──┬──┘└──┬──┘
group0 group1 group2
(→ grouped GEMM ↗ §3)
[5] unpermute (역매핑 · weighted)
y[t_src] = g · expert_out[slot]
(k>1 이면 k개 slot 합산)
// kernel 1: histogram __global__ hist(idx, cnt){ int t = global_tid; int e = idx[t]; atomicAdd(&cnt[e], 1); } // kernel 2: exclusive scan (block-wide) scan_exclusive(cnt, offset, N); // kernel 3: scatter __global__ scatter(idx, offset, perm, inv){ int t = global_tid, e = idx[t]; int slot = offset[e] + atomicAdd(&loc[e],1); perm[slot] = t; // src index inv[t] = slot; // reverse map }
| 버퍼 | shape | 의미 |
|---|---|---|
| idx | [T,k] | token→expert |
| cnt | [N] | expert별 token |
| offset | [N+1] | group 시작 |
| perm | [T·k] | slot→src token |
| inv | [T·k] | src→slot |
__global__ unpermute(out, expert_y, inv, g, T, k, d){ int t = blockIdx.x; int c = threadIdx.x; float acc = 0.f; for(int j=0; j<k; j++){ int slot = inv[t*k + j]; acc += g[t*k + j] * expert_y[slot*d + c]; } out[t*d + c] = acc; }
N 큰 경우(N>64) radix sort가 histogram+scatter보다 안정.
expert 0: m=17 × [d, H] expert 1: m= 5 × [d, H] expert 2: m=23 × [d, H] ... expert N: m=11 × [d, H]
| 방식 | 아이디어 | pros |
|---|---|---|
| Grouped | expert별 독립 GEMM | simple |
| Block-sparse | 하나의 sparse matmul | dropless (↗ §4) |
problem_sizes [N]: {(m_e, n, k)}
ptr_A [N]: X_0, X_1, ... X_{N-1}
ptr_B [N]: W_0, W_1, ... W_{N-1}
ptr_C [N]: Y_0, Y_1, ... Y_{N-1}
lda/ldb/ldc [N]
kernel: 각 CTA가 (e, tile_m, tile_n) 담당
→ tile scheduler가 ragged m_e 고려
→ 빈 expert(m_e=0) skip
↗ V06 §10 CUTLASS GroupedGemm & Persistent kernel 구조.
X[m_e,d] │ GEMM1 W1_e [d, 2H] (gate·up) ▼ Z[m_e,2H] │ SwiGLU (↗ §9) ▼ A[m_e, H] │ GEMM2 W2_e [H, d] ▼ Y[m_e, d]
| 축 | 분할 | 비고 |
|---|---|---|
| expert | CTA.z | ragged me |
| m | CTA.y | tileM |
| n | CTA.x | tileN |
| k | split-K opt | me 작으면 |
Mixtral-8×7B decode 단계는 T=1 (batch=1 per seq) → GEMV 위주.
full FFN weight W ∈ [N·H, d] W_stack: expert 0 ... expert N 을 행 방향 concat mask: [T·k, N] block-diagonal-ish X' [T·k, d] (permuted) W_stack [N·H, d] result [T·k, H] 은 block-sparse matmul: row r ∈ expert e 이면 W_e 구간만 접촉
| 방식 | drop | pad | kernel |
|---|---|---|---|
| Token choice | 有 | 有 | batched |
| Grouped | 無 | 有(tile) | N grouped |
| Megablocks | 無 | 無 | block-sparse 1 |
block tile: BM × BK (e.g. 128×128) row_ptr [T/BM + 1] : row당 active block 누적 col_idx [nnz] : 각 block의 열 tile idx values [nnz,BM,BK] : block data
Megablocks : SDD (Sparse-Dense→Dense) · DSD · DDS 세 변형.
| 단계 | 커널 |
|---|---|
| topology | permute 결과에서 block meta |
| SDD | block-sparse matmul |
| DSD/DDS | backward / dW (↗ V17) |
| 모델 | L | Hkv | Dh | KB/token |
|---|---|---|---|---|
| Llama3-8B | 32 | 8 | 128 | 128 |
| Llama3-70B | 80 | 8 | 128 | 320 |
| Mistral-7B | 32 | 8 | 128 | 128 |
FP16 · 2(K,V) 포함 · KB/token = 2·L·Hkv·Dh·2B.
| layout | order | access 특성 |
|---|---|---|
| head-major | [B,H,S,D] | head별 연속 · attention kernel 친화 |
| seq-major | [B,S,H,D] | append 쉬움 · scatter 효율 |
| paged | [page, H, D] | vLLM · ↗ V16 |
vLLM PagedAttention 세부 → ↗ V16 §3.
// new token k,v 1개 삽입 __global__ kv_update(K_cache, V_cache, k_new, v_new, b, s, h, D){ int tid = threadIdx.x; if(tid < D){ K_cache[b,h,s,tid] = k_new[b,h,tid]; V_cache[b,h,s,tid] = v_new[b,h,tid]; } }
보통 QKV projection epilogue와 fuse.
| Hq | Hkv | cache 절감 | |
|---|---|---|---|
| MHA | H | H | 1× |
| GQA-8 | H | H/G | G× |
| MQA | H | 1 | H× |
| MLA | H | — | ~D/r |
attention kernel 연결 → ↗ V07 §6.
| granularity | 메모리 | 정확도 |
|---|---|---|
| per-tensor | 1 | 낮음 |
| per-channel (head) | H | 중 |
| per-token | S | 높음 |
| per-token per-head | S·H | 최고 |
KV는 보통 per-token (outlier가 token 단위로 나타남).
K_q : [L, B, H_kv, S, D_h] INT8/INT4
K_s : [L, B, H_kv, S] FP16/BF16
decode 시:
fetch K_q[b,h,:S,:] 와 K_s[b,h,:S]
on-the-fly dequant:
K̂[t,d] = K_q[t,d] · K_s[t]
flash attention의 softmax online 흐름 유지 → ↗ V07 §4.
x → W_K → k_fp ─┐
│ online amax (per token · head)
▼
k_q, k_s
│ KV cache 저장 (INT8/INT4)
▼
attention (fused dequant)
│
▼
output FP16
| dtype | PPL Δ | KV mem |
|---|---|---|
| FP16 (baseline) | 0 | 1× |
| FP8 (E4M3) | ~0 | 0.5× |
| INT8 per-token | +0.02 | 0.5× |
| INT4 group-128 | +0.1~0.3 | 0.25× |
Δ는 WikiText2 PPL 기준 범위 · 모델/calibration 의존.
풀어 쓰면:
q'_{2i} = q_{2i}·c − q_{2i+1}·s
q'_{2i+1} = q_{2i}·s + q_{2i+1}·c
(c = cos(mθ_i), s = sin(mθ_i))
cos[Smax, D/2], sin[Smax, D/2]__global__ rope(Q, K, cos_t, sin_t, B, H, S, D){ int t = blockIdx.y, h = blockIdx.z; int i = threadIdx.x; // 0..D/2 float c = cos_t[t*D/2 + i]; float s = sin_t[t*D/2 + i]; float a = Q[..2*i], b = Q[..2*i+1]; Q[..2*i] = a*c - b*s; Q[..2*i+1] = a*s + b*c; }
QKV projection epilogue로 fuse하는 것이 표준.
| 방식 | pair | 모델 |
|---|---|---|
| Interleaved | (0,1)(2,3)… | 원 RoPE |
| Split-half | (i, i+D/2) | Llama 계열 |
같은 회전이지만 element layout이 다름 → 구현 시 불일치 주의.
YaRN: wavelength ≫ Ltrain 이면 extrapolate, 짧으면 interpolate.
QKV_proj → (q,k,v) │ ├─ rope(q, cos, sin) ┐ epilogue ├─ rope(k, cos, sin) │ 로 fuse └─ k,v → KV cache ┘ │ ▼ attention (↗ V07)
| 방식 | pass | 수치 |
|---|---|---|
| naive 2-pass | 2 | 정확 |
| 1-pass (sum, sumsq) | 1 | cancel risk |
| Welford 1-pass | 1 | 안정 |
| RMSNorm 1-pass | 1 | 단순 |
// block당 token 1개 · D thread 병렬 int t = threadIdx.x; float v = x[row, t]; float sq = v*v; // warp reduce sum for(int o=16; o>0; o>>=1) sq += __shfl_xor_sync(0xFFFFFFFF, sq, o); // block reduce across warps (shmem) ... float rms = rsqrtf(sq/D + eps); y[row,t] = v * rms * gamma[t];
half2/float4로 4원소/threadx_res = x + attn_out (add) y = RMSNorm(x_res) (norm) ───────────────────────────── fused: 한 kernel에서 read x, attn_out sum = x + attn_out write x_res to mem (다음 residual용) reduce sum(x_res²) y = x_res · rms · γ
| act | FFN 식 | mat 수 |
|---|---|---|
| ReLU | W2·ReLU(W1x) | 2 |
| GELU | W2·GELU(W1x) | 2 |
| SwiGLU | W2·(Swish(Wgx) ⊙ Wux) | 3 |
SwiGLU: gate / up / down 세 matrix · Hff는 보통 dense FFN의 2/3.
x [T, d] │ GEMM1 (gate·up · [d, 2H]) ▼ z [T, 2H] → a (first H), b (second H) │ SwiGLU: s = silu(a) * b ▼ s [T, H] │ GEMM2 [H, d] ▼ y [T, d]
| fuse | 이득 | 난이도 |
|---|---|---|
| GEMM1 + act | write 생략 | low (epilogue) |
| act + GEMM2 | read 생략 | high (prologue) |
| all three | 최대 | 대부분 비현실 |
실전은 GEMM1 epilogue에 SwiGLU fuse가 표준. ↗ V06 §11.
// GEMM1 output tile [BM, 2BN] half a = tile[m, n]; half b = tile[m, n + BN]; // second half half sa = a * __hsigmoid(a); // Swish half out = sa * b; // write out [BM, BN] → feeds GEMM2 smem[m, n] = out;
erff · 느림, 정확inference: tanh approx가 속도·정확도 균형.
| act | 대칭 | 음수 영역 |
|---|---|---|
| ReLU | Y | = 0 |
| GELU | ~Y | small leak |
| Swish | N | smooth neg lobe |
Swish는 non-monotone (min at x≈−1.28).
| 모델 | V |
|---|---|
| Llama 2 | 32,000 |
| Llama 3 | 128,256 |
| GPT-NeoX | 50,257 |
V ~ 10⁵ → reduction·top-k 커널의 주 축.
| 방식 | 복잡도 | 용도 |
|---|---|---|
| full sort | O(V log V) | k 큼 |
| bitonic top-k | O(V log² k) | k≤1024 |
| radix select | O(V · bits) | k 중간 |
| heap (warp) | O(V log k) | k≤32 |
실전 top-k ∈ {50, 64, 200} 대부분 radix select 또는 bitonic.
top-k(V') + top-p 조합이 일반적: k'≈1024로 V 축소 후 nucleus.
// block = 1 batch element · V thread (또는 분할)
penalty_apply(l);
scale(l, 1/τ);
(vals, idxs) = topk(l, K);
p = softmax(vals);
cum = scan_inclusive(p);
cutoff = lower_bound(cum, p_top);
mask(&p, cutoff);
renorm(p);
out = inv_cdf_sample(p, rand);
overflow 방지 · FP16 logit은 FP32로 승격 후 exp.
| method | 식 |
|---|---|
| greedy | argmax(l) |
| multinomial | inv-CDF |
| Gumbel-max | argmax(l + g) |
Gumbel: gi = −log(−log(ui)) · sort 없이 argmax 한 번.
// 각 beam b ∈ [0,W) for(b=0; b<W; b++){ logp[b,:] = log_softmax(logits[b,:]); cand[b,:] = beam_score[b] + logp[b,:]; } // flatten [W,V] → [W·V] 후 top-W (vals, flat_idx) = topk(cand, W); int parent_b = flat_idx / V; int tok = flat_idx % V;
| 버퍼 | shape |
|---|---|
| beam_score | [B, W] |
| beam_tokens | [B, W, Smax] |
| parent | [B, W, Smax] |
| KV cache | [L, B·W, H, S, D] |
effective batch = B·W → KV mem W배.
prev beams: 0 1 2 3 selected parents: 2 0 2 1 → new beam 0 ← old 2 의 cache → new beam 1 ← old 0 → new beam 2 ← old 2 (share) → new beam 3 ← old 1
paged KV는 block_table만 rewire · ↗ V16 §3.
| beam | sample | |
|---|---|---|
| quality | deterministic max | stochastic |
| diversity | 낮음 | 높음 |
| KV | W× | 1× |
| 용도 | MT · summarize | chat · creative |
while not done:
1) draft: x_1..x_γ ~ q(·|ctx) (γ forward, small)
2) target: p(·|ctx + x_1..x_i) for i=0..γ (1 forward, large, batched)
3) for i = 1..γ:
r ~ U(0,1)
if r < min(1, p_i(x_i) / q_i(x_i)): accept x_i
else: reject → resample from p_i - q_i, stop
4) if 모두 accept: bonus token ~ p_{γ+1}
| α | γ=3 | γ=4 | γ=6 |
|---|---|---|---|
| 0.5 | 1.88 | 1.94 | 1.98 |
| 0.7 | 2.53 | 2.80 | 3.08 |
| 0.85 | 3.06 | 3.56 | 4.26 |
Leviathan 2023 수식 · c는 보통 1/7B ~ 1/70B ratio.
input ctx + [x_1..x_γ] len = C+γ → target forward (하나의 batch) → logits [γ+1, V] → softmax row별 → acceptance loop (CPU or GPU) → update KV cache with accepted prefix
h_t ├── LM head (t+1) original ├── Medusa_1 head (t+2) ├── Medusa_2 head (t+3) └── Medusa_3 head (t+4) Top-s per head → 후보 조합 → tree
ctx
│
t+1(A) t+1(B)
/ \ │
t+2(a1) t+2(a2) t+2(b1)
│ │ │
t+3(.) t+3(.) t+3(.)
mask M[i,j] = 1 iff j ∈ ancestors(i) ∪ {i}
node order: ctx, A, B, a1, a2, b1 M = [1 0 0 0 0 0] ctx [1 1 0 0 0 0] A [1 0 1 0 0 0] B [1 1 0 1 0 0] a1 [1 1 0 0 1 0] a2 [1 0 1 0 0 1] b1 (행 i가 attend 가능한 col j)
flash attention + custom mask → ↗ §14.
| 파라미터 | 의미 |
|---|---|
| depth | 예측 step 수 |
| topkdepth | 각 level 후보 |
| node 예산 | 검증 cost |
| calibration | αnode로 가지치기 |
tree 크기는 target 1 forward batch의 seq dim으로 들어감.
| 방식 | draft 재료 | α 범위 |
|---|---|---|
| Vanilla spec | 별도 소형 모델 | 0.5~0.75 |
| Medusa | multi-head | 0.6~0.8 |
| EAGLE | 1-layer on hidden | 0.75~0.85 |
| Lookahead | n-gram cache | 가변 |
α는 논문 보고 범위 · 데이터 분포 의존.
−65504.0 또는 −1e4 (큰 값)실제 −∞는 IEEE NaN 위험 → 큰 음수 상수 사용.
// flash attn 내부 · score 타일별 if(j > i) s_ij = -1e4; // 타일 경계 밖은 skip (조기 break) if(j_tile_start > i_tile_end) continue;
flash attention은 causal을 타일 경계 비교로 loop 단축 (↗ V07 §4).
// packed bit mask: M[i, j/8] & (1 << (j%8)) bool allow = (M[i, j>>3] >> (j & 7)) & 1; float s_ij = allow ? qk : -1e4; // 또는 per-(i,j) additive bias load s_ij += bias[i, j];
| 형태 | mem | 비용 |
|---|---|---|
| implicit causal | 0 | O(1) |
| sliding w | 0 | O(1) |
| bit mask | S²/8 | 1 bit load |
| FP32 bias | 4·S² | load + add |
| tree bit mask | T²/8 | spec decode |
x [T,d] │ RMSNorm (§8) ▼ x_n │ QKV_proj (GEMM ↗ V06) ▼ q,k,v │ RoPE (§7) on q,k ▼ │ KV cache append (§5) ▼ │ Attention (↗ V07) ▼ a │ O_proj (GEMM) │ + residual (§8 fuse) ▼ x1 │ RMSNorm ▼ │ FFN: GEMM1 → SwiGLU (§9) → GEMM2 │ or MoE: router (§1) → permute (§2) │ → grouped GEMM (§3) → unpermute ▼ │ + residual ▼ xout → next layer
x_final [T,d] │ RMSNorm │ LM head [d,V] ▼ logits │ penalty · temperature (§10) ▼ │ top-k · top-p · softmax ▼ token │ (spec decode: verify §12) ▼ output
| kernel | FLOP | mem | bound |
|---|---|---|---|
| RMSNorm | T·d | T·d | memory |
| QKV proj | T·d·3d | T·d | compute |
| RoPE | T·H·D | T·H·D | memory |
| Attention prefill | T²·H·D | T·H·D | compute |
| Attention decode | S·H·D | S·H·D | memory |
| FFN GEMM1 | T·d·2Hff | T·d | compute |
| SwiGLU | T·Hff | T·Hff | memory |
| FFN GEMM2 | T·Hff·d | T·Hff | compute |
| MoE route+perm | T·N + T·k | T | memory |
| Grouped GEMM | T·k·d·Hff | ragged | compute |
| Sampling | V log V | V | memory |
| 주제 | 상세 |
|---|---|
| attention 본체 | ↗ V07 |
| GEMM · CUTLASS | ↗ V06 |
| reduction · scan | ↗ V05 |
| quantization 알고리즘 | ↗ V10 |
| paged KV · 서빙 | ↗ V16 |
| 학습 · backward | ↗ V17 |
| expert parallel 통신 | ↗ V15 |
| Hopper TMA/WGMMA | ↗ V04 |
| 순 | fuse |
|---|---|
| 1 | QKV_proj + RoPE + KV write |
| 2 | Add + RMSNorm (residual) |
| 3 | GEMM1 + SwiGLU |
| 4 | Attn + O_proj epilogue |
| 5 | MoE permute + GEMM1 |