INT4 / FP4 / FP8 / W4A16 같은 비대칭 mixed-precision GEMM 은 — Tensor Core 의 layout 요구와 layout 변환의 비용 때문에 손으로 짜기 매우 어렵다. Microsoft 의 Wang Lei 가 만든 Bitblas 는 그 자리를 자동 코드 생성으로 푸는 라이브러리 + Ladder 라는 endtoend 컴파일러. 그리고 새 DSL 인 TileLang (TI Language) 으로 Triton-like 작성도. 이 셋의 디자인 결정과 — 같은 GEMM 이 cuBLAS 보다 빠를 수 있는 이유의 학습 노트.
강의의 출발점 — 모던 LLM inference 는 weight 만 양자화하는 W4A16 패턴이 dominant. weight 는 INT4 (메모리 절감), activation 은 FP16 (정확도 보존). mma 명령은 같은 dtype input 을 요구하는데, 어떻게 둘을 곱하는가?
강의가 답하려는 두 줄 —
Wang Lei 의 출발점은 명시적이다. “같은 GEMM 이 — Tensor Core 가 요구하는 정확한 layout 을 못 맞추면 — RTX 3090 위에서 cuBLAS 의 28% perf 만 낸다”. layout 은 단순한 메모리 배치가 아니라 perf 의 결정자. Bitblas 의 모든 디자인이 이 한 사실에서 나온다.
Bitblas 는 — TVM 의 후예, Ladder 컴파일러의 일부, 그리고 mixed-precision GEMM 라이브러리. “스케줄과 컴퓨트의 분리” 라는 TVM/Triton 의 디자인을 mixed-precision 으로 확장. 사용자는 “이 dtype 끼리의 GEMM” 을 정의하기만 하고, 컴파일러가 layout 변환부터 mma 매핑까지 자동.
강의의 작은 역사 — 2018 년대 양자화는 “모델을 더 작은 칩에 넣기” 위함. 2024 년대는 다르다. 모델이 너무 커서 — “GPU 의 cache 와 HBM 안에 들어가게” 하는 의 도구가 양자화. inference 의 throughput 도 메모리 bandwidth 가 dominant.
핵심 trade-off —
LLM decode (KV cache 사용 중 next-token 생성) 의 핵심은 — weight 를 HBM 에서 SM 으로 한 번 fetch 한다. weight 가 작을수록 fetch 가 빠르고, throughput 이 비례적으로 증가. INT4 weight 는 단순히 메모리 절감이 아니라 throughput 의 4× 가 되는 자리.
Bitblas 는 라이브러리지만, 그 안에서 도는 건 Ladder 컴파일러. TVM 의 후예로 — “스케줄과 컴퓨트의 분리” + “low-bit 형식의 자동 layout 변환” 이 더해진 endtoend 컴파일러.
matmul(int4, fp16) → fp16 의 추상적 식만 박는다. 어떻게 실행할지는 명시 안 함.
~5 줄 Python
# Bitblas 사용 — 사용자 시점
import bitblas
from bitblas import Matmul, MatmulConfig
# 1. 추상적 정의 — input/output dtype 만
config = MatmulConfig(
M=4096, N=4096, K=4096,
A_dtype="float16", # activation
W_dtype="int4", # weight
accum_dtype="float16",
out_dtype="float16",
layout="nt",
with_scaling=True,
group_size=128,
)
# 2. Bitblas 가 자동 — layout 변환 + tuning + 코드 생성
matmul = Matmul(config=config)
matmul.tune() # 처음 호출 시 ~수십 초
# 3. 호출 — cuBLAS 같은 인터페이스
out = matmul(activation, weight_int4_packed, scale)
Tensor Core 의 mma 명령은 — input fragment 가 정확한 layout 으로 들어와야. m16n8k16, m16n8k32 같은 명령마다 register 안 데이터 배치가 박혀 있음. 잘못된 layout 이면 shared memory 를 거쳐 layout conversion 이 일어나고 — 그게 perf 의 절반 이상을 먹는다.
↑ 한 warp 의 32 thread 가 16×16 의 A 행렬을 어떻게 분담하는가 — thread t 는 정확한 (row, col) 의 8 element 만 들고 있어야.
Bitblas 가 자동으로 푸는 layout 의 종류 —
(a) mma 명령의 종류가 많다 — Ampere/Hopper/Blackwell 별로 다름. (b) 각 명령의 fragment 가 다 다르다. (c) INT4 의 packed layout 이 추가로 — 한 byte 에 두 weight 가 박혀 있어 한 단계 더의 unpack. (d) group-wise scale 의 alignment 까지. 한 (M, N, K, dtype, group_size) 조합마다 layout 이 달라진다 — 그래서 자동 코드 생성이 필요.
Bitblas 의 매트릭스 — 모든 (W, A) 조합을 cover 함. W4A16, W2A16, W4A8, W4A4, W8A8, FP4 와 FP8 도. 새 조합이 추가될 때마다 layout 변환 pass 만 새로 — 사용자 코드 그대로.
(1) 메모리 4× — Llama 70B 가 단일 80GB GPU 에 들어감. (2) 정확도 손실이 작음 — BF16 baseline 대비 perplexity diff < 0.1. (3) Ampere 의 mma 가 INT4 input 을 받음 — Ampere/Hopper 표준. “INT4 가 LLM inference 의 사실상 표준”.
각 format 의 hardware mma 지원 —
Bitblas 는 hardware 별 mma 명령 매핑까지 자동. 같은 GEMM 정의가 Ampere 위 INT4 mma, Hopper 위 fp8 mma 로 lower 된다.
강의 후반부의 surprise — Wang Lei 가 만든 새 DSL TileLang (TI Language). Triton 의 사용자 친화적 syntax 와 — TVM 의 schedule 분리, CUTLASS 의 Tensor Core 직접 통제 — 를 합친 시도.
# TileLang 의 GEMM — Triton 과 닮은 모습
import tilelang.language as T
@T.prim_func
def gemm_kernel(
A: T.Buffer((M, K), "float16"),
B: T.Buffer((K, N), "int4"),
C: T.Buffer((M, N), "float16"),
):
with T.Kernel(M//128, N//128,
threads=128) as (bx, by):
A_s = T.alloc_shared((128, 32), "float16")
B_s = T.alloc_shared((32, 128), "int4")
C_l = T.alloc_local((128, 128), "float16")
T.clear(C_l)
for k in T.Pipelined(K//32, num_stages=3):
T.copy(A[bx*128, k*32], A_s)
T.copy(B[k*32, by*128], B_s)
T.gemm(A_s, B_s, C_l, transpose_B=True)
T.copy(C_l, C[bx*128, by*128])
TileLang 의 디자인 결정 4 가지 —
Triton = 가장 사용자 친화. layout/pipeline 결정이 컴파일러 안. 디버깅 어려움.
TileLang = 더 explicit. T.alloc_shared, T.Pipelined 같은 hint 가 명시적. 고급 사용자가 직접 통제.
CUTLASS = 가장 explicit. fragment, tile, mma 모두 사용자 결정. C++ 보일러플레이트 큼.
TileLang 은 Triton 과 CUTLASS 의 사이.
(1) 표준 dense GEMM (FP16, BF16, INT8) 은 cuBLAS 가 가장 안전. NVIDIA 가 끊임없이 튜닝.
(2) 커스텀 INT4 weight + scale + group size 같은 조합은 Bitblas 가 dominant. cuBLAS 가 cover 안 하는 자리.
(3) fully custom GEMM (특수 epilogue, sparse 등) 은 CUTLASS. 단 학습 곡선과 boilerplate 가 큼.
(4) Bitblas 가 잡는 자리는 — “LLM inference 의 W4A16, W4A8, FP4A16 같은 조합”. cuBLAS 가 비어 있고 사용자 빈도가 폭발하는 자리.
강의에서 인용된 측정 —
quantize_ 의 한 옵션.Bitblas 는 “low-bit 양자화의 표준 GEMM kernel 라이브러리” 의 자리를 잡아가는 중. cuBLAS 가 dense GEMM 의 자리이듯, Bitblas 가 mixed-precision GEMM 의 자리. “양자화 모델을 만드는 도구(GPTQ, AWQ) 와 양자화 모델을 돌리는 엔진(vLLM, sglang) 사이의 layer”.
pip install bitblas. MatmulConfig 으로 W4A16 정의. tune() 후 cuBLAS 와 perf 비교.이 노트의 perf 수치(cuBLAS 의 98%, Marlin 동급 등)는 강의 시점(2024 10월) 의 측정. Bitblas / TileLang 모두 빠르게 발전 중이고, 새 hardware (Blackwell) 의 mma 명령도 추가되는 중. GitHub 의 README benchmark 를 다시 확인할 것.