| bucket | B/param | 비율 |
|---|---|---|
| FP16 param | 2 | 10.5% |
| FP16 grad | 2 | 10.5% |
| FP32 master | 4 | 21.1% |
| FP32 momentum | 4 | 21.1% |
| FP32 variance | 4 | 21.1% |
| workspace·act | ~3 | ~15.7% |
base: Rajbhandari 2020 (ZeRO) Table 1, Adam + FP16 mixed
param (FP16) ██ 2B
grad (FP16) ██ 2B
master(FP32) ████ 4B
moment(FP32) ████ 4B
var (FP32) ████ 4B
└─── model state 16B/param
act/ws ████ ~3B (batch·seq 의존)
| 항목 | 값 |
|---|---|
| P (param) | 7.0 × 10⁹ |
| Model state (16B/p) | 112 GB |
| A100 80GB 수용 | ❌ |
| H100 80GB 수용 | ❌ |
| + activation (B=1,S=4k) | +15~40 GB |
∴ 7B조차 단일 GPU 학습 불가 → ZeRO/FSDP 필수
| 구성 | c_act (B/token·layer·H) |
|---|---|
| vanilla transformer | ~34 |
| SP (sequence parallel) | ~22 |
| + selective checkpoint | ~10 |
| full checkpoint | ~2 (+recompute) |
Korthikanti 2022 (Reducing Activation Recomputation)
| 모드 | B/param | 비고 |
|---|---|---|
| FP32 only + Adam | 16 | 4+4+4+4 |
| Mixed FP16 + Adam | 16 | 같음 (master 추가) |
| Mixed BF16 + Adam | 16 | 동일 |
| Mixed + AdaFactor | ~6 | m,v 축 분해 |
∴ mixed의 이점은 용량이 아니라 throughput(TC) · compute cost
w ← w − η·(m̂/√(v̂)+ε)FP32 master ──cast──► FP16 param │ │ forward (TC) │ ▼ │ activation (FP16) │ │ backward │ ▼ │ FP16 grad │ │ unscale ▼ ▼ FP32 m,v ◄──── grad (FP32) ← cast up │ update → master (in-place)
| 항목 | FP16 | BF16 |
|---|---|---|
| exponent | 5 | 8 (FP32 동일) |
| mantissa | 10 | 7 |
| loss scale | 필요 | 대부분 불필요 |
| 수렴 | 민감 | 안정 |
| TC 지원 | V100+ | A100+ |
| 옵션 | grad B/p | 주의 |
|---|---|---|
| FP32 grad | 4 | allreduce 대역폭 ×2 |
| FP16 grad | 2 | loss scale 필수 |
| BF16 grad | 2 | 안정·표준 |
| FP8 grad | 1 | TE·실험적 |
torch.cuda.amp.autocast: dtype 자동 캐스트GradScaler: scale / unscale / skip| stage | shard 대상 | 추가 통신 |
|---|---|---|
| ZeRO-1 | Optimizer state | 0 (step 이후 없음) |
| ZeRO-2 | + Gradient | RS (대체) |
| ZeRO-3 | + Parameter | AG fwd + AG bwd + RS |
AG = all-gather · RS = reduce-scatter
| stage | N=1 | N=8 | N=64 |
|---|---|---|---|
| baseline | 112 GB | 112 | 112 |
| ZeRO-1 | 112 | 44.5 | 29.3 |
| ZeRO-2 | 112 | 30.5 | 15.8 |
| ZeRO-3 | 112 | 14.0 | 1.75 |
16·P/N → N=64일 때 1.75 GB/GPU
| stage | fwd | bwd | 총 vs DP |
|---|---|---|---|
| DP | 0 | AR 2P | 1.0× |
| ZeRO-1 | 0 | AR 2P | 1.0× |
| ZeRO-2 | 0 | RS 2P | 1.0× |
| ZeRO-3 | AG 2P | AG 2P + RS 2P | 1.5× |
AR ≡ RS+AG · 전체 bytes 기준
[layer L] ┌── all-gather param (N→full) ──┐ │ forward │ │ free param (→1/N) │ ▼ ▼ ... ┌── all-gather param ────────────┐ │ backward │ │ reduce-scatter grad (→1/N) │ │ free param again │ ▼ ▼ step: local shard만 update
| optimizer | K | total B/p |
|---|---|---|
| SGD momentum | 8 | 12 |
| Adam/AdamW | 12 | 16 |
| AdaFactor | ~2 | ~6 |
| 8-bit Adam | 3 | 7 |
| 대상 | 위치 | 비고 |
|---|---|---|
| FP16 param | GPU | fwd/bwd hot |
| FP16 grad | GPU → CPU | bwd 즉시 이동 |
| FP32 master | CPU | update만 |
| m, v | CPU | CPU-side Adam |
Ren 2021 (ZeRO-Offload)
GPU HBM ← hot (param tile) ↕ NVLink/PCIe CPU DRAM ← warm (grad, opt state) ↕ PCIe NVMe SSD ← cold (huge opt state)
| 항목 | 값 |
|---|---|
| opt state (12B/p) | 156 GB |
| PCIe4 BW | 32 GB/s |
| Update 전송 | ~4.9 s |
| GPU step (TC) | ~0.3 s |
| Ratio (offload/compute) | ~16× |
| 상황 | 판단 |
|---|---|
| 단일 A100·10B+ 모델 | Offload 유효 |
| NVMe만·PCIe3 | Infinity 마지막 수단 |
| 다수 GPU (ZeRO-3 충분) | offload 불필요 |
| long seq·큰 activation | checkpoint 먼저 |
| 정책 | 단위 |
|---|---|
| size_based | ≥ threshold param |
| transformer_auto | 각 block |
| manual | 사용자 지정 |
transformer_auto_wrap_policy가 대부분 최적
| mode | param | grad | opt |
|---|---|---|---|
| FULL_SHARD | shard | shard | shard |
| SHARD_GRAD_OP | 복제 | shard | shard |
| NO_SHARD | 복제 | 복제 | 복제 (DDP) |
| HYBRID_SHARD | 내부 FULL + 외부 복제 (2D) | ||
HYBRID: intra-node shard + inter-node replicate
MixedPrecision( param_dtype = bfloat16, reduce_dtype = float32, buffer_dtype = bfloat16)
| policy | 효과 |
|---|---|
| BACKWARD_PRE | 다음 layer param AG 선행 |
| BACKWARD_POST | 현 layer 끝난 뒤 AG |
| NO_PREFETCH | stall 허용 |
BACKWARD_PRE가 보통 최적 (compute와 overlap)
fully_shard API · DTensor 기반PyTorch 2.4+, torch.distributed._composable
[unit i, rank r] local shard: P_i / N full param (AG 직후): P_i ← temporary grad shard: P_i / N opt state: P_i / N (AG buffer == compute buffer 재사용)
activation은 full param 보존 없음 → backward에 재-AG
fwd[L0]: AG→compute→reshard fwd[L1]: AG→compute→reshard ... fwd[Ln]: loss bwd[Ln]: AG→bwd→RS→reshard bwd[Ln-1]: AG→bwd→RS→reshard ... step: local shard update
comp stream: | F(L0) | F(L1) | F(L2) |... comm stream: | AG(L1) | AG(L2) |... bwd: comp stream: | B(Ln) | B(Ln-1) |... comm stream: | AG(Ln-1) | RS(Ln) + AG(Ln-2) |...
no_sync() context: backward 시 RS 생략, 여러 micro-batch 누적 뒤 마지막에만 sync.
clip_grad_norm_ FSDP API 사용| 단위 | 구현 |
|---|---|
| per-layer | 가장 흔함 |
| per-block | transformer block |
| per-module | custom boundary |
| per-op | selective (↓) |
layer: [QKV proj] [SDPA] [out proj] [FFN] store: ✓ ✗ ✓ partial recomp: ✓ FFN 하위
Korthikanti 2022 · Megatron selective activation recompute
from torch.utils.checkpoint import checkpoint def block(x): return layer2(layer1(x)) y = checkpoint(block, x, use_reentrant=False)
use_reentrant=False: autograd 2세대, 권장preserve_rng_state=True 기본값| mode | mem | compute |
|---|---|---|
| none | 1.0 | 1.0 |
| selective | ~0.4 | ~1.08 |
| full (per-block) | ~0.15 | ~1.33 |
| offload act | ~0.05 | +PCIe |
GPU HBM 해제·CPU DRAM 사용 — PCIe BW 여유 필요
raw: [s1: 123 ][s2:12][s3:1234] pack: [123 12 1234] seq=9 bdry: 0 3 5 (누적 시작)
cu_seqlens = [0, 3, 5, 9] max_seqlen = 4 total_tok = 9
FlashAttention varlen API 입력 ↗ V07 §12
s1 s2 s3 s1 ██ ░░ ░░ s2 ░░ ██ ░░ s3 ░░ ░░ ██ (block-diagonal)
position_ids = [0,1,2, 0,1, 0,1,2,3]| 요소 | 필요 |
|---|---|
| cu_seqlens | ✓ |
| max_seqlen | ✓ |
| position_ids per-sample | ✓ |
| label shift·ignore_index | ✓ |
| FA varlen flag | ✓ |
| axis | 대상 | 통신 |
|---|---|---|
| DP | batch 복제 | AR grad |
| TP | hidden 분할 | AR per layer |
| PP | layer 분할 | P2P activation |
| EP | MoE expert | all-to-all |
| CP | sequence 분할 | KV rotate |
상세 통신식 ↗ V15 §3
FFN: y = GeLU(x·W1)·W2 W1 : col-parallel (split N) W2 : row-parallel (split K) 마지막 all-reduce (sum) QKV proj: col-parallel (head 분할) Attn out: row-parallel + all-reduce
1 layer당 fwd 1× AR · bwd 1× AR
| axis | bytes/step·rank |
|---|---|
| TP | O(L · B · S · H) |
| PP | O(M · B · S · H / P) |
| DP | O(P_param/N_dp) |
B:batch, S:seq, H:hidden, L:layer, M:microbatch, P:PP stages
| size | TP | PP | DP | total GPU |
|---|---|---|---|---|
| 7B | 1 | 1 | 8 | 8 |
| 13B | 2 | 1 | 16 | 32 |
| 30B | 4 | 2 | 16 | 128 |
| 70B | 8 | 4 | 16 | 512 |
| 400B | 8 | 16 | 16 | 2048 |
TP=intra-node · PP=inter-node · DP 가장 바깥
inner TP (NVLink)
│
PP (node-to-node P2P)
│
outer DP (ZeRO/FSDP 가능)
TP는 가장 fast fabric, DP는 가장 slow에 배치
P=4, M=4
t: 0 1 2 3 4 5 6 7 8 9 A B
G0 F F F F B B B B
G1 F F F F B B B B
G2 F F F F B B B B
G3 F F F F B B B B
└─fwd warm─┘└─bwd tail─┘
bubble = 2(P-1)/M
∴ activation이 모든 micro 저장 → 메모리 많음
P=4, M=8, 1F1B: t: 0 1 2 3 4 5 6 7 8 9 A B C D G0 F F F F B F B F B F B B B B G1 F F F B F B F B F B F B B G2 F F F B F B F B F B F B G3 F B F B F B F B F B F warm-up steady (1F1B) drain bubble = (P-1)/M (GPipe의 절반)
P=4, v=2 (virtual stages/GPU), M=8: G0: [s0]...[s4] F F F F B B B B (교차) G1: [s1]...[s5] ... bubble = (P-1) / (v·M)
1F1B-BW:
G0 F F F F B W B W B W B W W W
└─ B·W 교차 ─┘
Qi, Zero Bubble Pipeline (2023)
bubble → 근사 0 · 구현 복잡도 ↑
| schedule | bubble | act mem | 구현 |
|---|---|---|---|
| GPipe | 2(P−1)/M | O(M) | 단순 |
| 1F1B | (P−1)/M | O(P) | 중간 |
| Interleaved | (P−1)/(v·M) | O(P·v) | 복잡 |
| Zero-Bubble | ~0 | O(P) | 매우복잡 |
bwd: G_L1 G_L2 G_L3 G_L4 ...
└bkt┘└bkt┘
RS starts asap (overlap next B)
DDP·FSDP 기본 동작
for i, batch in enumerate(loader): if (i+1) % accum != 0: with model.no_sync(): loss = model(batch) loss.backward() else: loss = model(batch) loss.backward() # sync optimizer.step()
| phase | overlap 대상 |
|---|---|
| fwd L_i | AG L_{i+1} |
| bwd L_i | AG L_{i-1} |
| bwd L_i | RS L_{i+1} |
3중 overlap이 가능하나 구현이 복잡
| 증상 | 원인 후보 |
|---|---|
| NCCL stream 빈 공간 | bucket 너무 큼 → 분할 |
| compute stream gap | AG 대기 → prefetch 활성 |
| HBM BW 포화 | bucket·GEMM 충돌 → 크기 조정 |
| per-step variance ↑ | NCCL 순서 불안정 |
Nsight Systems trace 기준 ↗ V18 §12
| 방식 | decay 적용 |
|---|---|
| Adam (L2) | grad에 λw 추가 후 Adam |
| AdamW | update 후 w ← w(1−ηλ) |
| optimizer | B/p (FP32 state) |
|---|---|
| SGD | 0 |
| SGD + momentum | 4 |
| Adam/AdamW | 8 (m+v) |
| Adafactor | ~2 |
| LAMB | 8 |
| Lion | 4 (momentum만) |
master copy 4B/p 별도 추가
| state | baseline | 8-bit |
|---|---|---|
| m | 4B | 1B + scale |
| v | 4B | 1B + scale |
| 총 state | 8B | ~2B |
Dettmers 2022 · 수렴 품질 Adam과 동등
| opt | mem | 품질 | 권장 맥락 |
|---|---|---|---|
| AdamW | 12B | 표준 | 대부분 |
| Adafactor | 6B | 약간 ↓ | 초대형·T5 |
| 8-bit AdamW | 6B | ≈ AdamW | 단일 GPU |
| Lion | 8B | 혼재 | 비전·실험 |
| SGD-M | 8B | 비전 OK | LLM 비권장 |
master 4B 포함 값 · LLM pretrain 기준
FusedAdam·PyTorch foreach·fused=Truestate:
S (scale factor) init ~2^16
growth_interval e.g. 2000
n_ok success 카운트
per step:
if any(grad = ±inf, NaN):
skip step
S ← S / 2; n_ok ← 0
else:
apply step
n_ok ← n_ok + 1
if n_ok == growth_interval:
S ← S · 2; n_ok ← 0
scaler = GradScaler() for x,y in loader: with autocast(dtype=FP16): loss = model(x,y) scaler.scale(loss).backward() scaler.unscale_(opt) # for clip clip_grad_norm_(...) scaler.step(opt) # skip if inf scaler.update() # adjust S
| format | E/M | 용도 |
|---|---|---|
| E4M3 | 4/3 | activation·weight |
| E5M2 | 5/2 | gradient (range↑) |
Micikevicius 2022 · IEEE binary8 candidate
x (BF16) ─cast→ E4M3 ─┐
W (BF16) ─cast→ E4M3 ─┤ WGMMA (FP8×FP8)
└→ FP32 accum → BF16
grad (BF16) ─cast→ E5M2 (더 큰 range)
amax history → per-tensor scale
| layer | FP8? |
|---|---|
| QKV/FFN GEMM | ✓ |
| LayerNorm | BF16 유지 |
| softmax·attention | FP16/BF16 |
| logit·loss | FP32 |
| embedding table | BF16 (보수적) |
| param | 의미 |
|---|---|
| num_workers | worker process 수 |
| prefetch_factor | worker당 미리 버퍼링 |
| pin_memory | pinned → H2D 빠름 |
| persistent_workers | epoch 간 재사용 |
| collate_fn | batch 조립 |
DistributedSampler: rank별 index 분할| 방식 | 특성 |
|---|---|
| index shuffle | map-style 전용 |
| buffer shuffle | iterable·window N |
| tar-shuffle | webdataset tar 단위 |
global shuffle은 많은 메모리 필요 → buffer approximation
| 포맷 | 장점 | 단점 |
|---|---|---|
| raw image/jpg | 호환성 | decode CPU |
| tfrecord | seq read 빠름 | 검색 ↓ |
| webdataset (tar) | streaming·shard | random X |
| arrow/parquet | columnar·효율 | image 불리 |
| mosaic (mds) | deterministic resume | tooling 의존 |
set_epoch 필수| signal | 의미 |
|---|---|
| data-wait % > 5% | loader 부족 |
| CPU ≈ 100% | decode bound |
| IO wait | 저장매체 bound |
| variance ↑ | buffer 부족 |
| 항목 | 필수 |
|---|---|
| model state_dict | ✓ |
| optimizer state | ✓ |
| lr scheduler | ✓ |
| grad scaler | FP16 시 |
| RNG state (cpu/cuda) | ✓ |
| dataloader sampler | ✓ |
| global_step | ✓ |
| mode | 특성 |
|---|---|
| rank-0 consolidated | 단일 파일·느림·OOM |
| sharded (per-rank) | 빠름·world 의존 |
| DCP (torch) | rank-independent load |
torch.distributed.checkpoint 2.2+
ckpt-1000/ ├ model/ │ ├ __0_0.distcp │ ├ __1_0.distcp │ └ ... ├ optim/ ├ meta.json (world topology) └ sampler_state.pt
StateDictType.FULL_STATE_DICTsave: TP=8, PP=4, DP=16 load: TP=4, PP=4, DP=32 ← OK load: TP=8, PP=4, DP=16 ← 동일 사용: resume시 node 수 변경
| 요인 | 권장 |
|---|---|
| MTBF 예상 | interval ≤ MTBF/2 |
| save cost | ≤ 1% wall-time |
| disk 용량 | rotation 3~5 |
interval ↑ → 손실 확률 ↑ · IO ↓
| 유형 | 빈도 |
|---|---|
| GPU Xid error | 주기적 |
| NVLink/IB link down | 드묾 |
| host crash | 드묾 |
| NCCL timeout | 네트워크 jitter |
| straggler (slow) | 자주 |
1000 GPU·30일 학습 시 수십건 통상
NCCL_TIMEOUT (default 30min)--max-restarts, --rdzv-backend=c10d[rdzv server]
│
┌────┴────┬────┬────┐
worker worker ...
│ │
fail!
↓
rdzv re-bargain (N→N-1 or +spare)
↓
load last ckpt (reshard topology)
↓
resume
NCCL_ALGO=Ring 고정| scale | 권장 interval |
|---|---|
| 8 GPU | ~1h |
| 128 GPU | ~30m |
| 1024 GPU | ~10~20m |
| 10k GPU | <10m (async 필수) |
∝ 1/√MTBF · 실측이 최종
| 항목 | 값 |
|---|---|
| HW | 8 × A100/H100 |
| Parallelism | DP=8 + FSDP(FULL_SHARD) |
| TP/PP | 1/1 |
| Mixed precision | BF16 / FP32 master |
| Optimizer | AdamW fused |
| Activation | selective ckpt |
| Seq packing | 권장 |
| 항목 | 값 |
|---|---|
| HW | 64~512 × H100 |
| TP | 8 (intra-node NVLink) |
| PP | 4~8 |
| DP | 2~16 (+ZeRO-1) |
| SP | on (activation ↓) |
| Schedule | 1F1B interleaved v=2 |
| Activation | full ckpt·attention만 off |
| 항목 | 값 |
|---|---|
| HW | 2k~16k × H100/B200 |
| TP | 8 |
| PP | 16 (zero-bubble) |
| DP | 16~128 + ZeRO-1 |
| CP | long seq 시 on |
| EP | MoE 시 8~64 |
| FP8 | TE path |
| Async ckpt | 필수 |
start
└ model ≤ 10B? → FSDP (FULL_SHARD) only
│
▽ no
└ intra-node NVLink? → TP=8
│
▽ yes
└ model ≤ 100B? → TP8 + PP4 + DP
│
▽ no
└ TP8 + PP≥8 + ZeRO-1 + CP/EP 조합
| OOM 원인 | 조치 |
|---|---|
| model state | ZeRO stage ↑ · TP 도입 |
| activation | ckpt selective → full |
| peak full param | unit wrap 작게 |
| kv cache | 학습엔 없음 (↗ V16) |
| optimizer state | 8-bit·offload |
| 증상 | 조치 |
|---|---|
| data wait | loader worker/prefetch ↑ |
| comm bound | overlap·bucket tuning (↗ §11) |
| pipeline bubble | M 증가·interleaved·zero-bubble |
| HBM sat | selective ckpt·FP8 |