gpumode · 강의 아카이브
《GPU Mode》 L034 2024 · OCT High priority transcript · available

Low Bit Triton Kernels

CUTLASS / cuBLAS 가 다 가져간 4-bit weight-only GEMM 의 자리에서, Triton 만으로 비슷한 성능을 짜내려면 어떤 트릭이 필요한가. HQQ 의 저자 Hicham Badri 가 BitBLAS / Marlin / vLLM 의 in-house 커널들을 한 줄로 비교하면서 — Triton tile 안에서 dequantize-then-multiply 패턴을 어떻게 펴는지, group-wise scale 이 왜 헤더로 따로 packing 되어야 하는지, 같은 코드가 BLOCK_K 한 줄로 두 배 빨라지는 이유까지 깐다.

Triton INT4 weight-only dequantize HQQ Marlin BitBLAS TMA autotune
H
Speaker
Hicham Badri
Mobius Labs · HQQ 저자 · github.com/mobicham
강의 번호
L034
스피커
Hicham Badri
학습 우선순위
High · 정독 + 구현
다시 볼 때
Triton 직접 짜며
§ 01강의가 풀려는 문제· why Triton, why low-bit

“CUTLASS 만 잘하던 일” 을 Triton 으로 끌어오는 작업

LLM 추론에서 4-bit weight-only GEMM 이 표준 도구가 됐다. 그런데 빠른 커널은 거의 다 CUTLASS / cuBLAS / Marlin 처럼 깊은 C++ 위에 있다. Hicham 의 질문은 단순하다 — Triton 으로도 같은 성능을 짤 수 있을까. 짤 수 있다면, 어디까지 봐야 하는가.

이 강의는 그 질문에 대한 진행형 답이다. 강의 시점에 Hicham 이 직접 만들고 있는 라이브러리는 HQQ(Half-Quadratic Quantization) 와 그 위에 얹은 Triton 커널 모음 gemlite. “BitBLAS 는 TVM 에 묶여 있어 손대기 어렵고, Marlin 은 4-bit 만, Triton 은 짜기는 쉽지만 성능이 떨어진다” 는 표가 강의 첫 장이다 — 그 빈 자리를 Triton 으로 메운다.

강의의 인지적 frame

Hicham 의 일관된 입장 — “Triton 의 성능은 코드의 길이가 아니라 tile 안에서 dequant 를 어떻게 펴는가BLOCK 설정 에 달렸다.” 이 두 축으로 강의 전체가 정렬된다. Triton 이 자동으로 tensor core 까지 끌어주는 시점에서, 사람의 일은 register 안에서 unpack/scale 의 시퀀스를 줄이는 것으로 좁혀진다.

“Cutlass 는 GEMM 을 정말 잘한다. 그런데 customize 가 쉽지 않다. Triton 은 customize 가 trivial 한데 성능이 떨어졌다 — 그 사이를 메우는 게 이 작업이다.”Hicham Badri · 강의 도입부

그래서 강의 끝에 손에 잡혀야 할 자산은 세 개. (1) dequantize-then-multiply 패턴이 Triton tile 안에서 어떻게 펼쳐지는지의 mental model, (2) group-wise scale 이 왜 packing 단계에서 이미 들어가야 하는지의 이유, (3) 새 GPU(Hopper) 의 TMA 를 Triton 에서 끌어 쓸 때 어디가 막히는지의 한계.

§ 02weight-only 4-bit 의 배경· memory-bound decode

왜 weight 만 quantize 하는가 — decode 가 메모리 밴드 위에서 도니까

LLM 추론은 두 단계로 나뉜다 — prefill(프롬프트 처리, batch=long) 과 decode(토큰 한 개씩 생성, batch=1). decode 한 step 의 비용은 거의 100% weight 를 HBM 에서 읽는 시간이다. weight 를 4-bit 로 줄이면 같은 GPU 에서 4배 더 빠르게 읽힌다 — 이게 weight-only 의 출발점.

강의에서 Hicham 이 도식으로 보여준 비교 — 16-bit 와 4-bit 같은 layer 의 메모리 트래픽.

  • 16-bit — 7B 모델의 한 linear 가 약 130MB. decode 한 토큰 당 모든 layer 의 weight 를 한 번씩 본다.
  • 4-bit + group-wise scale (g=128) — weight 32MB + scale ~1MB ≈ 33MB. 거의 4배 줄음.
  • activation — batch=1 에서 작아서 quantize 안 해도 무시 가능. 다만 긴 context 의 KV cache 는 별도 quantize 가 또 필요해진다(L041 으로 이어짐).

그래서 weight-only 가 “가장 싼 압축으로 가장 큰 latency 이득” 의 자리. activation 까지 같이 quantize 하는 W8A8 / W4A8 같은 방식은 prefill 이 dominant 한 throughput-bound 시나리오에서 다시 의미가 생긴다.

FIG · decode 한 step 의 HBM 트래픽 분해7B · GQA
FP16 weight~130 MB
INT4 weight~33 MB
activation (b=1)~0.5 MB
KV cache (1k ctx)~24 MB
decode 한 token 당 GPU 가 읽어야 하는 데이터의 비율. weight 가 압도적이라 weight-only 만 줄여도 latency 가 거의 4배 빨라진다.

강의의 이 슬라이드가 끝나면, 다음 한 줄이 자연스럽게 따라온다 — “그러면 4-bit weight 를 어떻게 packing 하고, 어떻게 GEMM 안에서 16-bit 로 펴서 tensor core 에 먹이는가.” 그게 §03 의 본론.

§ 03dequantize-then-multiply 패턴· tile 안에서 펴기

매 BLOCK_K 마다 unpack → scale → MMA 가 register 위에서 도는 한 사이클

low-bit GEMM 의 핵심 패턴 — 4-bit packed weight 를 HBM 에서 읽어 SMEM 에 올린 뒤, register 위에서 nibble 을 펴고 group scale 을 곱해서 16-bit tile 로 복원, 그걸 그대로 tensor core 에 넘긴다. 이 시퀀스가 BLOCK_K 단위로 반복된다.

FIG · 한 BLOCK_K iteration 동안 register 위에서 일어나는 일weight-only INT4 GEMM
L0
HBM
packed nibbles
L1
SMEM
cp.async / TMA
L2
register unpack
shift & mask
L3
scale 곱하기
group-wise
L4
MMA (FP16)
tensor core
중요한 건 L2 와 L3 가 register 위에서 끝나야 한다는 점. SMEM 에 dequant 결과를 다시 적으면 SMEM 트래픽이 두 배가 된다 — 그 시점에서 low-bit 의 메모리 이득이 깨진다.

그림은 단순하지만 실제로는 두 가지 결정이 묶여 있다. 첫째는 packing layout — 4-bit 두 개를 한 byte 에 넣는 방식이 cuBLAS/Marlin/Triton 모두 다르다. 둘째는 scale 의 위치 — 매 group 마다 한 번 scale 을 곱하는데, 이걸 K 축의 어느 시점에 끼워넣느냐가 register 압박을 결정한다.

두 라이브러리의 packing 차이가 강의 본문의 골치 거리

Hicham 이 강의에서 명시적으로 말한 한 줄 — “Marlin 의 4-bit packing 과 Triton 의 packing 이 다르다. 같은 weight 라도 둘은 호환되지 않고, repack 코드를 따로 짜야 한다.” CUTLASS 가 정한 사실상의 표준이 있고, Triton 은 자기만의 layout 을 쓴다. fast Triton kernelfast CUDA kernel 을 한 모델에 같이 쓰려면 weight 를 두 번 packing 해서 디스크에 저장해야 한다.

“Triton 안에서 한 컴포넌트만 customize 하고 싶다면 packing 단계를 손대야 한다. 코드 한 줄이 아니라 weight 디스크 형식 전체 가 따라 움직인다.”Hicham Badri
§ 04group-wise vs per-channel scale· 정확도 · packing

group_size = 64 / 128 가 사실상 표준이 된 이유

4-bit 만으로 정확도를 유지하려면 한 weight matrix 에 scale 하나 로는 부족하다. K 축 방향으로 일정 group 단위로 scale 을 따로 가지면 정확도가 살아난다. group_size = 128 이 GPTQ / AWQ / HQQ 모두에서 사실상 표준.

Per-tensor scale
전체에 scale 하나
메모리 추가 비용 0. 4-bit 정확도 손실이 큼. weight 분포가 outlier 많으면 quantization range 가 망가진다.
Per-channel scale
N 차원마다 scale
N 개 scale (~수천 개 fp16). PTQ 의 baseline. 정확도 일부 회복하지만 K 방향 outlier 는 여전.
Group-wise scale (g=128)
K 차원 group 마다 scale
K/g × N 개 scale. ~1MB 추가. INT4 의 perplexity 가 거의 FP16 수준까지 올라감 — 표준.

강의에서 Hicham 이 강조한 사실 — group-wise scale 은 정확도 의 문제만이 아니라 커널 구현 의 문제이기도 하다. K 방향으로 BLOCK_K 만큼 진행할 때마다 scale 을 다시 읽어와야 한다. BLOCK_K 와 group_size 가 정렬되지 않으면 한 iteration 안에서 scale 이 바뀌는 경우가 생기고 — 그러면 dequant 의 분기가 늘어나며 register 압박이 폭발한다.

실전 규칙

BLOCK_K = group_size 또는 BLOCK_K 가 group_size 의 배수가 되도록 launch 설정을 잡는다. 둘이 어긋나면 같은 코드라도 30~50% 느려진다. autotune config 에서도 이 제약을 잡아두는 게 첫 단계.

다음 그림은 16개 4-bit weight 가 한 byte 8 개에 packing 된 모습 — 그리고 그 위에 어떻게 group scale 이 묶이는지.

FIG · 4-bit packing 한 row (16개 weight = 8 byte)group_size=8 simplified
w0
w1
w2
w3
w4
w5
w6
w7
w8
w9
w10
w11
w12
w13
w14
w15
scale group 0 (FP16)
scale group 1 (FP16)
한 byte 안에 두 개의 4-bit weight (lo nibble · hi nibble). group_size = 8 의 예시이므로 두 group 이 행에 들어 있고, 각 group 별 scale 한 개씩이 별도 buffer 에 저장된다. 실제 표준은 g=128.
§ 05Triton tile 안 dequant· tl.shift, tl.bitwise_and

shift & mask 두 줄로 nibble 을 분리한다 — 그 자리가 register 압박의 시작

Triton 안에서 packing 된 byte 를 두 nibble 로 펴는 코드는 짧다. 어려운 건 그게 BLOCK_M × BLOCK_K tile 단위로 동시에 일어난다는 사실 — 그래서 unpack 결과 텐서가 register 위에서 두 배 크기로 늘어난다.

# gemlite 류 Triton 커널의 BLOCK_K iteration 안 — 핵심부
# a: FP16 activation tile [BLOCK_M, BLOCK_K]
# b_packed: INT8 weight tile [BLOCK_K // 2, BLOCK_N]
# scales: FP16 scale [num_groups, BLOCK_N]

# 1) packed byte 한 장 load — half size 만 읽는다
b_q = tl.load(b_ptr + offs_pk, mask=...)        # int8

# 2) two-nibble unpack — register 위에서 두 배로 펼친다
b_lo = (b_q & 0xF).to(tl.int8) - 8            # signed shift
b_hi = ((b_q >> 4) & 0xF).to(tl.int8) - 8

# 3) group scale broadcast 후 곱 — FP16 으로 복원
g  = (offs_k // GROUP_SIZE)
sc = tl.load(s_ptr + g * stride_sg + offs_n)
b_lo = b_lo.to(tl.float16) * sc
b_hi = b_hi.to(tl.float16) * sc

# 4) MMA — Triton 이 tensor core 로 dispatch
acc += tl.dot(a, tl.cat(b_lo, b_hi, axis=0))

이 짧은 시퀀스가 한 BLOCK_K iteration 안에서 매번 돈다. 강의에서 Hicham 이 짚은 압박 포인트는 세 개.

  • register footprint — unpack 결과 b_lo, b_hi 둘 다 register 에 살아 있는다. scales 까지 더하면 BLOCK_K, BLOCK_N 이 커질수록 register spill 위험.
  • dequant 와 MMA 의 overlap — Triton 은 자동으로 software pipelining (num_stages) 을 깐다. 다음 K-step 의 unpack 이 현재 step 의 MMA 와 겹쳐야 효율이 산다.
  • group boundary — group 이 BLOCK_K 안에서 바뀌면 scale broadcast 코드가 분기 처리되며 SASS 가 한 단계 더 복잡해진다 (§04 와 같은 얘기).
한 줄 더 짚을 것

Triton 은 tl.dot 의 두 피연산자에 대해 알아서 tensor core MMA 로 dispatch 한다. 사람이 wmma::* 를 직접 쓸 필요는 없다 — 다만 piece 가 FP16/BF16 으로 cast 되어야 dispatch 가 일어남. INT8 까지 내릴 거면 tl.dot(..., out_dtype=tl.int32) 형태로 명시.

§ 06BLOCK_K 와 register 트레이드오프· launch sweep

같은 코드가 BLOCK_K 한 줄로 두 배 빨라진다 — 그 자리에 autotune 이 산다

Triton 의 launch 설정이 성능을 결정한다는 §06(L001) 의 일반 원리가, low-bit 에서는 group_size 와의 정렬 이라는 추가 제약과 합쳐지며 더 좁아진다.

BLOCK_M × BLOCK_N × BLOCK_K · num_warps · num_stagesregister/threadoccupancy4090 · 4096×4096
128×128×32 · 4 · 2~9638%1.7 ms
128×128×64 · 4 · 3~12831%1.3 ms
128×128×128 · 4 · 3~16825%0.92 ms
128×128×128 · 8 · 4~23219% (spill)1.5 ms
256×64×128 · 4 · 4~19622%1.05 ms

위 표의 메시지 — BLOCK_K = group_size = 128 이 sweet spot. 거기서 더 키우면 register spill, 거기서 줄이면 group scale broadcast 비용이 매 iteration 마다 누적. num_stages 도 3 이 균형점이고 4 로 키우면 software pipeline 을 위한 SMEM 이 부족해진다.

autotune 패턴

강의의 gemlite 가 깐 autotune config 는 약 20개 — BLOCK 셋의 곱과 num_warps ∈ {2,4,8}, num_stages ∈ {2,3,4} 의 부분집합. 새 모델/새 GPU 에 처음 띄울 때 이 config 가 한 번 도는 데 수십 초가 걸리지만, 그 결과는 ~/.triton/cache 에 박힌다.

“Triton 의 성능은 코드의 길이가 아니라 BLOCK 설정과 group_size 의 정렬에 달렸다. 같은 코드, 다른 launch — 두 배 차이가 보통이다.”Hicham Badri · 강의 후반
§ 07정확도 검증과 numerical drift· FP16 reference

커널이 빠른 게 아니라 맞는 게 먼저다

low-bit GEMM 의 첫 검증은 reference matmul (FP16 cuBLAS) 와의 결과 차이. 단순한 max-abs-diff 로는 모자라다 — quantization 의 error 는 분포 형태로 본다.

Hicham 이 추천한 검증 시퀀스.

  1. round-trip 무손실 확인 — quantize → dequantize 가 group_size 만큼의 분해능 안에서 정확히 복원되는지. 코드 버그가 있으면 여기서부터 깨짐.
  2. per-output relative error(y_q - y_ref) / (|y_ref| + ε). mean 과 max 둘 다 본다. 4-bit + g=128 이면 mean ~1e-3, max ~5e-3 정도가 정상.
  3. downstream perplexity — toy GEMM 이 통과해도 모델 perplexity 가 깨질 수 있다. wikitext-103 같은 작은 set 으로 sanity.

특히 group_size 와 BLOCK_K 가 어긋난 버그는 random input 에는 안 잡히고 실제 모델 weight 에서만 깨진다 — outlier 가 group boundary 위에 정확히 떨어질 때만 보이는 패턴.

FIG · INT4 vs FP16 출력 분포 비교HQQ + g=128
mean rel.err~8e-4
p99 rel.err~3e-3
max rel.err~7e-3
PPL drift (wikitext)+0.05
정상 INT4 + g=128 의 error 분포. mean 이 작아도 max 가 큰 이유 는 outlier 채널 — 그래서 group-wise 가 per-tensor 보다 늘 낫다.
§ 08cuBLAS · Marlin 대비 성능· batch size 별 그래프

batch=1 에서는 cuBLAS 보다 빠르다 — Marlin 만큼은 아니지만 근접

강의의 마지막 절반은 같은 weight, 같은 GPU 에서 라이브러리별 latency 비교. Triton(gemlite) / Marlin / BitBLAS / cuBLAS FP16 / vLLM 의 GPTQ kernel 을 같은 매트릭스에 올린다.

library · kernelbatch=1batch=16customizable
cuBLAS FP163.6 ms3.7 msno
Marlin INT40.55 ms0.78 mslimited
BitBLAS INT40.62 ms0.85 msTVM 묶임
gemlite Triton INT40.78 ms1.02 msfull
vLLM GPTQ kernel0.95 ms1.3 mspartial

표가 보여주는 두 가지. 첫째 — Triton 이 cuBLAS 보다 5~6배 빠르다. weight 를 4-bit 로 줄였으니 당연하다. 둘째 — Marlin / BitBLAS 와 30% 정도 차이. 이게 “Triton 이 거의 따라잡았다” 의 의미. 강의 시점 기준으로, customize 가능성을 포기하지 않으면서 Marlin 의 80% 까지 와 있다.

Hopper TMA 활용 시 다음 단계

Triton 도 tl.experimental.descriptor_load 로 TMA 를 끌어 쓸 수 있다. Hicham 의 실험에서 GEMM 부분은 reference 와 동일한 성능이 나왔지만 — 빌드가 까다롭고 30분 정도 직접 컴파일이 필요. H100 위에서 Triton 의 격차가 더 줄어들 것이라는 게 강의의 전망.

§ 09INT4 / NF4 / FP4 비교· 표현 공간의 갈래

같은 4-bit 라도 “어느 16개 값을 쓰는가” 가 다르다

4-bit quantization 은 한 weight 당 16개의 가능한 값 을 정해두는 일이다. 그 16개를 어떻게 고르느냐 — uniform grid (INT4) / non-uniform (NF4) / FP-style (FP4) — 가 정확도와 커널 단순도를 동시에 결정한다.

type
grid
trade-off
INT4
uniform · 16 levels -8 .. +7. 일정 간격.
dequant 이 단순(shift & mask). 표준 weight 분포(가우시안)에 잘 안 맞음.
NF4
non-uniform · 정규분포 quantile 가운데 빽빽, 양 끝 sparse.
정확도 best (QLoRA 표준). dequant 가 LUT lookup 이라 SASS 비용 늘어남.
FP4 (E2M1)
floating-point · 2-bit exp + 1-bit mantissa
H100 의 hardware 지원. dequant 이 hardware path 면 거의 free. 정확도 INT4 와 비슷.

Hicham 의 입장 — 실용 라이브러리는 INT4 를 표준으로 두고 NF4 를 옵션으로 둔다. 이유는 dequant 비용. NF4 는 LUT 가 register 에 살아 있어야 해서 register 압박이 더 심하고, INT4 의 shift & mask 두 줄과 비교하면 SASS 길이가 두 배 가까워진다. 정확도 차이가 작은 모델에서는 INT4 가 균형 있다.

“NF4 는 정확도가 약간 더 높지만 dequant 가 무거워서 — 결국 같은 latency 예산 안에서는 INT4 + g=64 가 NF4 + g=128 보다 좋게 나오는 케이스가 많다.”Hicham Badri · Q&A
§ 10기억할 메모와 코드· HQQ · gemlite

다시 열었을 때 5분 안에 손에 잡혀야 할 것

강의에서 6개월 뒤에 돌아왔을 때 가장 빨리 복원해야 하는 사실들과 — 직접 손에 박아야 하는 코드 자료들.

weight-only 의 자리
decode latency 의 거의 100% 가 weight HBM read. 4-bit 로 줄이면 거의 4배 빨라진다.
group-wise scale (g=128)
K 차원 group 마다 scale 한 개. 정확도 회복 + packing 표준. BLOCK_K 와 정렬 필수.
dequant-then-MMA on register
unpack/scale 결과는 register 위에 두고 그대로 tl.dot. SMEM 다시 적으면 메모리 이득 깨진다.
packing 호환성 함정
Marlin packing ≠ Triton packing. 두 라이브러리를 같이 쓰려면 weight 두 번 packing.
shift & mask 두 줄
(b_q & 0xF) - 8, ((b_q >> 4) & 0xF) - 8 가 INT4 unpack 의 표준 형태.
launch sweet spot
BLOCK_K = group_size = 128, num_warps=4, num_stages=3 이 4090/A100 의 좋은 출발점.
정확도 검증 시퀀스
round-trip → per-output rel.err (mean·p99·max) → downstream PPL. random input 만으로는 못 잡는 버그가 있다.
표현 공간 선택
실용 표준은 INT4. NF4 는 dequant 비용으로 손해, FP4 는 H100 hardware 지원이 본격화되며 미래 선택지.
Code · HQQgithub.com/mobiusml/hqq — Hicham 의 quantization 라이브러리
Code · gemlitegithub.com/mobiusml/gemlite — Triton low-bit GEMM 모음

손에 새기기 — 실습 시퀀스

  1. 4-bit packing 함수 직접 짜기 — FP16 weight → group-wise quantize → 두 nibble 을 한 byte 에. 자기 packing 이 gemlite 와 호환되는지 round-trip 으로 검증.
  2. 최소 Triton GEMM 골격 — fp16 ⊗ fp16 GEMM 을 먼저 짜서 tl.dot 의 형태와 launch 설정에 익숙해진다. 이게 다음 단계의 baseline.
  3. INT4 weight load 추가 — 위 GEMM 의 weight 를 packed INT4 로 바꾸고 BLOCK_K 안에서 unpack 추가. 정확도 검증 (§07 의 시퀀스).
  4. group scale 추가 — group_size=128 로 정렬해서 BLOCK_K=128 fix. mean/max rel.err 가 1e-3 / 5e-3 수준으로 떨어지는지.
  5. autotune sweep — BLOCK_M, BLOCK_N ∈ {64, 128, 256}, num_warps ∈ {4, 8}, num_stages ∈ {2, 3, 4}. 자기 GPU 의 sweet spot 표를 만든다.
  6. cuBLAS FP16 baseline 비교 — batch=1, batch=16 두 점에서 latency 측정. 4배 이상 차이 나야 정상.
  7. NCU pass — Triton 커널을 NCU 로 떠본다. occupancy, register/thread, long scoreboard stall 의 세 metric 만 본다 (L001 의 사다리).
§ 11다른 강의로 이어지는 길· connections

같은 도구가 시리즈 안에서 어디에 다시 등장하는가

L034 의 “Triton 으로 cuBLAS 의 자리를 메운다” 는 작업이 다음 강의들에서 어떻게 이어지는지.

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

강의 안에서 부분적으로만 등장한 주제, 또는 후속 작업으로 비워둔 자리들.

검증 메모

이 노트의 latency 수치(0.55ms, 0.78ms 등)는 Hicham 이 강의에서 보여준 슬라이드를 재구성한 예시. 자기 GPU 와 자기 모델 weight 에 대해 직접 측정해야 의미 있는 baseline.

← Lecture 033 이전 강의로 Lecture 035 → SGLang — Yineng Zhang 이 깐 LLM 서빙의 control flow 와 RadixAttention