gpumode · 강의 아카이브
《GPU Mode》 L033 2024 · OCT · 26 High priority transcript · available

Bitblas — Mixed-Precision GEMM with the Ladder Compiler

INT4 / FP4 / FP8 / W4A16 같은 비대칭 mixed-precision GEMM 은 — Tensor Core 의 layout 요구와 layout 변환의 비용 때문에 손으로 짜기 매우 어렵다. Microsoft 의 Wang Lei 가 만든 Bitblas 는 그 자리를 자동 코드 생성으로 푸는 라이브러리 + Ladder 라는 endtoend 컴파일러. 그리고 새 DSL 인 TileLang (TI Language) 으로 Triton-like 작성도. 이 셋의 디자인 결정과 — 같은 GEMM 이 cuBLAS 보다 빠를 수 있는 이유의 학습 노트.

Bitblas Ladder compiler TileLang mixed precision GEMM W4A16 INT4 · INT8 · FP4 · FP8 layout transformation Tensor Core Microsoft Research
W
Speaker
Wang Lei
Microsoft Research · Bitblas / Ladder 저자
강의 번호
L033
스피커
Wang Lei
학습 우선순위
High · 정독
다시 볼 때
TileLang 로 GEMM 짜본다
§ 01강의가 풀려는 문제· low-bit GEMM 의 어려움

“INT4 weight 와 FP16 activation 을 같은 GEMM 안에 넣어라” — 손으로 짜면 거의 불가능한 자리

강의의 출발점 — 모던 LLM inference 는 weight 만 양자화하는 W4A16 패턴이 dominant. weight 는 INT4 (메모리 절감), activation 은 FP16 (정확도 보존). mma 명령은 같은 dtype input 을 요구하는데, 어떻게 둘을 곱하는가?

강의가 답하려는 두 줄 —

  1. INT4 weight 를 FP16 mma 의 input 으로 어떻게 박는가 — dequant 를 어디서 어떻게 할지의 디자인.
  2. 그 dequant + matmul 을 cuBLAS 보다 빠르게 어떻게 만드는가 — vendor 라이브러리가 cover 하지 않는 자리에서.

Wang Lei 의 출발점은 명시적이다. “같은 GEMM 이 — Tensor Core 가 요구하는 정확한 layout 을 못 맞추면 — RTX 3090 위에서 cuBLAS 의 28% perf 만 낸다”. layout 은 단순한 메모리 배치가 아니라 perf 의 결정자. Bitblas 의 모든 디자인이 이 한 사실에서 나온다.

강의의 frame

Bitblas 는 — TVM 의 후예, Ladder 컴파일러의 일부, 그리고 mixed-precision GEMM 라이브러리. “스케줄과 컴퓨트의 분리” 라는 TVM/Triton 의 디자인을 mixed-precision 으로 확장. 사용자는 “이 dtype 끼리의 GEMM” 을 정의하기만 하고, 컴파일러가 layout 변환부터 mma 매핑까지 자동.

“같은 GEMM 코드가 잘못된 layout 위에서는 cuBLAS 의 28%, 잘 짜인 layout 위에서는 cuBLAS 보다 빠릅니다 — layout 이 이렇게 결정적이에요.”Wang Lei · 14:32
§ 02low-bit GEMM 의 trade-off· memory vs Tensor Core

“weight 만 quantize 하면 — 절약은 메모리에서, perf 는 Tensor Core 가 결정”

강의의 작은 역사 — 2018 년대 양자화는 “모델을 더 작은 칩에 넣기” 위함. 2024 년대는 다르다. 모델이 너무 커서 — “GPU 의 cache 와 HBM 안에 들어가게” 하는 의 도구가 양자화. inference 의 throughput 도 메모리 bandwidth 가 dominant.

핵심 trade-off —

  • weight only 양자화 — weight INT4, activation FP16. 메모리 절감 4× (weight 만), 정확도는 거의 그대로. matmul 자체는 FP16 mma.
  • weight + activation 양자화 — 둘 다 INT4. 메모리 + Tensor Core 모두 절감. 정확도 손실이 큼. inference 에는 적합하지만 fine-tuning 은 어렵.
  • 이상적 경로 — weight INT4 + activation FP16, 그런데 mma 가 같은 dtype 만 받음. dequant 를 어디서 할지가 문제.
FIG · weight only 양자화의 메모리 분포 변화Llama-2 7B inference
weight FP16
14 GB
weight INT8
7 GB
weight INT4
3.5 GB
weight INT2
1.75 GB
weight 가 dominant 메모리. INT4 는 4× 절감 — A100 80GB 에 70B 모델이 들어가는 길.
memory-bound 인 자리에서의 win

LLM decode (KV cache 사용 중 next-token 생성) 의 핵심은 — weight 를 HBM 에서 SM 으로 한 번 fetch 한다. weight 가 작을수록 fetch 가 빠르고, throughput 이 비례적으로 증가. INT4 weight 는 단순히 메모리 절감이 아니라 throughput 의 4× 가 되는 자리.

§ 03자동 코드 생성 파이프라인· Ladder 의 흐름

Bitblas 의 본체 — 컴파일러가 layout · pipeline · tile size 를 모두 결정한다

Bitblas 는 라이브러리지만, 그 안에서 도는 건 Ladder 컴파일러. TVM 의 후예로 — “스케줄과 컴퓨트의 분리” + “low-bit 형식의 자동 layout 변환” 이 더해진 endtoend 컴파일러.

FIG · Bitblas/Ladder 의 코드 생성 단계compute 정의 → 자동 tuning → CUDA 코드
L0 · COMPUTE 사용자 정의matmul(int4, fp16) → fp16 의 추상적 식만 박는다. 어떻게 실행할지는 명시 안 함. ~5 줄 Python
L1 · LAYOUT layout transformation passweight 의 INT4 layout 을 Tensor Core 가 요구하는 모양(swizzle, fragment shape)으로 자동 변환. 이게 Bitblas 의 핵심 contribution. 자동
L2 · SCHEDULE tile + pipeline + warp specializationBLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps 의 방대한 search space. tuning 으로 자동. ~수십 candidate
L3 · TILELANG / TIR 중간 IRTVM 의 TIR(Tensor IR) 또는 TileLang. 이 시점에 layout 이 모두 박혀 있음. 중간 자료
L4 · CUDA / PTX 최종 CUDA 코드nvcc 가 받아서 PTX → SASS. 같은 mma 명령을 쓰지만 layout 이 정확. cuBLAS 동급 또는 빠름
사용자 시점은 L0 만. L1, L2 가 Bitblas 의 진짜 가치 — layout 변환을 자동화하고, search space 를 자동 tuning. 사람이 손으로 짠다면 L1 만 며칠.
# Bitblas 사용 — 사용자 시점
import bitblas
from bitblas import Matmul, MatmulConfig

# 1. 추상적 정의 — input/output dtype 만
config = MatmulConfig(
    M=4096, N=4096, K=4096,
    A_dtype="float16",         # activation
    W_dtype="int4",            # weight
    accum_dtype="float16",
    out_dtype="float16",
    layout="nt",
    with_scaling=True,
    group_size=128,
)

# 2. Bitblas 가 자동 — layout 변환 + tuning + 코드 생성
matmul = Matmul(config=config)
matmul.tune()                    # 처음 호출 시 ~수십 초

# 3. 호출 — cuBLAS 같은 인터페이스
out = matmul(activation, weight_int4_packed, scale)
§ 04Tensor Core 위 매핑· layout 변환의 자리

“같은 데이터인데 layout 이 다르면 mma 가 안 받는다” — Bitblas 의 핵심 문제

Tensor Core 의 mma 명령은 — input fragment 가 정확한 layout 으로 들어와야. m16n8k16, m16n8k32 같은 명령마다 register 안 데이터 배치가 박혀 있음. 잘못된 layout 이면 shared memory 를 거쳐 layout conversion 이 일어나고 — 그게 perf 의 절반 이상을 먹는다.

FIG · mma 입력의 fragment layoutm16n8k16 fp16 mma · NVIDIA Ampere
A0
A1
A2
A3
A4
A5
A6
A7
A8
A9
Aa
Ab
Ac
Ad
Ae
Af

↑ 한 warp 의 32 thread 가 16×16 의 A 행렬을 어떻게 분담하는가 — thread t 는 정확한 (row, col) 의 8 element 만 들고 있어야.

shared memory 의 swizzle 패턴도 이 fragment 를 효율적으로 채울 수 있어야 — bank conflict 0. “잘못된 swizzle ⇒ bank conflict ⇒ 한 자리수 perf 손실”.

Bitblas 가 자동으로 푸는 layout 의 종류 —

  • weight pre-permutation — 디스크에 저장된 INT4 weight 를 fragment 모양으로 재배치. inference 시작 전에 한 번.
  • shared memory swizzle — XOR-based 패턴. fragment fill 시 bank conflict 0.
  • register fragment — mma 명령이 정확히 받는 register 매핑.
  • scale layout — group-wise scale 의 배치도 mma 와 함께 흐르도록.
왜 손으로 짜기 어려운가

(a) mma 명령의 종류가 많다 — Ampere/Hopper/Blackwell 별로 다름. (b) 각 명령의 fragment 가 다 다르다. (c) INT4 의 packed layout 이 추가로 — 한 byte 에 두 weight 가 박혀 있어 한 단계 더의 unpack. (d) group-wise scale 의 alignment 까지. 한 (M, N, K, dtype, group_size) 조합마다 layout 이 달라진다 — 그래서 자동 코드 생성이 필요.

§ 05INT4 vs INT8 vs FP4· format 별 trade-off

같은 “4 비트” 인데 다른 길 — 정확도와 hardware support 의 분기

INT8 대칭 양자화 + scale. activation 도 가능. Ampere 부터 mma int8 native. 가장 안정. 2× memory
INT4 (W4A16) weight 만 INT4 + scale, activation FP16. dequant inside-kernel + FP16 mma. 가장 흔한 LLM inference 자리. 4× memory
INT2 / INT3 실험적. 큰 정확도 손실 — outlier-aware scaling 필요. BitNet 의 INT2 가 한 사례. 8×–16×
FP4 (NF4 · MXFP4) 정규분포 가정 codebook. Blackwell 부터 hardware mma. INT4 보다 정확도 좋음. 4× memory
FP8 (e4m3 / e5m2) Hopper 부터 mma native. 학습/추론 모두. weight + activation 둘 다 가능. 2× memory
mixed (W2A8 등) 실험적. weight 매우 작게, activation 적당히. dequant 가 dominant. 조합형

Bitblas 의 매트릭스 — 모든 (W, A) 조합을 cover 함. W4A16, W2A16, W4A8, W4A4, W8A8, FP4 와 FP8 도. 새 조합이 추가될 때마다 layout 변환 pass 만 새로 — 사용자 코드 그대로.

왜 INT4 가 dominant 이 됐는가

(1) 메모리 4× — Llama 70B 가 단일 80GB GPU 에 들어감. (2) 정확도 손실이 작음 — BF16 baseline 대비 perplexity diff < 0.1. (3) Ampere 의 mma 가 INT4 input 을 받음 — Ampere/Hopper 표준. “INT4 가 LLM inference 의 사실상 표준”.

각 format 의 hardware mma 지원 —

  • Ampere (A100, A6000, RTX 30) — fp16, bf16, int8, int4 mma.
  • Ada Lovelace (RTX 40) — Ampere + fp8 (e4m3, e5m2).
  • Hopper (H100, H200) — fp8 wgmma, TMA, async barrier.
  • Blackwell (B100, B200) — fp8 + fp4 mma. INT4 mma 는 deprecated 추세.

Bitblas 는 hardware 별 mma 명령 매핑까지 자동. 같은 GEMM 정의가 Ampere 위 INT4 mma, Hopper 위 fp8 mma 로 lower 된다.

§ 06TileLang — 새 DSL· Triton 의 후속

Triton-like syntax + Tensor Core 직접 통제 — 두 시점이 합쳐진 새 언어

강의 후반부의 surprise — Wang Lei 가 만든 새 DSL TileLang (TI Language). Triton 의 사용자 친화적 syntax 와 — TVM 의 schedule 분리, CUTLASS 의 Tensor Core 직접 통제 — 를 합친 시도.

# TileLang 의 GEMM — Triton 과 닮은 모습
import tilelang.language as T

@T.prim_func
def gemm_kernel(
    A: T.Buffer((M, K), "float16"),
    B: T.Buffer((K, N), "int4"),
    C: T.Buffer((M, N), "float16"),
):
    with T.Kernel(M//128, N//128,
                  threads=128) as (bx, by):
        A_s = T.alloc_shared((128, 32), "float16")
        B_s = T.alloc_shared((32, 128), "int4")
        C_l = T.alloc_local((128, 128), "float16")

        T.clear(C_l)
        for k in T.Pipelined(K//32, num_stages=3):
            T.copy(A[bx*128, k*32], A_s)
            T.copy(B[k*32, by*128], B_s)
            T.gemm(A_s, B_s, C_l, transpose_B=True)
        T.copy(C_l, C[bx*128, by*128])

TileLang 의 디자인 결정 4 가지 —

  • Triton-like programming model. tile/block 단위 op. shared memory, local memory 명시적.
  • T.Pipelined — software pipeline 을 사용자가 직접 켠다. Triton 의 num_stages 와 같은 자리.
  • T.gemm — Tensor Core mma 의 high-level 추상. dtype 매칭은 자동.
  • schedule 의 분리 — 같은 compute 가 다른 hardware 에서 다른 schedule 로 lower (TVM 의 유산).
Triton 과의 비교

Triton = 가장 사용자 친화. layout/pipeline 결정이 컴파일러 안. 디버깅 어려움.
TileLang = 더 explicit. T.alloc_shared, T.Pipelined 같은 hint 가 명시적. 고급 사용자가 직접 통제.
CUTLASS = 가장 explicit. fragment, tile, mma 모두 사용자 결정. C++ 보일러플레이트 큼.
TileLang 은 Triton 과 CUTLASS 의 사이.

§ 07cuBLAS · CUTLASS 와 비교· vendor 라이브러리 vs 자동

“cuBLAS 가 빠르다” — 어떤 자리에서? 그리고 어디서 Bitblas 가 더 빠른가

cuBLAS
FP16/BF16 GEMM최강
INT8 GEMM강함
INT4 weight지원 약함
FP4 / mixed없음
커스텀 layoutX
새 모델 모양몇 달 후
CUTLASS
FP16/BF16 GEMM강함
INT8 GEMM강함
INT4 weight가능 (Marlin)
FP4 / mixed일부
커스텀 layoutO (template)
학습 곡선매우 길다
Bitblas / Ladder
FP16/BF16 GEMMcuBLAS 동급
INT8 GEMMcuBLAS 동급
INT4 weight최강
FP4 / mixed모든 조합
커스텀 layout자동
새 조합config 만 추가
자리별 강점 정리

(1) 표준 dense GEMM (FP16, BF16, INT8) 은 cuBLAS 가 가장 안전. NVIDIA 가 끊임없이 튜닝.
(2) 커스텀 INT4 weight + scale + group size 같은 조합은 Bitblas 가 dominant. cuBLAS 가 cover 안 하는 자리.
(3) fully custom GEMM (특수 epilogue, sparse 등) 은 CUTLASS. 단 학습 곡선과 boilerplate 가 큼.
(4) Bitblas 가 잡는 자리는 — “LLM inference 의 W4A16, W4A8, FP4A16 같은 조합”. cuBLAS 가 비어 있고 사용자 빈도가 폭발하는 자리.

“cuBLAS 가 모든 GEMM 을 cover 한다는 건 신화입니다. 새로 등장하는 양자화 조합마다 vendor 가 따라잡는 데 몇 달이 걸려요. 그 사이를 컴파일러로 메우는 게 우리의 자리입니다.”Wang Lei · 1:02:18
§ 08실측 결과· RTX 3090 · A100 위

“Bitblas 가 같은 GEMM 을 cuBLAS 보다 빠르게 — 어디까지?”

FIG · GEMM throughput — RTX 30904096×4096×4096 · TFLOPS
cuBLAS FP16
110 TF
Bitblas FP16
108 TF
cuBLAS INT8
220 TF
Bitblas W4A16
~4 TB/s effective
Bitblas W2A16
+30% vs W4A16
FP16 표준 GEMM 은 cuBLAS 와 동급. INT4 / INT2 같은 mixed-precision 자리는 Bitblas 가 dominant — cuBLAS 가 cover 하지 않으니까.

강의에서 인용된 측정 —

  • FP16 dense GEMM — Bitblas 가 cuBLAS 의 ~98%. 동급 안 자리.
  • INT4 W4A16 — Bitblas 가 Marlin (CUTLASS-based) 동급 또는 약간 빠름. vendor 가 cover 안 하는 자리에서의 SOTA.
  • 특수 layout (group_size=64, 128, 256) — cuBLAS 가 받지 않음. Bitblas 가 자동 코드 생성으로 조합별 SOTA.
  • 긴 K (K=11008 같이 비표준) — Bitblas 의 자동 tuning 이 vendor 라이브러리보다 잘 맞춤.
FIG · LLM decode latency — Llama-2 7B토큰당 ms · A100
FP16 (cuBLAS)
22 ms
INT8 (Bitblas)
13 ms
INT4 (Bitblas)
8 ms
INT4 (Marlin)
9 ms
decode 의 throughput 이 거의 정확히 메모리 절감 비율로 따라옴. weight bandwidth 가 dominant.
§ 09채택 사례· vLLM · BitNet · Marlin

이미 어디서 쓰이고 있는가

vLLM 의 옵션
vLLM 의 quantization backend 중 한 옵션. AWQ, GPTQ 등 weight only 양자화 모델의 inference 에 사용.
BitNet — 1-bit LLM
Microsoft 의 BitNet 가 INT1 / INT2 weight 를 사용. Bitblas 의 가장 극단 사례 — weight 가 거의 1비트. 메모리 절감의 끝.
Marlin / CUTLASS 와 공존
Marlin (NVIDIA, CUTLASS-based) 와 같은 자리. Bitblas 는 더 다양한 조합 cover 가 강점, Marlin 은 정해진 패턴에서 정점.
PyTorch torchao
torchao 의 backend 중 하나로 통합 진행 중. quantize_ 의 한 옵션.
GPTQ / AWQ 양자화
두 표준 W4A16 양자화의 weight 를 받는 inference engine 중 가장 빠른 자리. AutoGPTQ, AutoAWQ 와 함께.
학술 사용
low-bit research 의 표준 평가 도구. 새 양자화 방법의 perf 측정 baseline.
생태계의 위치

Bitblas 는 “low-bit 양자화의 표준 GEMM kernel 라이브러리” 의 자리를 잡아가는 중. cuBLAS 가 dense GEMM 의 자리이듯, Bitblas 가 mixed-precision GEMM 의 자리. “양자화 모델을 만드는 도구(GPTQ, AWQ) 와 양자화 모델을 돌리는 엔진(vLLM, sglang) 사이의 layer”.

§ 10기억할 메모와 코드· key takeaways

다시 열었을 때 5분 안에 손에 잡혀야 할 것

low-bit GEMM 의 어려움
mma 의 layout 요구 + INT4 packed 의 unpack + group-wise scale alignment + hardware 별 명령 차이. 손으로 짜기 어려운 자리.
Ladder 컴파일러
사용자 시점 = compute 정의만. 컴파일러가 layout 변환 + schedule + tile size + pipeline 자동 결정.
layout transformation pass
Bitblas 의 핵심 contribution. weight 를 fragment 모양으로 자동 재배치. shared memory swizzle 도 자동.
format 별 자리
INT8 (안정), INT4 W4A16 (LLM dominant), FP4 (Blackwell 부터), FP8 (Hopper 부터). Bitblas 가 모두 cover.
TileLang
Triton-like syntax + Tensor Core 직접 통제 + TVM schedule 분리. Triton 과 CUTLASS 사이의 새 DSL.
cuBLAS 와의 자리
표준 dense GEMM 은 cuBLAS 동급. mixed-precision · 비표준 layout 은 Bitblas 가 dominant.
실측 win
FP16 GEMM cuBLAS 의 98%. INT4 W4A16 은 Marlin 동급/약간 빠름. 다양한 (W,A) 조합에서 SOTA.
채택 위치
vLLM/sglang 의 backend, BitNet 의 INT1, AutoGPTQ/AWQ 의 inference, torchao 의 옵션. low-bit 표준 layer.

손에 새기기 — 실습 시퀀스

  1. Bitblas 설치 + 첫 GEMMpip install bitblas. MatmulConfig 으로 W4A16 정의. tune() 후 cuBLAS 와 perf 비교.
  2. 같은 GEMM 의 다른 group_size — 32, 64, 128, 256 으로 tune. 어떤 size 가 자기 GPU 에서 최적인지 직접 측정.
  3. ~/.bitblas 의 cache 보기 — Bitblas 도 specialization cache. 각 (M,N,K,dtype,group_size) 의 생성된 CUDA 코드를 직접 열어본다.
  4. TileLang 첫 커널 — 같은 GEMM 을 TileLang 으로. T.Pipelined, T.gemm 의 사용. Triton 으로 짜는 것과 비교.
  5. Llama 7B 양자화 모델 inference — AutoGPTQ 또는 AWQ 로 양자화 + Bitblas backend. cuBLAS 기반 baseline 과 토큰/초 비교.
  6. 새 dtype 조합 시도 — W2A16 (INT2 weight). 정확도 손실과 perf win 측정.
  7. FP4 시도 (Blackwell 사용 가능 시) — Blackwell hardware 위 FP4 mma 의 perf 확인.
  8. vLLM 안 Bitblas backend — vLLM 으로 LLM serving + Bitblas. throughput 측정.
§ 11다른 강의로 이어지는 길· connections

이 강의의 도구가 시리즈 안에 어떻게 다시 등장하는지

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

검증 메모

이 노트의 perf 수치(cuBLAS 의 98%, Marlin 동급 등)는 강의 시점(2024 10월) 의 측정. Bitblas / TileLang 모두 빠르게 발전 중이고, 새 hardware (Blackwell) 의 mma 명령도 추가되는 중. GitHub 의 README benchmark 를 다시 확인할 것.

← Lecture 032 Unsloth — Daniel Han 의 fine-tuning 가속 Lecture 034 → Low Bit Triton Kernels — Hicham Badri