CUTLASS / cuBLAS 가 다 가져간 4-bit weight-only GEMM 의 자리에서, Triton 만으로 비슷한 성능을 짜내려면 어떤 트릭이 필요한가. HQQ 의 저자 Hicham Badri 가 BitBLAS / Marlin / vLLM 의 in-house 커널들을 한 줄로 비교하면서 — Triton tile 안에서 dequantize-then-multiply 패턴을 어떻게 펴는지, group-wise scale 이 왜 헤더로 따로 packing 되어야 하는지, 같은 코드가 BLOCK_K 한 줄로 두 배 빨라지는 이유까지 깐다.
LLM 추론에서 4-bit weight-only GEMM 이 표준 도구가 됐다. 그런데 빠른 커널은 거의 다 CUTLASS / cuBLAS / Marlin 처럼 깊은 C++ 위에 있다. Hicham 의 질문은 단순하다 — Triton 으로도 같은 성능을 짤 수 있을까. 짤 수 있다면, 어디까지 봐야 하는가.
이 강의는 그 질문에 대한 진행형 답이다. 강의 시점에 Hicham 이 직접 만들고 있는 라이브러리는 HQQ(Half-Quadratic Quantization) 와 그 위에 얹은 Triton 커널 모음 gemlite. “BitBLAS 는 TVM 에 묶여 있어 손대기 어렵고, Marlin 은 4-bit 만, Triton 은 짜기는 쉽지만 성능이 떨어진다” 는 표가 강의 첫 장이다 — 그 빈 자리를 Triton 으로 메운다.
Hicham 의 일관된 입장 — “Triton 의 성능은 코드의 길이가 아니라 tile 안에서 dequant 를 어떻게 펴는가 와 BLOCK 설정 에 달렸다.” 이 두 축으로 강의 전체가 정렬된다. Triton 이 자동으로 tensor core 까지 끌어주는 시점에서, 사람의 일은 register 안에서 unpack/scale 의 시퀀스를 줄이는 것으로 좁혀진다.
그래서 강의 끝에 손에 잡혀야 할 자산은 세 개. (1) dequantize-then-multiply 패턴이 Triton tile 안에서 어떻게 펼쳐지는지의 mental model, (2) group-wise scale 이 왜 packing 단계에서 이미 들어가야 하는지의 이유, (3) 새 GPU(Hopper) 의 TMA 를 Triton 에서 끌어 쓸 때 어디가 막히는지의 한계.
LLM 추론은 두 단계로 나뉜다 — prefill(프롬프트 처리, batch=long) 과 decode(토큰 한 개씩 생성, batch=1). decode 한 step 의 비용은 거의 100% weight 를 HBM 에서 읽는 시간이다. weight 를 4-bit 로 줄이면 같은 GPU 에서 4배 더 빠르게 읽힌다 — 이게 weight-only 의 출발점.
강의에서 Hicham 이 도식으로 보여준 비교 — 16-bit 와 4-bit 같은 layer 의 메모리 트래픽.
그래서 weight-only 가 “가장 싼 압축으로 가장 큰 latency 이득” 의 자리. activation 까지 같이 quantize 하는 W8A8 / W4A8 같은 방식은 prefill 이 dominant 한 throughput-bound 시나리오에서 다시 의미가 생긴다.
강의의 이 슬라이드가 끝나면, 다음 한 줄이 자연스럽게 따라온다 — “그러면 4-bit weight 를 어떻게 packing 하고, 어떻게 GEMM 안에서 16-bit 로 펴서 tensor core 에 먹이는가.” 그게 §03 의 본론.
low-bit GEMM 의 핵심 패턴 — 4-bit packed weight 를 HBM 에서 읽어 SMEM 에 올린 뒤, register 위에서 nibble 을 펴고 group scale 을 곱해서 16-bit tile 로 복원, 그걸 그대로 tensor core 에 넘긴다. 이 시퀀스가 BLOCK_K 단위로 반복된다.
그림은 단순하지만 실제로는 두 가지 결정이 묶여 있다. 첫째는 packing layout — 4-bit 두 개를 한 byte 에 넣는 방식이 cuBLAS/Marlin/Triton 모두 다르다. 둘째는 scale 의 위치 — 매 group 마다 한 번 scale 을 곱하는데, 이걸 K 축의 어느 시점에 끼워넣느냐가 register 압박을 결정한다.
Hicham 이 강의에서 명시적으로 말한 한 줄 — “Marlin 의 4-bit packing 과 Triton 의 packing 이 다르다. 같은 weight 라도 둘은 호환되지 않고, repack 코드를 따로 짜야 한다.” CUTLASS 가 정한 사실상의 표준이 있고, Triton 은 자기만의 layout 을 쓴다. fast Triton kernel 과 fast CUDA kernel 을 한 모델에 같이 쓰려면 weight 를 두 번 packing 해서 디스크에 저장해야 한다.
4-bit 만으로 정확도를 유지하려면 한 weight matrix 에 scale 하나 로는 부족하다. K 축 방향으로 일정 group 단위로 scale 을 따로 가지면 정확도가 살아난다. group_size = 128 이 GPTQ / AWQ / HQQ 모두에서 사실상 표준.
강의에서 Hicham 이 강조한 사실 — group-wise scale 은 정확도 의 문제만이 아니라 커널 구현 의 문제이기도 하다. K 방향으로 BLOCK_K 만큼 진행할 때마다 scale 을 다시 읽어와야 한다. BLOCK_K 와 group_size 가 정렬되지 않으면 한 iteration 안에서 scale 이 바뀌는 경우가 생기고 — 그러면 dequant 의 분기가 늘어나며 register 압박이 폭발한다.
BLOCK_K = group_size 또는 BLOCK_K 가 group_size 의 배수가 되도록 launch 설정을 잡는다. 둘이 어긋나면 같은 코드라도 30~50% 느려진다. autotune config 에서도 이 제약을 잡아두는 게 첫 단계.
다음 그림은 16개 4-bit weight 가 한 byte 8 개에 packing 된 모습 — 그리고 그 위에 어떻게 group scale 이 묶이는지.
Triton 안에서 packing 된 byte 를 두 nibble 로 펴는 코드는 짧다. 어려운 건 그게 BLOCK_M × BLOCK_K tile 단위로 동시에 일어난다는 사실 — 그래서 unpack 결과 텐서가 register 위에서 두 배 크기로 늘어난다.
# gemlite 류 Triton 커널의 BLOCK_K iteration 안 — 핵심부
# a: FP16 activation tile [BLOCK_M, BLOCK_K]
# b_packed: INT8 weight tile [BLOCK_K // 2, BLOCK_N]
# scales: FP16 scale [num_groups, BLOCK_N]
# 1) packed byte 한 장 load — half size 만 읽는다
b_q = tl.load(b_ptr + offs_pk, mask=...) # int8
# 2) two-nibble unpack — register 위에서 두 배로 펼친다
b_lo = (b_q & 0xF).to(tl.int8) - 8 # signed shift
b_hi = ((b_q >> 4) & 0xF).to(tl.int8) - 8
# 3) group scale broadcast 후 곱 — FP16 으로 복원
g = (offs_k // GROUP_SIZE)
sc = tl.load(s_ptr + g * stride_sg + offs_n)
b_lo = b_lo.to(tl.float16) * sc
b_hi = b_hi.to(tl.float16) * sc
# 4) MMA — Triton 이 tensor core 로 dispatch
acc += tl.dot(a, tl.cat(b_lo, b_hi, axis=0))
이 짧은 시퀀스가 한 BLOCK_K iteration 안에서 매번 돈다. 강의에서 Hicham 이 짚은 압박 포인트는 세 개.
b_lo, b_hi 둘 다 register 에 살아 있는다. scales 까지 더하면 BLOCK_K, BLOCK_N 이 커질수록 register spill 위험.num_stages) 을 깐다. 다음 K-step 의 unpack 이 현재 step 의 MMA 와 겹쳐야 효율이 산다.Triton 은 tl.dot 의 두 피연산자에 대해 알아서 tensor core MMA 로 dispatch 한다. 사람이 wmma::* 를 직접 쓸 필요는 없다 — 다만 piece 가 FP16/BF16 으로 cast 되어야 dispatch 가 일어남. INT8 까지 내릴 거면 tl.dot(..., out_dtype=tl.int32) 형태로 명시.
Triton 의 launch 설정이 성능을 결정한다는 §06(L001) 의 일반 원리가, low-bit 에서는 group_size 와의 정렬 이라는 추가 제약과 합쳐지며 더 좁아진다.
위 표의 메시지 — BLOCK_K = group_size = 128 이 sweet spot. 거기서 더 키우면 register spill, 거기서 줄이면 group scale broadcast 비용이 매 iteration 마다 누적. num_stages 도 3 이 균형점이고 4 로 키우면 software pipeline 을 위한 SMEM 이 부족해진다.
강의의 gemlite 가 깐 autotune config 는 약 20개 — BLOCK 셋의 곱과 num_warps ∈ {2,4,8}, num_stages ∈ {2,3,4} 의 부분집합. 새 모델/새 GPU 에 처음 띄울 때 이 config 가 한 번 도는 데 수십 초가 걸리지만, 그 결과는 ~/.triton/cache 에 박힌다.
low-bit GEMM 의 첫 검증은 reference matmul (FP16 cuBLAS) 와의 결과 차이. 단순한 max-abs-diff 로는 모자라다 — quantization 의 error 는 분포 형태로 본다.
Hicham 이 추천한 검증 시퀀스.
(y_q - y_ref) / (|y_ref| + ε). mean 과 max 둘 다 본다. 4-bit + g=128 이면 mean ~1e-3, max ~5e-3 정도가 정상.특히 group_size 와 BLOCK_K 가 어긋난 버그는 random input 에는 안 잡히고 실제 모델 weight 에서만 깨진다 — outlier 가 group boundary 위에 정확히 떨어질 때만 보이는 패턴.
강의의 마지막 절반은 같은 weight, 같은 GPU 에서 라이브러리별 latency 비교. Triton(gemlite) / Marlin / BitBLAS / cuBLAS FP16 / vLLM 의 GPTQ kernel 을 같은 매트릭스에 올린다.
표가 보여주는 두 가지. 첫째 — Triton 이 cuBLAS 보다 5~6배 빠르다. weight 를 4-bit 로 줄였으니 당연하다. 둘째 — Marlin / BitBLAS 와 30% 정도 차이. 이게 “Triton 이 거의 따라잡았다” 의 의미. 강의 시점 기준으로, customize 가능성을 포기하지 않으면서 Marlin 의 80% 까지 와 있다.
Triton 도 tl.experimental.descriptor_load 로 TMA 를 끌어 쓸 수 있다. Hicham 의 실험에서 GEMM 부분은 reference 와 동일한 성능이 나왔지만 — 빌드가 까다롭고 30분 정도 직접 컴파일이 필요. H100 위에서 Triton 의 격차가 더 줄어들 것이라는 게 강의의 전망.
4-bit quantization 은 한 weight 당 16개의 가능한 값 을 정해두는 일이다. 그 16개를 어떻게 고르느냐 — uniform grid (INT4) / non-uniform (NF4) / FP-style (FP4) — 가 정확도와 커널 단순도를 동시에 결정한다.
Hicham 의 입장 — 실용 라이브러리는 INT4 를 표준으로 두고 NF4 를 옵션으로 둔다. 이유는 dequant 비용. NF4 는 LUT 가 register 에 살아 있어야 해서 register 압박이 더 심하고, INT4 의 shift & mask 두 줄과 비교하면 SASS 길이가 두 배 가까워진다. 정확도 차이가 작은 모델에서는 INT4 가 균형 있다.
강의에서 6개월 뒤에 돌아왔을 때 가장 빨리 복원해야 하는 사실들과 — 직접 손에 박아야 하는 코드 자료들.
tl.dot. SMEM 다시 적으면 메모리 이득 깨진다.(b_q & 0xF) - 8, ((b_q >> 4) & 0xF) - 8 가 INT4 unpack 의 표준 형태.tl.dot 의 형태와 launch 설정에 익숙해진다. 이게 다음 단계의 baseline.L034 의 “Triton 으로 cuBLAS 의 자리를 메운다” 는 작업이 다음 강의들에서 어떻게 이어지는지.
tl.dot, BLOCK 설정의 일반론num_stages, software pipelining 의 IR 단계강의 안에서 부분적으로만 등장한 주제, 또는 후속 작업으로 비워둔 자리들.
tl.dot 의 FP8 path 와 비슷한 형태로 Triton 이 받아들일 가능성.torch.library.custom_op 로 wrap 하면 torch.compile 이 그 boundary 를 못 fuse 한다는 사실. Hicham 이 지나가는 식으로 언급. 정확한 회피 패턴은 후속 강의에서.이 노트의 latency 수치(0.55ms, 0.78ms 등)는 Hicham 이 강의에서 보여준 슬라이드를 재구성한 예시. 자기 GPU 와 자기 모델 weight 에 대해 직접 측정해야 의미 있는 baseline.