LLM 추론의 attention 은 한 모양이 아니다. prefill 은 long Q × long KV 의 GEMM 류, decode 는 short Q × long KV 의 GEMV 류. 거기에 GQA / MQA / paged KV / sparse / chunked prefill 같은 변종이 더해진다. FlashInfer 의 답 — 이 모든 attention 을 한 통일 API + 한 generator 가 만든다. Zihao Ye 가 깐 unified attention engine 의 내부 — query tile 의 다양한 크기, paged KV cache 와 shared prefix, sparse attention, JIT compilation, vLLM/SGLang 통합까지.
L036 의 FlashAttention 3 는 학습 의 attention 을 풀었다. inference 에서는 attention 의 모양이 학습 시점보다 훨씬 다양하다. prefill 의 long-Q-long-KV, decode 의 short-Q-long-KV, GQA / MQA 의 head 비대칭, paged KV 의 비연속 메모리, sparse 의 토큰 부분 — 이게 모두 attention.
이전까지의 답 — 각 변종마다 별도 커널. vLLM 이 자기 paged attention, FA repo 가 다른 변종, Triton 이 또 다른 변종. FlashInfer 의 답 — 한 통일 API 가 이 모든 변종을 받아들이고, 내부에서 가장 적합한 커널을 dispatch / generate 한다.
Zihao 의 입장 — “inference 의 attention 은 ‘shape 의 cartesian 곱’ 이다. (prefill/decode) × (GQA/MQA) × (paged/contiguous) × (sparse/dense) × (chunked/full). 이 곱을 사람이 한 본씩 짤 수 없다 — 코드 generation 으로 푼다.” FlashInfer 는 그 generation 의 추상.
그래서 강의 끝에 손에 잡혀야 할 자산 — (1) attention shape 의 차원들에 대한 mental model, (2) paged KV 의 table 구조와 shared prefix 의 결합, (3) JIT 가 코드 양을 어떻게 줄이는지, (4) vLLM / SGLang 같은 엔진과의 표준화.
FA-style tiled. tile size 128 이 자연스러움. compute-bound. tensor core 점유율 결정적. FA3 forward 의 직접 적용.
memory-bound. K/V read 가 latency 의 거의 전부. tile size 16 도 큼. tensor core 활용 어려움 — 대신 여러 request 의 query 를 한 batch 로 묶어 GEMV 의 M 을 키우는 트릭.
대칭. 표준 transformer 의 형태. KV cache 가 가장 큼.
Llama-3 / Mistral 등의 표준. 여러 Q head 가 같은 K/V head 를 공유. K/V cache 가 작아진다 — decode 의 메모리 bandwidth 큰 절약.
강의에서 Zihao 가 강조한 사실 — “prefill 과 decode 는 같은 attention 이지만 hardware 위 dispatch 가 다르다.” prefill 은 큰 tile 의 GEMM kernel, decode 는 작은 tile 의 GEMV kernel. FlashInfer 의 unified API 는 사용자가 어느 모드인지 명시적으로 알리고, 내부에서 다른 코드로 dispatch.
둘이 같은 algorithm — online softmax 위 tiled attention. 다르게 보이는 건 tile 크기 / 병렬화 단위 / KV layout. 같은 표현 위에서 다른 instantiation. CUTLASS 의 layout algebra (L036 §08) 와 비슷한 정신.
# FlashInfer — 한 wrapper, 다양한 모드
import flashinfer
# 1) prefill (paged KV)
prefill = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD")
prefill.plan(qo_indptr, paged_kv_indptr,
paged_kv_indices, paged_kv_last_page_len,
num_qo_heads, num_kv_heads,
head_dim, page_size)
out = prefill.run(q, paged_kv_cache)
# 2) decode
decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(...)
decode.plan(...)
out = decode.run(q, paged_kv_cache)
# 3) prefill + decode 동시 (chunked prefill)
mixed = flashinfer.BatchPrefillWithPagedKVCacheWrapper(...)
out = mixed.run(q, paged_kv_cache, causal=True)
API 의 핵심 결정 두 가지.
이 추상의 실제 효과 — 커널 dispatch 가 한 함수 호출. 사용자는 “이 batch 가 어떤 sequence 들을 가지고 있는가” 만 plan 에 넘기고, 내부에서 적합한 attention kernel 이 골라진다.
plan 단계에서 scheduling 정보 (block 별 work assignment, SM 분배) 까지 미리 결정한다. SGLang 같은 엔진의 overlap scheduler 가 GPU forward 와 병렬로 다음 batch 의 plan 을 미리 호출하면, run 의 실제 latency 가 거의 GPU 시간만 됨. 이게 SGLang 의 throughput 이득 (L035) 의 한 축.
vLLM 의 PagedAttention paper 가 도입한 표준 — KV cache 를 작은 page (보통 16 토큰) 로 잘라 메모리에 비연속적으로 둔다. attention kernel 은 page table (block table) 을 따라가며 K/V 를 gather. FlashInfer 는 이 page table 을 first-class 로 받음.
FlashInfer 의 결정 — page_size = 1 도 first-class 로 지원. SGLang 이 RadixAttention 의 prefix match 를 더 정밀하게 하려고 page=1 을 쓰는데, page=1 의 paged attention 은 “메모리 access 가 매 토큰마다 indirect” 이라 일반 paged kernel 이 잘 다루기 어렵다. FlashInfer 는 이 케이스를 위한 별도 kernel path.
contiguous 는 TMA 2D 를 직접 쓸 수 있다 (한 thread 가 큰 tile 을 비동기 load). paged (page=1) 는 TMA 2D 가 안 통하니 register 단위 gather 로 fall-back. 강의에서 Zihao 가 짚은 한 줄 — “성능 차이는 의외로 크지 않다 — 메인 비용은 register pressure 이고, 그건 둘 다 비슷하다.”
shared prefix 를 활용하는 mode — shared-prefix attention. 같은 prefix 를 가진 N 개 request 의 attention 에서 K/V 를 한 번만 read. SGLang 이 이 mode 를 적극 활용 — 그래서 SGLang + FlashInfer 의 결합이 자연스러움.
긴 컨텍스트의 LLM 에서 모든 토큰이 모든 토큰을 attend 할 필요가 없다는 가정. sparse attention — 일부 (key, query) 쌍만 계산. FlashInfer 가 이 mode 를 attention API 안에 통합한다.
Zihao 의 한 줄 — “larger block 이면 K cache 를 SMEM/register 까지 stage 해 tensor core 를 쓴다. smaller block 이면 hardware-managed sparse fall-back.” 그래서 sparse attention 을 위한 attention API 는 single block size 가 아니라 block size 의 mixture 를 받는다.
이 자리는 강의의 가장 흥미로운 자리 중 하나 — sparse attention 이 실용적이려면 attention API 가 그걸 자연스럽게 받아들여야 한다. 별도 kernel 로 두면 사람이 안 쓴다. FlashInfer 의 unified API 가 그 진입 장벽을 낮춘다.
long-context LLM (32k, 128k 토큰) 의 prefill 은 한 번에 너무 무겁다. 그동안 다른 request 의 decode 가 GPU idle 로 기다린다. chunked prefill 이 답 — prefill 을 N 조각으로 잘라, 각 chunk 사이사이 다른 request 의 decode 를 끼워넣는다.
FlashInfer 가 이걸 지원하는 방식 — BatchPrefill API 에서 prefill chunk + decode 를 한 batch 안에서 처리. 같은 forward 안에 “이 sequence 는 prefill 의 chunk 5/10, 저 sequence 는 decode 한 토큰” 을 동시에. attention kernel 이 두 모드를 한 dispatch 로.
chunked prefill 의 query 는 prefill chunk 의 토큰들이고, key 는 그 chunk 까지의 모든 KV. 즉 query length 는 chunk size, key length 는 지금까지의 누적 prefix. 이 모양이 일반 prefill 이나 decode 와도 다른 별개 case. “variable Q-length” attention 으로 일반화하면 자연스럽게 들어맞는다.
SGLang / vLLM 이 chunked prefill 을 채용한 이유는 — long-context request 의 prefill 이 다른 request 의 decode latency 를 망가뜨리지 않게 하는 것. throughput + latency 의 동시 균형. FlashInfer 가 이걸 하나의 attention API 호출로 expose.
VLLM_ATTENTION_BACKEND=FLASHINFER. FA, FA3, Triton 과 함께 선택 가능. 특정 워크로드(MQA / paged page=1)에서 FlashInfer 가 더 좋음.강의에서 Zihao 가 명시한 — “우리는 SGLang 팀과 매우 가깝게 작업한다. 다음 release 에서 SGLang 에 10~20% 추가 throughput 이 들어올 예정.” 이 통합이 단순한 dependency 가 아니라 active co-development.
두 자리에 분명한 차이 — FA3 가 학습 forward/backward 의 표준이고 FlashInfer 가 inference 의 다양성 의 표준. 한 라이브러리가 모두를 다 잘하는 건 아니고, 자기 자리가 다르다. SGLang/vLLM 이 두 backend 를 모두 wrap.
강의의 specific 한 사실 — “우리는 query tile size 16 / 32 / 64 / 128 을 모두 가지고 있고, batch 의 평균 query length 같은 statistic 으로 best tile 을 선택한다.” 이게 가능한 이유는 JIT — 모든 tile size 의 kernel 을 미리 컴파일하지 않고 필요시 생성.
FlashInfer 의 다음 release 가 흥미로운 자리. 강의에서 Zihao 가 짧게 언급 — JIT generator 가 하나의 attention spec 으로부터 CUDA / CUTLASS / TileLang 코드를 모두 생성한다는 방향.
이 방향의 의미 — “새 attention 변종이 paper 로 나오면, 한 spec 추가 + generator 의 한 분기로 production 라이브러리에 흡수”. CUTLASS 가 GEMM 에서 한 일을 attention 에서 하는 자리. 강의 시점에 진행형.
TileLang 은 강의 시점에 새 DSL — Triton 과 CUTLASS 사이의 자리. 더 표현력 있는 high-level DSL 위에서 attention 을 표현, 컴파일러가 backend 별로 lowering. FlashInfer 가 그 generation pipeline 의 frontend 역할을 할 수 있는지가 강의의 마지막 흥미 자리.
FlashInfer 는 매우 빠르게 갱신되는 라이브러리. 이 노트의 API signature, 성능 수치, JIT 의 상태는 강의 시점 기준. 자기 시점의 release 노트와 docs 로 갱신 확인 필수. plan/run 분리 같은 핵심 추상은 안정적이나, 구체적 wrapper 이름은 변할 수 있음.