gpumode · 강의 아카이브
《GPU Mode》 L030 2024 · OCT · 06 High priority transcript · available

Quantized Training

Inference 양자화는 비교적 쉬운 길이지만 — 학습 시 양자화는 forward 의 정확도와 backward 의 gradient flow 를 동시에 깨뜨리지 않으면서 메모리/속도를 따 내야 하는 어려운 문제다. Thien Tran 이 깐 master weights, stochastic rounding, low-bit optimizer state 의 디자인 결정과 — Triton 위에서 dequantize/quantize 를 shared memory 안에서 끝내는 패턴의 학습 노트.

quantized training FP8 INT8 stochastic rounding master weights 8-bit optimizer e4m3 / e5m2 torchao delayed scaling
T
Speaker
Thien Tran
torchao contributor · gau-nernst (GitHub)
강의 번호
L030
스피커
Thien Tran
학습 우선순위
High · 정독
다시 볼 때
8-bit optimizer 직접 짠다
§ 01강의가 풀려는 문제· 왜 학습 양자화가 어려운가

“양자화 학습” 은 두 개의 다른 문제를 동시에 푼다 — 메모리와 throughput

Inference 양자화는 결과가 정해진 weight 를 줄이는 일이다. 학습 양자화는 — weight 가 매 step 바뀌고, gradient 가 weight 를 업데이트하고, optimizer state 가 그 사이를 잇는다. 정밀도를 깨면 어디서든 학습이 발산할 수 있다.

Thien 이 강의 첫 5 분에 깐 두 줄 —

  1. 학습 메모리의 어디를 어떻게 줄일 것인가 — weight, gradient, optimizer state, activation 의 4 자리. 자리마다 양자화의 길이 다르다.
  2. Tensor Core 의 low-bit compute(FP8/INT8 mma)를 어떻게 학습에 박을 것인가 — 단순 메모리 절감이 아니라 throughput 도 따 내려면 mma 입력이 low-bit 여야 한다.

강의의 frame 은 — “정밀도를 일률적으로 줄이지 말라. 자리마다 다른 정밀도를 쓰는 게 정답”. forward 의 mma input 은 FP8/INT8, accumulation 은 FP32, master weight 는 FP32, optimizer state 는 8-bit, gradient 는 BF16. 한 모델 안에 5–6 개 정밀도가 동시에 산다.

강의의 frame

이 강의의 모든 디자인 결정은 한 질문에서 나온다 — “이 자리에서 정밀도를 깎으면 학습이 발산하는가”. 발산하지 않는 자리는 깎고, 발산하는 자리는 fp32 를 유지한다. 그 분기점을 찾는 것이 핵심.

“Inference 양자화에서는 ‘이 weight 가 얼마나 정확한가’만 보면 됩니다. 학습에서는 — 그 weight 가 어떻게 움직이는가 도 봐야 해요. gradient 의 부호 한 비트가 잘못되면 발산합니다.”Thien Tran · 04:32
§ 02학습 시 메모리의 분포· weight · grad · opt state · act

“학습 메모리의 60% 가 weight 가 아니다” — 어디서 잡힐 수 있는지의 지도

강의의 첫 그림 — Llama-3 8B 학습 시 메모리 사용을 분해. weight 자체는 약 16 GB(BF16), 그런데 학습 step 의 peak 는 60–80 GB. 차이가 어디서 오는지 알아야 양자화의 표적이 잡힌다.

FIG · 8B BF16 학습 step 의 메모리 분포Adam optimizer · seq 4096 · batch 16
model weights (BF16)
16 GB
gradients (BF16)
16 GB
Adam state (FP32 m + v)
64 GB
master weights (FP32)
32 GB
activations (BF16)
~24 GB
Adam 의 m, v 가 각각 fp32 모델 크기와 같다. 두 개니까 weight 의 4 배. 여기에 master weight 가 또 한 벌. weight 의 8 배 + activation 이 학습 메모리의 본체.

각 자리의 양자화 가능성 —

  • weight — BF16 까지는 안전. FP8 weight 는 mma 와 직결. delicate.
  • gradient — BF16 표준. FP8 grad 는 delayed scaling 이 필요. 발산의 주요 원인.
  • optimizer state (m, v)가장 많이 차지. 가장 양자화 안전한 자리. 8-bit Adam 이 표준.
  • master weights — FP32 한 벌. 정확도가 직결되니 항상 fp32. 단 stochastic rounding 으로 유지하면 BF16 으로 줄일 수 있다.
  • activations — BF16 대부분. recompute 로 줄이는 게 우선, 양자화는 까다로움.
Thien 의 우선순위

가장 큰 win 은 optimizer state 의 8-bit 화. 64 GB 가 16 GB 로 — 전체의 30% 절감. accuracy 영향이 거의 없다 (state 의 정확도가 학습에 직결되지 않으니까). 그 다음이 master weight 의 BF16 화 + stochastic rounding. 그 다음이 mma 의 FP8 화. 이 순서가 risk/reward 의 sweet spot.

§ 03forward/backward 정밀도 분리· low-bit compute, high-bit accum

한 mma 안에 두 정밀도가 산다 — input 은 FP8, accumulator 는 FP32

Tensor Core 의 mma 명령 자체가 혼합 정밀도다. 입력은 FP16/BF16/FP8/INT8, 출력은 FP32. 학습 양자화의 핵심 메커니즘은 그 “입력만 줄인다” 는 점에 있다.

FIG · 정밀도의 사다리 — 같은 mma 안에서storage vs compute
FP32 1+8+23 / 32b
BF16 1+8+7 / 16b
FP16 1+5+10 / 16b
FP8 e4m3 1+4+3 / 8b
FP8 e5m2 1+5+2 / 8b
INT8 signed 8b
같은 mma 의 두 시간 단위: input 의 BF16/FP8/INT8 → 누적은 항상 fp32. 그래서 input precision 만 깎고 sum 의 정확도는 유지된다.
# forward: weight FP8 + activation FP8 → output FP32 (cast back)
def linear_fp8_forward(x_bf16, w_bf16):
    # amax 추적 + scale
    s_x = compute_scale(x_bf16)
    s_w = compute_scale(w_bf16)
    x_fp8 = (x_bf16 / s_x).to(fp8)
    w_fp8 = (w_bf16 / s_w).to(fp8)

    # Tensor Core mma — FP8 in, FP32 out
    out_fp32 = mma_fp8(x_fp8, w_fp8)

    # scale 복원 후 BF16 output
    return (out_fp32 * s_x * s_w).to(bf16)

이 패턴의 본질 —

  • compute 정밀도 ≠ storage 정밀도. mma 의 input 만 low-bit. accumulator 는 fp32 그대로.
  • scaling 이 핵심. FP8 의 dynamic range 가 좁으니 — 입력의 amax 를 보고 scale 한 다음 cast.
  • scale 은 metadata 로 따라다닌다. PyTorch 의 torch._scaled_mm 같은 API 가 (tensor, scale) 쌍을 받는다.
  • backward 도 같은 패턴. weight grad / activation grad 둘 다 FP8 입력 + fp32 누적.
delayed scaling 의 의미

scale 을 매 step 새로 계산하면 한 비트가 흔들려 학습이 발산할 수 있다. 그래서 NVIDIA Transformer Engine 등은 delayed scaling — 지난 N 개 step 의 amax 를 history 로 저장해 max 를 scale 로. 학습이 stable.

§ 04stochastic rounding· unbiased gradient 의 길

round-to-nearest 가 아니라 — “확률적으로 위 또는 아래로”

학습 양자화의 가장 우아한 trick 중 하나. 작은 gradient 가 weight update 시 round-to-nearest 으로 0 이 되는 문제를 — 확률적 반올림으로 해결.

문제 상황을 구체적으로 —

  • BF16 master weight: 0.5000 (mantissa 7 비트)
  • gradient × lr: 0.0001 (BF16 의 LSB 보다 작음)
  • round-to-nearest 결과: 0.5000 (gradient 가 사라짐)
  • 여러 step 누적: 0.0001 × 1000 = 0.1 인데 — 매 step 에서 사라지면 영원히 안 더해진다

stochastic rounding 은 — 0.5001 을 받았을 때, 0.5000 으로 99% 확률, 0.5001 의 다음 representable 값으로 1% 확률로 round. 기댓값이 정확하다 (unbiased). 1000 step 누적시 평균적으로 0.1 가 더해진다.

# stochastic rounding — Triton 안에서
@triton.jit
def sr_cast_fp32_to_bf16(x_fp32):
    # 1. fp32 의 lower 16 bit 를 random bit 로 가산
    rand = tl.rand(seed, offset)         # [0, 1) uniform
    rand_bits = (rand * 65536).to(tl.uint32)

    # 2. fp32 비트 직접 manipulation
    bits = x_fp32.to(tl.uint32, bitcast=True)
    bits = bits + rand_bits              # lower 16b 에 random 가산

    # 3. truncate to bf16 (= upper 16 bits)
    return (bits >> 16).to(tl.bfloat16, bitcast=True)

핵심 — round-to-nearest 보다 비용이 거의 같다. random 한 번, addition 한 번. 그런데 학습 양자화의 정확도 손실이 한 자리수 줄어든다.

stochastic rounding 의 위치

fp32 → bf16 cast(master → working weight), bf16 → fp8 cast 등 모든 down-cast 자리에서 stochastic rounding 이 default 가 되는 추세. PyTorch 의 torch.ao, NVIDIA TE 모두 옵션으로 제공.

“round-to-nearest 는 작은 gradient 를 영원히 죽입니다. stochastic rounding 은 — 평균적으로 정확한 update 가 됩니다. 학습 양자화에서 한 줄 trick 으로 가장 큰 효과.”Thien Tran · 28:50
§ 05master weights 패턴· 왜 fp32 한 벌이 살아있는가

“mixed precision training” 의 구조 — 두 weight 가 동시에 산다

학습 시 weight 는 사실상 두 벌로 산다 — forward/backward 에 쓰이는 working weight (BF16/FP8) 와 optimizer 가 업데이트하는 master weight (FP32). 이 분리가 mixed-precision training 의 본질.

FIG · 한 step 의 master/working weight 흐름두 정밀도의 ping-pong
1
cast
master FP32 → working BF16
2
forward
BF16 weight × BF16 act → loss
3
backward
BF16 grad 산출
4
upcast
grad → FP32
5
opt step
master FP32 update
step 5 가 fp32 인 이유 — Adam 의 v = β₂·v + (1-β₂)·g² 에서 g² 가 매우 작은 값일 수 있고, 누적이 fp16 로는 부정확. 그래서 master 와 optimizer state 는 FP32 가 표준.
working weight
storageBF16 또는 FP8
역할forward · backward
크기2N 또는 N bytes
업데이트매 step master 에서 cast
master weight
storageFP32
역할optimizer 가 업데이트
크기4N bytes
업데이트매 step opt step 결과

이 분리의 미묘한 디테일들 —

  • master weight 가 메모리에서 가장 큰 자리. weight 의 4 배(FP32 vs FP16 storage 가 8 byte vs 2 byte). FSDP/ZeRO 의 sharding 이 이 자리에서 가장 큰 효과.
  • master weight 의 BF16 화 + stochastic rounding. 2024 의 새 추세. 메모리를 또 절반으로.
  • cast 의 cost — 매 step 마다 working = master.bf16(). N 의 weight 를 cast 하니 적지 않은 일.
FSDP / DeepSpeed 의 자리

FSDP 는 master weight 를 GPU 들 사이에 shard. opt step 시 자기 shard 만 업데이트. master 가 가장 큰 자리니까 sharding 의 효과가 비례적. ZeRO-3 도 같은 마음.

§ 06low-bit optimizer state· 8-bit Adam

“optimizer state 의 m, v 를 8-bit 로 — 학습 결과는 그대로”

학습 양자화의 가장 큰 win 자리. Tim Dettmers 의 bitsandbytes 가 처음 깐 패턴을 Thien 이 직접 강의에서 재해석. 핵심은 “dequant + update + quant 를 한 Triton 커널 안에서 끝낸다”.

# 8-bit Adam 의 한 step — Triton (요약)
@triton.jit
def adam_8bit_step(p_ptr, g_ptr,
                    m8_ptr, v8_ptr,
                    sm_ptr, sv_ptr,
                    lr, b1, b2, eps,
                    BLOCK: tl.constexpr):
    pid    = tl.program_id(0)
    offs   = pid * BLOCK + tl.arange(0, BLOCK)

    # 1. shared memory 안에서 dequant
    m8     = tl.load(m8_ptr + offs)             # int8
    v8     = tl.load(v8_ptr + offs)
    sm     = tl.load(sm_ptr + pid)              # scale
    sv     = tl.load(sv_ptr + pid)
    m      = m8.to(tl.float32) * sm
    v      = v8.to(tl.float32) * sv

    # 2. Adam update — fp32 안에서
    g      = tl.load(g_ptr + offs).to(tl.float32)
    m      = b1 * m + (1 - b1) * g
    v      = b2 * v + (1 - b2) * g * g
    p      = tl.load(p_ptr + offs)
    p_new  = p - lr * m / (tl.sqrt(v) + eps)

    # 3. 다시 quantize — 새 scale 산출
    abs_m  = tl.max(tl.abs(m))
    sm_new = abs_m / 127.
    m8_new = (m / sm_new).to(tl.int8)

    tl.store(p_ptr + offs, p_new)
    tl.store(m8_ptr + offs, m8_new)
    tl.store(sm_ptr + pid, sm_new)
    # v 도 같은 패턴 ...

이 커널의 핵심 디자인 결정 —

  • dequantized state 를 global memory 에 절대 안 쓴다. 모든 fp32 값이 register/SRAM 에서만 살아있음. 강의 Q&A 의 핵심 질문.
  • block-wise scaling. 한 BLOCK 단위(예: 2048 element) 의 amax 로 한 scale. tensor-wise 보다 정확, element-wise 보다 cheap.
  • Adam 의 b1, b2 가 1 에 가까우니 m, v 가 매우 안정. 양자화 noise 의 영향이 작다.
Q&A 의 핵심

강의 중 한 청중이 — “BLOCK 안에서 row-wise scaling 을 쓰면 row 가 너무 클 때 reduction 이 한 BLOCK 에 안 들어가지 않나?” 라고 질문. Thien 의 답: “그래서 group size 를 bias 시키되 한 SM 의 shared memory 에 들어오게 결정한다. 작은 group 이 안전하다”.

FIG · 메모리 절감 — 8-bit AdamLlama-3 8B
FP32 Adam (m + v)
64 GB
FP16 Adam
32 GB
8-bit Adam
16 GB (-75%)
4-bit Adam (실험)
9 GB
8-bit Adam 만으로 학습 메모리 30% 절감. bitsandbytes 의 표준. accuracy 영향 거의 0.
§ 07FP8 의 e4m3 / e5m2· 두 포맷의 분업

왜 FP8 는 두 포맷인가 — forward 와 backward 의 다른 distribution

FP8 spec 은 두 가지 — e4m3(exponent 4 + mantissa 3)와 e5m2(exponent 5 + mantissa 2). 같은 8 비트지만 trade-off 가 다르다. 학습에서 어디에 어떤 포맷을 쓰는지가 결정적.

e4m3 — 정확도 우선
exponent 4 비트 → range ±448
mantissa 3 비트 → relative precision ~1/16

forward activation, weight 에 사용. 값의 범위가 좁고 정확도가 중요한 자리.
denormal subnormal 까지 활용해 precision 을 짜냄.
e5m2 — 범위 우선
exponent 5 비트 → range ±57344
mantissa 2 비트 → relative precision ~1/8

backward gradient 에 사용. gradient 는 분포가 넓고(작은 grad + 큰 grad 동시), 정확도보다 dynamic range 가 중요.
IEEE 754 fp16 와 같은 exponent 폭.
NVIDIA Transformer Engine 의 표준

activation / weight = e4m3. gradient = e5m2. 두 포맷이 mma 의 다른 자리에서 만난다 — forward 의 w_e4m3 × a_e4m3 → fp32, backward 의 g_e5m2 × ... → fp32. FP16 mixed-precision 시대의 patterned trade-off 가 더 정교화된 버전.

Hopper 의 mma 는 두 포맷을 모두 input 으로 받는다. tcgen05.mma{f8e4m3, f8e5m2} 같은 명령. Triton 도 3.x 부터 양쪽 모두.

학습 시 한 mma 의 흐름을 풀면 —

  1. forward: out = w_e4m3 @ a_e4m3. fp32 accumulation. cast back to bf16/fp8.
  2. backward (input grad): g_in = w_e4m3.T @ g_out_e5m2. weight 는 forward 와 같은 e4m3, grad 는 e5m2.
  3. backward (weight grad): g_w = a_e4m3.T @ g_out_e5m2. forward 의 activation 을 재사용.

각 mma 마다 두 input 의 scale 을 따로 추적해야 한다. 그래서 scale 메타데이터 텐서가 weight/activation/grad 옆에 따라다닌다.

§ 08실측 사례· torchao · TE 비교

“이 디자인이 실제로 학습 결과를 깨뜨리지 않는가”

강의 후반의 본론. Thien 이 torchao 의 quantized training stack 으로 직접 측정한 결과. NVIDIA Transformer Engine 의 결과와도 비교.

FIG · 학습 perf — Llama-3 8B 1B tokenthroughput · loss
BF16 baseline
4500 tok/s
+ 8-bit Adam
4800
+ FP8 forward
7100 (+58%)
+ FP8 fwd+bwd (TE)
8200 (+82%)
+ stochastic master
8400 (+87%)
H100 80GB 단일 노드 기준. throughput 이 거의 2 배까지 — 메모리는 30–40% 절감까지 따 낸다. Llama 의 train loss 는 1B token 까지 BF16 baseline 과 visible 차이 없음(diff < 1%).

강의에서 짚은 미묘한 측정 디테일들 —

  • FP8 의 win 은 H100 부터. A100 은 FP8 mma 가 없다. mma 자체는 H100 / H200 / Blackwell.
  • 실효 win 은 모델 + batch + seq 에 따라 달라짐. memory-bound 인 자리(작은 batch, KV cache dominant)는 FP8 의 throughput win 이 작다.
  • delayed scaling 의 history window. 작으면 (16) 빠르고 stable, 크면 (128) 더 robust 하지만 cold start 가 느림.
  • convergence 가 깨지는 자리 — scale 이 sudden 하게 변하는 self-attention 의 logit. 그 자리만 BF16 fallback 하는 패턴.
torchao 와 TE 의 디자인 차이

NVIDIA Transformer Engine = 모듈 단위 wrapping. te.Linear, te.LayerNorm 으로 모델 코드 변경.
torchao = tensor subclass + dispatch. quantize_(model, Float8DynamicLinearConfig()) 한 줄로 모델의 nn.Linear 를 swap. 모델 코드 변경 거의 없음.
접근 방식의 trade-off — TE 는 각 자리의 정밀 통제, torchao 는 채택 비용 0.

“BF16 학습이 1.0 이라면, 8-bit Adam + FP8 fwd/bwd 까지 가면 1.8× 빠르고 메모리는 30% 적게 — 그리고 loss curve 가 같습니다. 이게 quantized training 의 본질입니다.”Thien Tran · 1:12:18
§ 09inference 양자화와의 차이· 대칭/비대칭, dynamic/static

“같은 INT8 인데 학습은 어렵고 inference 는 쉬운 이유”

강의 끝부분에서 Thien 이 명시적으로 던진 비교. 같은 양자화 단어가 두 환경에서 다른 의미를 가진다.

inference: weight only
weight 만 INT8/INT4. activation 은 BF16. 각 weight 가 정해져 있으니 calibration 한 번에 끝.
training: weight 가 매 step 변함
amax 가 매 step 변함. delayed scaling history 또는 dynamic scaling 필수. inference 의 calibration 이 통하지 않음.
inference: 비대칭 양자화 OK
zero point 를 따로 저장. INT4 가 [-8, 7] 이지만 zero point 로 mapping 자유.
training: 대칭 양자화 표준
학습 시 zero point 가 변하면 backward 가 복잡. 대부분 symmetric scaling.
inference: static or dynamic activation
activation 양자화는 calibration set 위 amax history 또는 매 batch 마다 dynamic.
training: dynamic + history
매 step amax 추적 + history rolling max. 발산 방지가 우선.
inference: rounding 거의 무관
round-to-nearest 으로 충분. 정밀도 손실은 weight 분포에 영향 받음.
training: stochastic rounding 필수
작은 grad 가 round-to-nearest 으로 죽으면 학습 발산. unbiased rounding 이 결정적.
한 줄 요약

inference 양자화 = 정확도의 게임. weight 분포를 어떻게 더 정확하게 표현하는가.
training 양자화 = gradient flow 의 게임. 학습이 발산하지 않으면서 어디까지 정밀도를 깎을 수 있는가.

§ 10기억할 메모와 코드· key takeaways

다시 열었을 때 5분 안에 손에 잡혀야 할 것

학습 메모리 분포
weight × 8 (BF16 + grad + master fp32 + Adam m,v fp32) + activation. opt state 가 가장 큼.
정밀도 분리
mma input 만 low-bit (FP8/INT8), accumulator 는 fp32. storage ≠ compute precision.
stochastic rounding
작은 grad 가 round-to-nearest 으로 죽는 것을 막음. 기댓값 unbiased. 비용은 거의 없음.
master weights
FP32 한 벌이 optimizer 가 만지는 자리. working weight (BF16) 와 분리. mixed-precision 의 본질.
8-bit Adam
opt state m, v 를 INT8 + scale 로. dequant/quant 가 shared memory 안에서. 정확도 영향 거의 0, 메모리 -75%.
FP8 e4m3 / e5m2
e4m3 = forward act/weight (정확도). e5m2 = backward grad (range). Hopper mma 가 양쪽.
delayed scaling
amax history N 개 step 의 max 를 scale 로. 학습 stability 의 핵심.
torchao vs TE
torchao = tensor subclass dispatch (모델 변경 0). TE = 모듈 wrapping (정밀 통제). 채택 비용 trade-off.

손에 새기기 — 실습 시퀀스

  1. 학습 메모리 baseline 측정 — Llama-2 1B 의 한 step 의 메모리 분포를 torch.cuda.memory_allocated() 로 추적. weight / grad / opt state / activation 의 % 확인.
  2. 8-bit Adam 적용 — bitsandbytes 또는 torchao. 같은 모델로 학습 1B token. 메모리 절감과 loss 곡선 일치 확인.
  3. stochastic rounding 직접 짜보기 — Triton 으로 fp32 → bf16 cast 의 두 버전 (RTN vs SR). small grad 누적 시뮬레이션으로 차이 확인.
  4. FP8 forward 적용 — H100 환경에서 torchao 의 Float8DynamicLinear. throughput / loss 비교.
  5. FP8 backward 까지 — TE 또는 torchao 의 fully-FP8 mode. e4m3 vs e5m2 의 자리를 직접 확인.
  6. delayed scaling history 영향 — history window 16 / 64 / 256 으로 학습. 짧을수록 발산 위험, 긴 게 stable.
  7. Triton 으로 8-bit Adam 구현 — 강의의 패턴을 first principle 로 짜본다. dequant + update + quant 를 한 커널.
  8. master weight BF16 화 — torchao 의 옵션. stochastic rounding 와 결합. 메모리 절감을 측정.
§ 11다른 강의로 이어지는 길· connections

이 강의의 도구가 시리즈 안에 어떻게 다시 등장하는지

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

검증 메모

이 노트의 throughput / 메모리 수치는 강의에서 Thien 이 보여준 측정을 재구성한 것. H100 / H200 / Blackwell 별로 변동이 클 수 있고, torchao 의 stability 도 빠르게 변하니 — 자기 환경에서 baseline 을 직접 떠봐야 한다.

← Lecture 029 Triton Internals — Kapil Sharma 의 컴파일 단계 Lecture 031 → Beginners Guide to Metal — Nikita Shulga