LESSON 11 · 2026.04.22 · L4
vLLM 의 Paged Attention 을 Triton 으로 다시 짜보고, vLLM 의 실수를 반복했다
독립적으로 같은 refactor 에 수렴하는 게 설계가 맞다는 증거라는 이야기.
GPU · L4 · sm_89
stack · torch 2.11 + Triton 3.6
result · LLaMA-3-8B 에서 SDPA 를 14 % 이김
들어가며 — contiguous attention 에 끼워넣는 한 줄
Lesson 09 에서 (B, H, N, d) 4-D MHA + causal flash attention 을 Triton 으로 짰다. 그 커널은 torch.nn.functional.scaled_dot_product_attention (속을 까보면 Tri Dao 의 FA-2 CUDA) 의 78–90 % 속도. 벤치 수치만 보면 "잘 짠 커널".
문제는 LLM 서빙 (vLLM / SGLang / TensorRT-LLM) 은 contiguous KV cache 를 쓰지 않는다. sequence 별로 길이가 천차만별이고 들어왔다 나갔다 하니까 pre-allocate 하면 GPU 메모리 70 % 이상이 fragmentation 으로 날아간다. vLLM 이 SOSP '23 에서 밀어붙인 해법은 단순:
KV cache 를 고정 크기 block 의 pool 로 쪼개고, sequence 별 block_table 로 간접 참조한다.
(기존)
K: (B, H, N, d) ← seq 별 연속
(paged)
K_cache: (num_blocks, block_size, H_kv, d) ← 블록 풀
block_table: (B, max_blocks_per_seq) ← seq → physical block id
context_lens:(B,) ← 각 seq 유효 길이
attention 커널이 바뀌는 건 단 한 줄:
# 기존: K[b, h, start_n:end_n, :] 를 연속으로 로드
# paged: block_table[b, logical_blk] 조회 → phys_blk →
# K_cache[phys_blk, :, kv_head, :] 로드
이번 글은 "이 한 줄만 끼워넣으면 vLLM paged attention 이 되는지" 를 측정하는 세션. 되긴 되는데, 가는 길에 두 번 틀렸다. 그게 알맹이다.
Phase 1–2 · correctness 는 너무 쉽게 통과했다 (이게 함정)
작업 순서:
- PyTorch reference (Python loop 로 block_table 돌면서 gather 후 표준 attention) 로 oracle 만들기.
- Triton 커널: 처음엔
grid = (B, H_q) — lesson 09 의 (cdiv(N, BM), H, B) 에서 "decode 니까 N=1" 로 자연스럽게 축소.
- GQA 지원:
GQA_GROUP_SIZE = H_q // H_kv 를 constexpr 로, kv_head = pid_h // GQA_GROUP_SIZE. 총 4 줄 변경.
Correctness 벤치 (16 shape × 2 dtype = 32/32 PASS):
| shape | B | H_q | H_kv | group | fp16 max diff | fp32 max diff |
| MHA | 1–4 | 32 | 32 | 1 | 9.8e-04 | 3.6e-07 |
| LLaMA-3-8B GQA | 2 | 32 | 8 | 4 | 2.4e-04 | 3.6e-07 |
| LLaMA-70B GQA | 4 | 64 | 8 | 8 | 3.1e-05 | 3.6e-07 |
| MQA | 2 | 16 | 1 | 16 | 1.9e-06 | 3.6e-07 |
fp16 에서 1e-3 이하, fp32 에서 1e-7 수준. 모델 공장 기준 모든 게 통과.
이 시점의 판단: "되는구나. 끝나겠다." 이게 함정이었다.
Phase 3 · 속도 벤치가 구조 버그를 드러냈다
Correctness 만 보고 끝냈으면 이 커널이 LLaMA-3-8B 에서 SDPA 대비 2–3 배 느린 상태로 shipping 됐을 것.
| shape | B | group | SDPA ms | paged best | gap |
| llama7b (MHA) | 8 | 1 | 1.143 | 1.06× | -7 % ✅ |
| llama38b (GQA) | 8 | 4 | 0.271 | 0.86× | +16 % ⚠ |
| llama38b (GQA) | 32 | 4 | 1.147 | 0.73× | +37 % ⚠ |
| llama70b (GQA) | 4 | 8 | 0.069 | 0.31× | +217 % ❌ |
| llama70b (GQA) | 8 | 8 | 0.533 | 0.46× | +117 % ❌ |
| mqa | 16 | 32 | 0.062 | 0.07× | +1316 % 💥 |
MHA 는 parity. GQA 부터 gap 이 GROUP_SIZE 에 거의 선형.
- group=4 → +16–37 %
- group=8 → +117–217 %
- group=32 → +1316 % (13× 느림)
이 선형성이 진단의 핵심. 임의 regression 이 아니라 구조적.
원인 — (B, H_q) 그리드가 KV 를 GROUP 번 재로드한다
Grid (B, H_q) 는 한 program 이 한 (batch, query head). GQA 의 GROUP_SIZE 개의 query head 는 같은 KV head 의 캐시를 공유하지만, 각각의 program 이 독립적으로 block_table 를 돌며 같은 K/V block 을 DRAM 에서 다시 로드한다. GROUP_SIZE 배의 redundant DRAM 트래픽.
SDPA 는 왜 안 느리나? Contiguous KV 니까 L2 prefetcher 가 중복 로드를 흡수. 실제로 MQA 의 SDPA 는 L4 에서 542 GB/s — L4 DRAM peak (300 GB/s) 의 1.8 배. 이 숫자는 DRAM 단독으로 불가능. L2 가 반 이상을 먹는다는 증거.
우리 paged 는 block_table indirection 때문에 L2 prefetcher 가 패턴 인식을 못한다. 그래서 redundant load 가 전부 DRAM 으로.
교훈 #1
Correctness 가 통과해도 구조 문제는 속도로만 보인다. allclose 는 grid 설계에 무관. Reference 와의 비교만으론 이 버그 절대 안 잡혔다. 벤치 테이블 + SDPA gap 컬럼을 리포트로 남겨야 문제가 보인다.
Phase 3.5 · Grid 하나 바꿨을 뿐인데 (또 다른 버그가 튀어나왔다)
고칠 게 명확하다. Grid 를 (B, H_kv) 로 바꾸고, program 안에서 GQA group 의 GROUP_SIZE query head 를 한 번에 처리한다. K/V block 은 program 당 한 번만 로드.
grid = (B, H_kv) # 프로그램 수 / GROUP
q = tl.load(q_ptrs) # (GROUP, HEAD_DIM) — 2D tile
# 블록 루프 안:
scores = tl.dot(q_scaled, tl.trans(k)) # (GROUP, BLOCK)
acc += tl.dot(p.to(v.dtype), v) # (GROUP, HEAD)
이 변경 자체는 ~20 줄. 벤치:
| shape | Phase 3 gap | Phase 3.5 gap |
| llama38b (group=4) B=8 | +161 % | -14 % ← SDPA 를 이김 |
| llama38b (group=4) B=32 | +86 % | +3 % (parity) |
| llama70b (group=8) B=4 | -2 % | -1 % |
| llama70b (group=8) B=8 | -1 % | -1 % |
| mqa (group=32) | +1316 % | +85 % |
LLaMA-3-8B 의 -14 % — cuDNN / FA-2 를 우리 Triton 커널이 이기는 shape 가 생겼다. production 에서 흔한 mid-range batch.
그런데 fp32 correctness 가 깨졌다
- fp16: 모든 shape 여전히 통과 ✅
- fp32 MQA: max diff 4.1e-04 ← 원래 3.6e-07 이었다. 세 자릿수 나쁨.
당황. grid 만 바꿨는데 fp32 가 왜 깨지나?
진짜 원인 — tl.dot(fp32, fp32) 의 기본값은 TF32 (sm_80+)
Ampere 이상에서 Triton 은 tl.dot 의 fp32 × fp32 를 자동으로 TF32 로 하향 (10-bit mantissa). input_precision 을 명시 안 하면 기본이 TF32. MQA 의 (GROUP=16, BLOCK=16, HEAD=64) score tile 에서 summation 이 80–100 번 누적되면 10-bit 절단 오차가 쌓여서 softmax max 후보 경계가 4e-4 편향.
해결:
if IS_FP32:
# 3-pass TF32 스택 (2 low-bit 보정) 으로 IEEE 재구성 — 3× 느림
scores = tl.dot(q_scaled, tl.trans(k), input_precision="ieee")
else:
# fp16/bf16 MMA — default 는 이미 IEEE fp16
scores = tl.dot(q_scaled, tl.trans(k)).to(tl.float32)
fp32 max diff: 4.1e-04 → 3.6e-07 복구. fp16 speed 손해 없음.
왜 Phase 3 에선 안 보였나
Phase 3 은 manual broadcast (tl.sum(q * k)) 로 score 를 계산했고, 이건 순수 fp32 path. TF32 거치지 않음. Phase 3.5 에서 tl.dot 을 도입한 순간 처음 노출된 버그.
교훈 #2
두 독립 버그가 연쇄로 숨을 수 있다. Grid bug 를 안 고쳤으면 TF32 bug 가 안 나타남. 고치자마자 나타남. 큰 refactor 뒤엔 correctness 를 반드시 재실행.
Phase 4 · vLLM 소스 읽고 보니 나는 vLLM 역사를 miniature 로 재현했다
Phase 3.5 가 끝난 뒤에야 vLLM 소스를 읽었다. 일부러 — 독립적으로 설계한 뒤 vLLM 과 비교해서 수렴하는지 보고 싶었다.
읽은 파일:
| # | 파일 | 역할 |
| v1 | csrc/attention/paged_attention_v1.cu | 오리지널 CUDA 커널 (2023). Per-query-head 그리드. |
| v2 | csrc/attention/paged_attention_v2.cu | ctx 축 split-k + reduce 커널. |
| triton | vllm/v1/attention/ops/triton_unified_attention.py | 현행 Triton 구현. Per-KV-head 그리드. |
발견 1 — 내 Phase 3.5 는 vLLM 의 현행 Triton 커널과 axis-for-axis 매치
| axis | vLLM Triton unified | 내 Phase 3.5 |
| grid | (Σ q_blocks, H_kv) | (B, H_kv) |
| Q tile | (BLOCK_M, HEAD) | (GROUP, HEAD) |
| Matmul | tl.dot(Q, K) / tl.dot(P, V) | tl.dot (GROUP≥4) 또는 manual fallback |
| Softmax | per-row fp32 running (M, L, acc) | 동일 |
| KV layout | (num_blks, blk_size, H_kv, d) | 동일 |
vLLM 의 axis-0 가 (batch × query block) 이고 내 axis-0 가 pure batch 인 차이는, vLLM 은 prefill 까지 한 커널로 처리하니까 query 길이가 가변. 나는 decode 만 하니까 q_len=1 로 고정, batch 가 바로 0축. 구조는 같음, scope 가 다를 뿐.
발견 2 — vLLM 자신이 나와 같은 refactor 를 거쳤다
paged_attention_v1.cu:86:
dim3 grid(num_heads, num_seqs, 1);
이것이 per-query-head grid. 내 Phase 3 의 디자인과 동일. 2023 년 vLLM 이 shipping 한 오리지널.
당시엔 왜 괜찮았는가:
- LLM 시장의 대부분이 MHA (H_kv == H_q) → group redundancy 구조적으로 없음.
- 몇 안 되던 GQA 모델은 KV 가 작아서 L2 가 먹어줌.
- Triton 2.x 의 MMA 가 아직 CUDA 와 경쟁할 수 없어서 CUDA 가 정답.
LLaMA-2-chat, LLaMA-3, Mistral 이 GQA 로 shipping 되면서 per-query-head grid 가 병목. vLLM 은 Triton 으로 옮기면서 (q_block, H_kv) 로 restructure — 내가 Phase 3 → 3.5 에서 한 refactor 와 정확히 같다.
발견 3 — 내가 한 것 중 vLLM 이 안 한 것 하나
vLLM 은 tl.dot 에 precision 을 명시하지 않는다. production 이 fp16/bf16 만 돌리니까 문제 없음. 하지만 누가 fp32 로 돌리면 4e-4 오차. 내 IS_FP32 branching + input_precision="ieee" 는 lesson context 에서만 중요하지만, 그래도 걸려냈다.
발견 4 — 내가 못 한 것: ctx 축 split-k
MQA 잔여 +85 % gap 의 원인. SDPA 는 이 shape 에서 698 GB/s (DRAM 의 2.3×) — L2 가 반 이상 흡수. 1 KV head × 4k tokens × 128 dim × 2 B fp16 = 1 MB 가 L2 48 MB 에 쉽게 들어가고 32 query heads 가 공유.
내 paged 는 block_table indirection 때문에 같은 L2 reuse 가 안 된다. 구조적 한계. grid 만 바꿔서는 못 닫음. vLLM 의 v2 는 ctx 축을 partition 하고 reduce 커널로 softmax 재조합. 이걸 Lesson 12 로 이월.
교훈 #3
Paper + HW + workload 만으로 짜도 실전 소스와 수렴하면 그건 설계가 맞다는 증거. vLLM 의 현행 Triton 포트와 axis-for-axis 매치. 이건 "내가 똑똑한 것" 이 아니라 "맞는 답이 하나" 라는 것.
최종 숫자 (L4 sm_89, fp16, warmup=50, iters=200)
| shape | B | group | SDPA ms | paged best (bs) | gap |
| llama7b MHA | 8 | 1 | 1.322 | 1.227 (bs=16) | -7 % |
| llama7b MHA ctx=8k | 8 | 1 | 6.115 | 4.927 (bs=64) | -19 % |
| llama38b GQA | 8 | 4 | 0.308 | 0.264 (bs=16) | -14 % |
| llama38b GQA | 32 | 4 | 1.163 | 1.197 (bs=128) | +3 % |
| llama70b GQA | 4 | 8 | 0.049 | 0.048 (bs=128) | -1 % |
| llama70b GQA | 8 | 8 | 0.532 | 0.526 (bs=16) | -1 % |
| mqa | 16 | 32 | 0.048 | 0.089 (bs=128) | +85 % |
Correctness: 32/32 PASS. LLaMA-3-8B B=8 ctx=2k 에서 SDPA (= Tri Dao FA-2 CUDA) 를 14 % 이기는 275 줄 Triton 커널. LLaMA-70B 는 parity. MQA 는 +85 % (split-k 로 닫힐 residual).
세 가지 남는 것
(1) Correctness 가 통과하는 것이 "맞다" 를 의미하지 않는다
32/32 PASS 직후 "끝났다" 고 판단했으면 GQA shape 에서 2–13 배 느린 커널을 ship 했다. 이 함정은 속도 벤치에 SDPA gap 컬럼 이 없었으면 못 잡는다. allclose + gap 을 함께 리포트.
(2) 버그가 버그 뒤에 숨는다
Grid bug (Phase 3) 와 TF32 bug (Phase 3.5) 는 독립적이었고 순차적으로만 드러났다. 큰 refactor 뒤엔 correctness 를 반드시 재실행 — "한 버그 고쳤으니 안전" 은 정확히 저 상황에서 틀린다.
(3) 독립적으로 수렴하는 게 설계의 증거
vLLM 소스를 Phase 3.5 끝난 뒤에 읽었는데 axis-for-axis 매치. 이건 내가 똑똑한 게 아니라, 맞는 답이 하나고 같은 툴 (Triton) + 같은 HW (sm_80+) + 같은 workload (GQA) 면 거기로 수렴한다. 오히려 이 convergence 를 명시적으로 기록하는 게 credible — "ecosystem 이 이미 한 refactor 를 miniature 로 재현했다" 는 스토리.
다음 세션
남은 MQA +85 % gap 을 ctx 축 split-k (vLLM v2) 로 닫는다 — Lesson 12.
LESSON 11 · 2026.04.22 · L4
Rewriting vLLM's Paged Attention in Triton — and repeating vLLM's mistake
A story about how independently converging on the same refactor is the evidence the design is right.
GPU · L4 · sm_89
stack · torch 2.11 + Triton 3.6
result · beats SDPA by 14 % on LLaMA-3-8B
Prologue — the one line you slip into contiguous attention
In Lesson 09 I wrote a (B, H, N, d) 4-D MHA + causal flash attention in Triton. That kernel runs at 78–90 % of torch.nn.functional.scaled_dot_product_attention (which, under the hood, is Tri Dao's FA-2 CUDA). By the bench alone, "a good kernel."
The problem: LLM serving (vLLM / SGLang / TensorRT-LLM) doesn't use a contiguous KV cache. Sequences vary wildly in length and churn in and out, so pre-allocating burns 70 %+ of GPU memory to fragmentation. The fix vLLM pushed through at SOSP '23 is simple:
Split the KV cache into a pool of fixed-size blocks, and use a per-sequence block_table to reference them indirectly.
(before)
K: (B, H, N, d) ← contiguous per sequence
(paged)
K_cache: (num_blocks, block_size, H_kv, d) ← block pool
block_table: (B, max_blocks_per_seq) ← seq → physical block id
context_lens:(B,) ← valid length per seq
What changes in the attention kernel is exactly one line:
# before: load K[b, h, start_n:end_n, :] contiguously
# paged: look up block_table[b, logical_blk] → phys_blk →
# load K_cache[phys_blk, :, kv_head, :]
This essay measures "does slipping that one line in actually yield vLLM paged attention?" It does — but I was wrong twice on the way. That's the essay.
Phase 1–2 · correctness passed too easily (this was the trap)
Order of work:
- Build a PyTorch reference (Python loop through block_table, gather, then standard attention) as the oracle.
- Triton kernel: initially
grid = (B, H_q) — the natural reduction of Lesson 09's (cdiv(N, BM), H, B) with "decode, so N=1."
- GQA support: add
GQA_GROUP_SIZE = H_q // H_kv as a constexpr and kv_head = pid_h // GQA_GROUP_SIZE. Four lines changed total.
Correctness bench (16 shapes × 2 dtypes = 32/32 PASS):
| shape | B | H_q | H_kv | group | fp16 max diff | fp32 max diff |
| MHA | 1–4 | 32 | 32 | 1 | 9.8e-04 | 3.6e-07 |
| LLaMA-3-8B GQA | 2 | 32 | 8 | 4 | 2.4e-04 | 3.6e-07 |
| LLaMA-70B GQA | 4 | 64 | 8 | 8 | 3.1e-05 | 3.6e-07 |
| MQA | 2 | 16 | 1 | 16 | 1.9e-06 | 3.6e-07 |
Sub-1e-3 on fp16, 1e-7 on fp32. Everything any model shop would accept.
At this point the judgement was: "It works. We're done." That was the trap.
Phase 3 · the speed bench exposed a structural bug
Had I stopped at correctness, I would have shipped a kernel 2–3× slower than SDPA on LLaMA-3-8B.
| shape | B | group | SDPA ms | paged best | gap |
| llama7b (MHA) | 8 | 1 | 1.143 | 1.06× | -7 % ✅ |
| llama38b (GQA) | 8 | 4 | 0.271 | 0.86× | +16 % ⚠ |
| llama38b (GQA) | 32 | 4 | 1.147 | 0.73× | +37 % ⚠ |
| llama70b (GQA) | 4 | 8 | 0.069 | 0.31× | +217 % ❌ |
| llama70b (GQA) | 8 | 8 | 0.533 | 0.46× | +117 % ❌ |
| mqa | 16 | 32 | 0.062 | 0.07× | +1316 % 💥 |
MHA parity. From GQA onward, the gap is roughly linear in GROUP_SIZE.
- group=4 → +16–37 %
- group=8 → +117–217 %
- group=32 → +1316 % (13× slower)
That linearity is the diagnosis. Not a random regression — structural.
The cause — (B, H_q) grid reloads KV GROUP times
A grid of (B, H_q) means one program per (batch, query head). The GROUP_SIZE query heads in a GQA group share the same KV head, but each program independently walks block_table and reloads the same K/V block from DRAM again. Redundant DRAM traffic scales with GROUP_SIZE.
Why isn't SDPA slow? Contiguous KV lets the L2 prefetcher absorb the redundant loads. SDPA on MQA reaches 542 GB/s on L4 — 1.8× the DRAM peak (300 GB/s). That throughput is impossible from DRAM alone. L2 is absorbing more than half.
Our paged kernel breaks the L2 prefetch pattern because of block_table indirection, so the redundant loads all go to DRAM.
Takeaway #1
Correctness can pass and the structural bug only shows in speed. allclose is blind to grid design. The bug never surfaces by comparing to a reference. You need the bench table + the SDPA gap column in the report for the issue to become visible.
Phase 3.5 · I only changed the grid (and another bug popped up)
The fix is clear: change the grid to (B, H_kv) and handle the GROUP_SIZE query heads of the GQA group at once inside the program. K/V blocks get loaded once per program.
grid = (B, H_kv) # program count ÷ GROUP
q = tl.load(q_ptrs) # (GROUP, HEAD_DIM) — 2D tile
# inside the block loop:
scores = tl.dot(q_scaled, tl.trans(k)) # (GROUP, BLOCK)
acc += tl.dot(p.to(v.dtype), v) # (GROUP, HEAD)
The change itself is ~20 lines. Bench:
| shape | Phase 3 gap | Phase 3.5 gap |
| llama38b (group=4) B=8 | +161 % | -14 % ← beats SDPA |
| llama38b (group=4) B=32 | +86 % | +3 % (parity) |
| llama70b (group=8) B=4 | -2 % | -1 % |
| llama70b (group=8) B=8 | -1 % | -1 % |
| mqa (group=32) | +1316 % | +85 % |
LLaMA-3-8B at -14 % — our Triton kernel beats cuDNN / FA-2 at a shape common in production.
But fp32 correctness broke
- fp16: all shapes still pass ✅
- fp32 MQA: max diff 4.1e-04 ← it was 3.6e-07 before. Three orders of magnitude worse.
Confused. Why would fp32 break from just changing the grid?
The real cause — tl.dot(fp32, fp32) defaults to TF32 on sm_80+
On Ampere and later, Triton silently downgrades fp32 × fp32 in tl.dot to TF32 (10-bit mantissa). Without specifying input_precision, the default is TF32. On MQA's (GROUP=16, BLOCK=16, HEAD=64) score tile, 80–100 summation steps accumulate 10-bit truncation error and bias the softmax-max boundary by 4e-4.
Fix:
if IS_FP32:
# 3-pass TF32 stack (2 low-bit corrections) to reconstruct IEEE — 3× slower
scores = tl.dot(q_scaled, tl.trans(k), input_precision="ieee")
else:
# fp16/bf16 MMA — default is already IEEE fp16
scores = tl.dot(q_scaled, tl.trans(k)).to(tl.float32)
fp32 max diff: 4.1e-04 → 3.6e-07 recovered. No fp16 speed cost.
Why didn't this show up in Phase 3?
Phase 3 computed scores via manual broadcast (tl.sum(q * k)) — a pure fp32 path. No TF32 detour. The bug first showed up the moment Phase 3.5 introduced tl.dot.
Takeaway #2
Two independent bugs can hide in sequence. If you don't fix the grid bug, the TF32 bug stays invisible. It appears the instant you do. After a big refactor, always rerun the correctness bench.
Phase 4 · reading vLLM's source, I realized I had reproduced vLLM's history in miniature
I read the vLLM source only after finishing Phase 3.5 — deliberately, to see whether an independent design would converge with vLLM's.
Files I read:
| # | file | role |
| v1 | csrc/attention/paged_attention_v1.cu | Original CUDA kernel (2023). Per-query-head grid. |
| v2 | csrc/attention/paged_attention_v2.cu | ctx-axis split-k + reduce kernel. |
| triton | vllm/v1/attention/ops/triton_unified_attention.py | Current Triton implementation. Per-KV-head grid. |
Finding 1 — my Phase 3.5 matches vLLM's current Triton kernel axis-for-axis
| axis | vLLM Triton unified | my Phase 3.5 |
| grid | (Σ q_blocks, H_kv) | (B, H_kv) |
| Q tile | (BLOCK_M, HEAD) | (GROUP, HEAD) |
| Matmul | tl.dot(Q, K) / tl.dot(P, V) | tl.dot (GROUP≥4) or manual fallback |
| Softmax | per-row fp32 running (M, L, acc) | same |
| KV layout | (num_blks, blk_size, H_kv, d) | same |
The only real difference: vLLM's axis-0 is (batch × query block) because they handle prefill in the same kernel (variable q_len); mine is pure batch because decode pins q_len=1. Same structure, different scope.
Finding 2 — vLLM itself went through the same refactor
paged_attention_v1.cu:86:
dim3 grid(num_heads, num_seqs, 1);
That is a per-query-head grid — identical in design to my Phase 3. vLLM's 2023 ship.
Why it was fine then:
- Most of the LLM market was MHA (H_kv == H_q). Group redundancy didn't structurally exist.
- The few GQA models had small KVs that L2 could absorb.
- Triton 2.x's MMA couldn't yet compete with CUDA, so CUDA was the answer.
As LLaMA-2-chat, LLaMA-3, and Mistral shipped with GQA, the per-query-head grid became the bottleneck. vLLM moved to Triton and restructured to (q_block, H_kv) — exactly the refactor I did from Phase 3 → 3.5.
Finding 3 — one thing I did that vLLM didn't
vLLM doesn't specify precision on tl.dot. Production runs fp16/bf16 only, so it doesn't matter. But if someone flows fp32 through that path on sm_80+, they get the same 4e-4 error I caught. My IS_FP32 branching + input_precision="ieee" is only relevant in a lesson context, but still caught it.
Finding 4 — what I didn't do: split-k over the ctx axis
The source of my residual +85 % MQA gap. SDPA hits 698 GB/s on this shape (2.3× DRAM) — L2 absorbs more than half. 1 KV head × 4k tokens × 128 dim × 2 B fp16 = 1 MB fits trivially into L4's 48 MB L2 and is shared by 32 query heads.
My paged kernel can't replicate that L2 reuse because of block_table indirection — structural. The grid fix alone can't close it. vLLM's v2 partitions the ctx axis and uses a reduce kernel to recombine softmax. Pushed to Lesson 12.
Takeaway #3
If you can derive the design from paper + HW + workload alone, and it converges with the production source, the convergence is the evidence. Axis-for-axis match with vLLM's current Triton port. This isn't "I'm clever" — it's "the right answer is one."
Final numbers (L4 sm_89, fp16, warmup=50, iters=200)
| shape | B | group | SDPA ms | paged best (bs) | gap |
| llama7b MHA | 8 | 1 | 1.322 | 1.227 (bs=16) | -7 % |
| llama7b MHA ctx=8k | 8 | 1 | 6.115 | 4.927 (bs=64) | -19 % |
| llama38b GQA | 8 | 4 | 0.308 | 0.264 (bs=16) | -14 % |
| llama38b GQA | 32 | 4 | 1.163 | 1.197 (bs=128) | +3 % |
| llama70b GQA | 4 | 8 | 0.049 | 0.048 (bs=128) | -1 % |
| llama70b GQA | 8 | 8 | 0.532 | 0.526 (bs=16) | -1 % |
| mqa | 16 | 32 | 0.048 | 0.089 (bs=128) | +85 % |
Correctness: 32/32 PASS. A 275-line Triton kernel that beats SDPA (= Tri Dao FA-2 CUDA) by 14 % on LLaMA-3-8B B=8 ctx=2k. Parity on LLaMA-70B. MQA at +85 % — a residual gap that split-k will close.
Three things that stick
(1) Correctness passing doesn't mean "correct"
If I had called it done at 32/32 PASS, I would have shipped a kernel 2–13× slower on GQA shapes. This trap doesn't get caught without the SDPA-gap column in the speed bench. Report allclose and gap together.
(2) Bugs hide behind bugs
The grid bug (Phase 3) and the TF32 bug (Phase 3.5) were independent and only surfaced sequentially. After a big refactor, rerun correctness — always. "Fixed one, so we're safe" is exactly wrong in this situation.
(3) Converging independently is the evidence of design
I read vLLM's source only after finishing Phase 3.5, and found an axis-for-axis match. That's not me being clever; it's "the right answer is one" when the tool (Triton) + HW (sm_80+) + workload (GQA) are fixed. Recording this convergence is actually the credible story — "the ecosystem already did this refactor and I reproduced it in miniature."
Next session
Close the residual MQA +85 % gap with ctx-axis split-k (vLLM v2) — Lesson 12.