gpumode · 강의 아카이브
《GPU Mode》 L041 2024 · DEC High priority transcript · available

FlashInfer — Customizable kernels for LLM inference

LLM 추론의 attention 은 한 모양이 아니다. prefill 은 long Q × long KV 의 GEMM 류, decode 는 short Q × long KV 의 GEMV 류. 거기에 GQA / MQA / paged KV / sparse / chunked prefill 같은 변종이 더해진다. FlashInfer 의 답 — 이 모든 attention 을 한 통일 API + 한 generator 가 만든다. Zihao Ye 가 깐 unified attention engine 의 내부 — query tile 의 다양한 크기, paged KV cache 와 shared prefix, sparse attention, JIT compilation, vLLM/SGLang 통합까지.

FlashInfer paged KV prefill / decode GQA / MQA sparse attention JIT vLLM SGLang
Z
Speaker
Zihao Ye
UW · NVIDIA · FlashInfer 저자
강의 번호
L041
스피커
Zihao Ye
학습 우선순위
High · 정독
다시 볼 때
vLLM/SGLang 코드와
§ 01강의가 풀려는 문제· attention 의 다양성

“attention” 한 단어가 inference 에서 N 개 다른 모양을 가리킨다

L036 의 FlashAttention 3 는 학습 의 attention 을 풀었다. inference 에서는 attention 의 모양이 학습 시점보다 훨씬 다양하다. prefill 의 long-Q-long-KV, decode 의 short-Q-long-KV, GQA / MQA 의 head 비대칭, paged KV 의 비연속 메모리, sparse 의 토큰 부분 — 이게 모두 attention.

이전까지의 답 — 각 변종마다 별도 커널. vLLM 이 자기 paged attention, FA repo 가 다른 변종, Triton 이 또 다른 변종. FlashInfer 의 답 — 한 통일 API 가 이 모든 변종을 받아들이고, 내부에서 가장 적합한 커널을 dispatch / generate 한다.

강의의 인지적 frame

Zihao 의 입장 — “inference 의 attention 은 ‘shape 의 cartesian 곱’ 이다. (prefill/decode) × (GQA/MQA) × (paged/contiguous) × (sparse/dense) × (chunked/full). 이 곱을 사람이 한 본씩 짤 수 없다 — 코드 generation 으로 푼다.” FlashInfer 는 그 generation 의 추상.

“같은 attention 인데 inference 에서는 너무 많은 변종이 있다. 한 함수 + JIT 으로 그 cartesian 곱을 표현하는 게 이 라이브러리의 정체성.”Zihao Ye · 강의 도입부

그래서 강의 끝에 손에 잡혀야 할 자산 — (1) attention shape 의 차원들에 대한 mental model, (2) paged KV 의 table 구조와 shared prefix 의 결합, (3) JIT 가 코드 양을 어떻게 줄이는지, (4) vLLM / SGLang 같은 엔진과의 표준화.

§ 02prefill / decode / GQA / MQA· shape 의 분기

같은 “attention” 의 두 극단 — prefill 과 decode 가 사실상 다른 알고리즘

prefill

long Q × long KV — GEMM 류
q0
q1
q2
q3
q4
q5
q6
q7
k0
k1
k2
k3
k4
k5
k6
k7

FA-style tiled. tile size 128 이 자연스러움. compute-bound. tensor core 점유율 결정적. FA3 forward 의 직접 적용.

decode

short Q (=1 token) × long KV — GEMV 류
q
·
·
·
·
·
·
·
k0
k1
k2
k3
k4
k5
k6
k7

memory-bound. K/V read 가 latency 의 거의 전부. tile size 16 도 큼. tensor core 활용 어려움 — 대신 여러 request 의 query 를 한 batch 로 묶어 GEMV 의 M 을 키우는 트릭.

MHA (full)

Q heads = K heads = V heads
qH0
qH1
qH2
qH3
kH0
kH1
kH2
kH3

대칭. 표준 transformer 의 형태. KV cache 가 가장 큼.

GQA / MQA

Q heads > K/V heads
qH0
qH1
qH2
qH3
kH0
kH0
kH1
kH1

Llama-3 / Mistral 등의 표준. 여러 Q head 가 같은 K/V head 를 공유. K/V cache 가 작아진다 — decode 의 메모리 bandwidth 큰 절약.

강의에서 Zihao 가 강조한 사실 — “prefill 과 decode 는 같은 attention 이지만 hardware 위 dispatch 가 다르다.” prefill 은 큰 tile 의 GEMM kernel, decode 는 작은 tile 의 GEMV kernel. FlashInfer 의 unified API 는 사용자가 어느 모드인지 명시적으로 알리고, 내부에서 다른 코드로 dispatch.

왜 “하나의 통일 함수” 가 가능한가

둘이 같은 algorithm — online softmax 위 tiled attention. 다르게 보이는 건 tile 크기 / 병렬화 단위 / KV layout. 같은 표현 위에서 다른 instantiation. CUTLASS 의 layout algebra (L036 §08) 와 비슷한 정신.

§ 03unified attention API· 한 함수가 모든 모양

한 wrapper 가 prefill / decode / paged / sparse 를 모두 받는다

# FlashInfer — 한 wrapper, 다양한 모드
import flashinfer

# 1) prefill (paged KV)
prefill = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, kv_layout="NHD")
prefill.plan(qo_indptr, paged_kv_indptr,
             paged_kv_indices, paged_kv_last_page_len,
             num_qo_heads, num_kv_heads,
             head_dim, page_size)
out = prefill.run(q, paged_kv_cache)

# 2) decode
decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(...)
decode.plan(...)
out = decode.run(q, paged_kv_cache)

# 3) prefill + decode 동시 (chunked prefill)
mixed = flashinfer.BatchPrefillWithPagedKVCacheWrapper(...)
out = mixed.run(q, paged_kv_cache, causal=True)

API 의 핵심 결정 두 가지.

  • plan / run 분리 — 한 batch 의 metadata (어떤 sequence 가 얼마나 긴지, paged table 의 layout) 가 plan 에서 미리 처리. 그 결과는 cache 가능. 다음 forward 에서 같은 metadata 면 plan 재호출 안 함.
  • indptr / indices 의 표준 — variable-length batch 를 표현하는 두 vector. SGLang/vLLM 의 batch 표현과 자연스럽게 연결.

이 추상의 실제 효과 — 커널 dispatch 가 한 함수 호출. 사용자는 “이 batch 가 어떤 sequence 들을 가지고 있는가” 만 plan 에 넘기고, 내부에서 적합한 attention kernel 이 골라진다.

plan / run 의 구조적 의미

plan 단계에서 scheduling 정보 (block 별 work assignment, SM 분배) 까지 미리 결정한다. SGLang 같은 엔진의 overlap scheduler 가 GPU forward 와 병렬로 다음 batch 의 plan 을 미리 호출하면, run 의 실제 latency 가 거의 GPU 시간만 됨. 이게 SGLang 의 throughput 이득 (L035) 의 한 축.

§ 04paged attention 통합· vLLM table 구조

page=1 의 KV 가 비연속이어도 attention 이 빠르다

vLLM 의 PagedAttention paper 가 도입한 표준 — KV cache 를 작은 page (보통 16 토큰) 로 잘라 메모리에 비연속적으로 둔다. attention kernel 은 page table (block table) 을 따라가며 K/V 를 gather. FlashInfer 는 이 page table 을 first-class 로 받음.

FIG · paged KV cache + shared prefix4 requests
P0 sys
P1 sys
P2 sys
P3 r1
P4 r2
P5 r3
P6 r4
·
P8 r1
P9 r2
·
P11 r1
초록색 page (P0~P2) 는 4 request 가 모두 공유하는 system prompt. 각 색은 한 request 의 unique 부분. 같은 page 가 여러 request 의 attention 에서 한 번만 read 되어야 메모리 bandwidth 의 이득이 산다 (RadixAttention §03 in L035).

FlashInfer 의 결정 — page_size = 1 도 first-class 로 지원. SGLang 이 RadixAttention 의 prefix match 를 더 정밀하게 하려고 page=1 을 쓰는데, page=1 의 paged attention 은 “메모리 access 가 매 토큰마다 indirect” 이라 일반 paged kernel 이 잘 다루기 어렵다. FlashInfer 는 이 케이스를 위한 별도 kernel path.

contiguous KV vs paged KV

contiguous 는 TMA 2D 를 직접 쓸 수 있다 (한 thread 가 큰 tile 을 비동기 load). paged (page=1) 는 TMA 2D 가 안 통하니 register 단위 gather 로 fall-back. 강의에서 Zihao 가 짚은 한 줄 — “성능 차이는 의외로 크지 않다 — 메인 비용은 register pressure 이고, 그건 둘 다 비슷하다.”

shared prefix 를 활용하는 mode — shared-prefix attention. 같은 prefix 를 가진 N 개 request 의 attention 에서 K/V 를 한 번만 read. SGLang 이 이 mode 를 적극 활용 — 그래서 SGLang + FlashInfer 의 결합이 자연스러움.

§ 05sparse attention· column-sparse · block-sparse

토큰의 부분만 attend 한다 — 같은 attention API 안에서

긴 컨텍스트의 LLM 에서 모든 토큰이 모든 토큰을 attend 할 필요가 없다는 가정. sparse attention — 일부 (key, query) 쌍만 계산. FlashInfer 가 이 mode 를 attention API 안에 통합한다.

block size 결정의 trade-off

Zihao 의 한 줄 — “larger block 이면 K cache 를 SMEM/register 까지 stage 해 tensor core 를 쓴다. smaller block 이면 hardware-managed sparse fall-back.” 그래서 sparse attention 을 위한 attention API 는 single block size 가 아니라 block size 의 mixture 를 받는다.

이 자리는 강의의 가장 흥미로운 자리 중 하나 — sparse attention 이 실용적이려면 attention API 가 그걸 자연스럽게 받아들여야 한다. 별도 kernel 로 두면 사람이 안 쓴다. FlashInfer 의 unified API 가 그 진입 장벽을 낮춘다.

§ 06chunked prefill· 긴 컨텍스트 분할

prefill 한 forward 의 KV 를 작은 chunk 로 잘라 decode 와 섞는다

long-context LLM (32k, 128k 토큰) 의 prefill 은 한 번에 너무 무겁다. 그동안 다른 request 의 decode 가 GPU idle 로 기다린다. chunked prefill 이 답 — prefill 을 N 조각으로 잘라, 각 chunk 사이사이 다른 request 의 decode 를 끼워넣는다.

FlashInfer 가 이걸 지원하는 방식 — BatchPrefill API 에서 prefill chunk + decode 를 한 batch 안에서 처리. 같은 forward 안에 “이 sequence 는 prefill 의 chunk 5/10, 저 sequence 는 decode 한 토큰” 을 동시에. attention kernel 이 두 모드를 한 dispatch 로.

왜 일반 attention kernel 로 안 되는가

chunked prefill 의 query 는 prefill chunk 의 토큰들이고, key 는 그 chunk 까지의 모든 KV. 즉 query length 는 chunk size, key length 는 지금까지의 누적 prefix. 이 모양이 일반 prefill 이나 decode 와도 다른 별개 case. “variable Q-length” attention 으로 일반화하면 자연스럽게 들어맞는다.

SGLang / vLLM 이 chunked prefill 을 채용한 이유는 — long-context request 의 prefill 이 다른 request 의 decode latency 를 망가뜨리지 않게 하는 것. throughput + latency 의 동시 균형. FlashInfer 가 이걸 하나의 attention API 호출로 expose.

§ 07통합 사례 (vLLM · SGLang)· backend 표준화

두 inference engine 이 같은 attention library 를 쓰는 시점

vLLM
FlashInfer 는 옵션 backend
VLLM_ATTENTION_BACKEND=FLASHINFER. FA, FA3, Triton 과 함께 선택 가능. 특정 워크로드(MQA / paged page=1)에서 FlashInfer 가 더 좋음.
SGLang
FlashInfer 가 main backend
RadixAttention + shared-prefix attention 의 결합이 자연. SGLang 의 throughput 이득의 한 축이 FlashInfer 의 plan/run 분리에서 옴.
research
새 attention 변종 prototyping
JIT 으로 새 변종을 빠르게. paper 단계의 attention 알고리즘이 production engine 에 빠르게 흡수되는 길.

강의에서 Zihao 가 명시한 — “우리는 SGLang 팀과 매우 가깝게 작업한다. 다음 release 에서 SGLang 에 10~20% 추가 throughput 이 들어올 예정.” 이 통합이 단순한 dependency 가 아니라 active co-development.

“같은 hardware 에서 두 inference engine 이 같은 attention library 를 쓴다는 건 — attention 의 표준화 시대의 시작.”학습 노트
§ 08성능 사례· FA3 와 비교

같은 hardware 에서 FA3 의 forward 정도 + inference 시나리오 우위

시나리오 · H100FA3FlashInfer차이의 근원
prefill (4k Q, 4k KV, MHA)~85% peak~85% peak동등 — FA3 path
decode (1 Q, 4k KV, GQA)limitedmemory-bw bounddecode 전용 path
paged page=1, shared prefix미지원first-classSGLang fit
chunked prefill제한nativevariable Q
sparse attention없음column · blockunified API
FP8 KV cache진행기본 지원inference 전용

두 자리에 분명한 차이 — FA3 가 학습 forward/backward 의 표준이고 FlashInfer 가 inference 의 다양성 의 표준. 한 라이브러리가 모두를 다 잘하는 건 아니고, 자기 자리가 다르다. SGLang/vLLM 이 두 backend 를 모두 wrap.

tile size 의 다양성

강의의 specific 한 사실 — “우리는 query tile size 16 / 32 / 64 / 128 을 모두 가지고 있고, batch 의 평균 query length 같은 statistic 으로 best tile 을 선택한다.” 이게 가능한 이유는 JIT — 모든 tile size 의 kernel 을 미리 컴파일하지 않고 필요시 생성.

§ 09JIT / DSL 다음 방향· CUDA · CUTLASS · TileLang

한 spec 에서 여러 backend 코드 — generation 의 다음 단계

FlashInfer 의 다음 release 가 흥미로운 자리. 강의에서 Zihao 가 짧게 언급 — JIT generator 가 하나의 attention spec 으로부터 CUDA / CUTLASS / TileLang 코드를 모두 생성한다는 방향.

L0 · spec attention variant specificationvariable Q-length, GQA ratio, page size, sparse mask, dtype 등을 한 spec 객체로 사용자 입력
L1 · generator JIT code synthesizerspec 을 파싱하고 dispatch 결정 + kernel template 인스턴스화 FlashInfer 내부
L2 · backend CUDA · CUTLASS · TileLang같은 spec 이 여러 backend 의 코드로 펴진다. 각 backend 의 trade-off 가 다름. multi-backend
L3 · binary .cubin · ahead-of-time 또는 just-in-timecache 가능. 다음 호출에서 같은 spec 이면 binary 재사용. runtime

이 방향의 의미 — “새 attention 변종이 paper 로 나오면, 한 spec 추가 + generator 의 한 분기로 production 라이브러리에 흡수”. CUTLASS 가 GEMM 에서 한 일을 attention 에서 하는 자리. 강의 시점에 진행형.

“JIT 의 진짜 가치는 컴파일 시간이 아니라 — 하나의 spec 에서 여러 hardware 의 코드를 생성한다는 점.”Zihao Ye · Q&A

TileLang 은 강의 시점에 새 DSL — Triton 과 CUTLASS 사이의 자리. 더 표현력 있는 high-level DSL 위에서 attention 을 표현, 컴파일러가 backend 별로 lowering. FlashInfer 가 그 generation pipeline 의 frontend 역할을 할 수 있는지가 강의의 마지막 흥미 자리.

§ 10기억할 메모와 코드· repo · paper

다시 열었을 때 5분 안에 손에 잡혀야 할 것

attention shape 의 cartesian 곱
prefill/decode × MHA/GQA/MQA × paged/contig × dense/sparse × full/chunked. 한 본씩 짤 수 없음.
unified API
하나의 wrapper 가 모든 변종. plan / run 분리. metadata 는 cache.
paged page=1
SGLang 의 RadixAttention 이 요구하는 layout. FlashInfer 의 first-class.
shared-prefix attention
같은 prefix 의 N request 가 K/V 를 한 번만 read. SGLang 의 throughput 이득의 한 축.
chunked prefill
prefill 을 N 조각 + decode 와 섞기. variable Q-length attention 으로 표현.
sparse attention 통합
column-sparse · block-sparse · composition. 별도 kernel 안 두고 같은 API.
vLLM · SGLang 통합
SGLang 이 main, vLLM 은 옵션. 두 engine 이 같은 attention library 를 공유.
JIT generator
spec → CUDA / CUTLASS / TileLang. 다음 release 의 큰 줄기.
vLLM 통합vLLM docs — VLLM_ATTENTION_BACKEND
SGLang 통합SGLang repo — main backend
Speakergithub.com/yzh119 · Zihao Ye

손에 새기기 — 실습 시퀀스

  1. FlashInfer 한 줄 hello — pip install flashinfer + decode wrapper 한 batch. variable-length 의 indptr 표현 손에.
  2. plan / run 분리 직접 측정 — 같은 metadata 로 run 만 반복 호출. plan 비용이 한 번이고 run 이 가벼운지 latency 로 확인.
  3. paged page=1 vs page=16 — 같은 sequence 를 두 page size 로 cache 하고 attention latency 비교. 의외로 차이 적은 게 정상.
  4. shared-prefix attention — 같은 system prompt 에 4 개 request. shared mode on/off 의 K/V read 트래픽 측정.
  5. chunked prefill 한 batch — 한 batch 안에 prefill chunk + decode 동시. 한 forward 의 GPU util 이 얼마나 채워지는지 NCU.
  6. SGLang backend 직접 교체 — SGLang 에 FlashInfer / Triton 두 backend 모두로 같은 모델 띄우고 throughput 비교.
  7. vLLM backend 비교 — 같은 모델 / 같은 batch 로 FlashInfer / FA3 / Triton 세 backend. 어느 시나리오가 누구의 자리인지 표 만들기.
§ 11다른 강의로 이어지는 길· connections

FlashInfer 가 시리즈 안 다른 강의들과 만나는 자리

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

검증 메모

FlashInfer 는 매우 빠르게 갱신되는 라이브러리. 이 노트의 API signature, 성능 수치, JIT 의 상태는 강의 시점 기준. 자기 시점의 release 노트와 docs 로 갱신 확인 필수. plan/run 분리 같은 핵심 추상은 안정적이나, 구체적 wrapper 이름은 변할 수 있음.

← Lecture 040 CUDA Docs for Humans Lecture 042 → 다음 강의로