CUDA ENGINEERING 18권 · CONTENT-FIRST · A4 LANDSCAPE · 18p

V17 Training System

FSDP · ZeRO · Gradient Checkpointing · Comm-Compute Overlap
Volume 17/18
Tier T5 분산/서빙
선행 V01 · V15
용도 대규모 학습 시스템 지도

목차

§1. 학습 메모리 구성p.2
§2. Mixed Precision 메모리식p.3
§3. ZeRO 3단계p.4
§4. ZeRO-Offload / Infinityp.5
§5. FSDP 아키텍처p.6
§6. FSDP Fwd/Bwd 흐름p.7
§7. Gradient Checkpointingp.8
§8. Sequence Packingp.9
§9. Megatron 3D Parallelismp.10
§10. Pipeline Schedulep.11
§11. Comm-Compute Overlapp.12
§12. Optimizer Statesp.13
§13. Loss Scaling · FP8p.14
§14. Dataloaderp.15
§15. Checkpointingp.16
§16. Fault Tolerancep.17
§17. Cheat Sheet (7B·70B·400B)p.18

범례

핵심 용어
표 헤더·매우 중요
정의·공식 박스
예시·워크드
빨강주의·흔한 실수
실무 핵심
(!)니모닉 (권당 ≤5)
다른 권 참조
결론
인쇄 A4 가로 / 여백 없음 / 배경 그래픽 포함 · Ctrl(⌘)+P
ZeRO · FSDP · Megatron-Core · DeepSpeed · FA varlen

1 GPU memory 4 bucket 학습 시 점유 주체 PGOA

정의 학습 중 device memory = Param + Grad + Optimizer state + Activation + workspace.
  • Param: forward에 필요한 weight
  • Grad: backward가 채우는 ∂L/∂W
  • Opt state: Adam momentum·variance 등
  • Activation: backward용 중간 tensor

2 Model state 정의 고정 비용

정의 Model state = Param + Grad + Optimizer state. batch size와 무관, 모델 크기 P에만 의존.
  • Activation은 batch·seq 따라 변동
  • ZeRO는 model state 만을 shard (↗ §3)
  • Gradient checkpointing은 activation만 감소 (↗ §7)

3 메모리 단위 공식 byte/param

M_state = P × (b_p + b_g + b_opt)
M_act = B · S · H · L · c_act P: param count · B: batch · S: seq len · H: hidden · L: layer · c_act: activation factor (구현별 10~34·dtype)

4 Bucket 비율 ★ Adam·FP16 mixed

bucketB/param비율
FP16 param210.5%
FP16 grad210.5%
FP32 master421.1%
FP32 momentum421.1%
FP32 variance421.1%
workspace·act~3~15.7%

base: Rajbhandari 2020 (ZeRO) Table 1, Adam + FP16 mixed

5 ASCII bar chart ★

param (FP16)  ██        2B
grad  (FP16)  ██        2B
master(FP32)  ████      4B
moment(FP32)  ████      4B
var   (FP32)  ████      4B
         └─── model state 16B/param
act/ws  ████ ~3B (batch·seq 의존)
∴ 학습 메모리의 75% 이상이 model state — ZeRO의 표적.

6 7B 모델 숫자 예 single-GPU 기준

항목
P (param)7.0 × 10⁹
Model state (16B/p)112 GB
A100 80GB 수용
H100 80GB 수용
+ activation (B=1,S=4k)+15~40 GB

∴ 7B조차 단일 GPU 학습 불가 → ZeRO/FSDP 필수

7 Activation factor c_act

구성c_act (B/token·layer·H)
vanilla transformer~34
SP (sequence parallel)~22
+ selective checkpoint~10
full checkpoint~2 (+recompute)

Korthikanti 2022 (Reducing Activation Recomputation)

8 경계선 정리 out-of-scope 핑

  • 수치 안정성 (loss NaN, under/overflow) ↗ V09 §3
  • Inference KV cache ↗ V16 §4
  • NCCL primitive 대역폭 식 ↗ V15 §2

1 Mixed precision 구성 정의

정의 forward/backward는 FP16/BF16으로, optimizer step은 FP32 master로 수행하는 하이브리드. 메모리 ↓ · TC throughput ↑.
  • master copy가 있어야 누적 오차 발산 방지
  • BF16: exponent 8bit → loss scaling 불필요 경향
  • FP16: exponent 5bit → loss scaling 필요 (↗ §13)

2 Adam 메모리 공식 ★ 2+2+4+4+4=16

M_Adam(P) = 2P + 2P + 4P + 4P + 4P = 16·P bytes FP16 param + FP16 grad + FP32 master + FP32 m + FP32 v
P = 1.3B → M_Adam = 20.8 GB · P = 13B → 208 GB · P = 70B → 1,120 GB

3 FP32 baseline 비교

모드B/param비고
FP32 only + Adam164+4+4+4
Mixed FP16 + Adam16같음 (master 추가)
Mixed BF16 + Adam16동일
Mixed + AdaFactor~6m,v 축 분해

∴ mixed의 이점은 용량이 아니라 throughput(TC) · compute cost

4 왜 FP32 master인가 수치

  • update: w ← w − η·(m̂/√(v̂)+ε)
  • η·gradient는 종종 |w|의 10⁻⁷ scale
  • FP16 mantissa 10bit → round-to-zero
  • ∴ master를 FP32로 보존 → precision 유지
FP16 only 학습은 "update가 사라짐" (stalled) 현상으로 수렴 실패.

5 Forward/Backward bucket 흐름

FP32 master ──cast──► FP16 param
  │                      │ forward (TC)
  │                      ▼
  │                   activation (FP16)
  │                      │ backward
  │                      ▼
  │                   FP16 grad
  │                      │ unscale
  ▼                      ▼
FP32 m,v ◄──── grad (FP32) ← cast up
  │
update → master (in-place)

6 BF16 vs FP16 학습 관점

항목FP16BF16
exponent58 (FP32 동일)
mantissa107
loss scale필요대부분 불필요
수렴민감안정
TC 지원V100+A100+

7 Grad 저장 dtype 옵션

옵션grad B/p주의
FP32 grad4allreduce 대역폭 ×2
FP16 grad2loss scale 필수
BF16 grad2안정·표준
FP8 grad1TE·실험적

8 Apex/AMP 경로 요약

  • torch.cuda.amp.autocast: dtype 자동 캐스트
  • GradScaler: scale / unscale / skip
  • O1·O2·O3 mode (Apex) 유산 — 현재는 autocast 단일화
  • 수치 상세 ↗ V09 §5

9 흔한 실수 주의

  • loss.backward() 전에 scale 안 걸면 grad=0
  • optimizer step 후 unscale 누락 → η 실효 ×scale
  • master copy 안 두고 FP16 파라미터 직접 업데이트 → 수렴 불가

1 ZeRO 아이디어 ★ DP 중복 제거 OGP

정의 DP는 model state를 N GPU에 전부 복제. ZeRO는 이를 N개로 shard하여 필요한 순간에 all-gather → 사용 → 해제.
  • 통신 증가 vs 메모리 절감의 trade
  • NCCL 식: allreduce ≡ reduce-scatter + all-gather (↗ V15)

2 Stage 정의 ★

stageshard 대상추가 통신
ZeRO-1Optimizer state0 (step 이후 없음)
ZeRO-2+ GradientRS (대체)
ZeRO-3+ ParameterAG fwd + AG bwd + RS

AG = all-gather · RS = reduce-scatter

3 메모리 절감 공식 ★★

M₀ = 2P + 2P + K·P (K=12 for Adam FP32)
ZeRO-1: M₁ = 2P + 2P + K·P/N
ZeRO-2: M₂ = 2P + (2P + K·P)/N
ZeRO-3: M₃ = (2P + 2P + K·P)/N = 16P / N P: param, N: DP world size, K=12 = 4(master)+4(m)+4(v)

4 숫자 예 ★ P=7B, Adam

stageN=1N=8N=64
baseline112 GB112112
ZeRO-111244.529.3
ZeRO-211230.515.8
ZeRO-311214.01.75

16·P/N → N=64일 때 1.75 GB/GPU

5 통신량 per step

stagefwdbwd총 vs DP
DP0AR 2P1.0×
ZeRO-10AR 2P1.0×
ZeRO-20RS 2P1.0×
ZeRO-3AG 2PAG 2P + RS 2P1.5×

AR ≡ RS+AG · 전체 bytes 기준

6 Stage 선택 결정표

  • model ≤ 1B · NVLink: DP + ZeRO-1 충분
  • 1B~10B: ZeRO-2 (grad shard)
  • 10B+: ZeRO-3 (param shard) or FSDP/Megatron TP

7 ZeRO-3 life cycle per layer

[layer L]
 ┌── all-gather param (N→full) ──┐
 │           forward              │
 │      free param (→1/N)         │
 ▼                                ▼
 ...
 ┌── all-gather param ────────────┐
 │         backward                │
 │    reduce-scatter grad (→1/N)   │
 │    free param again             │
 ▼                                 ▼
 step: local shard만 update

8 ZeRO ≡ FSDP

관계 PyTorch FSDP는 ZeRO-3의 PyTorch-native 구현. API·구조는 다르지만 메모리·통신 동등.
  • DeepSpeed ZeRO = Microsoft 진영
  • FSDP = Meta/PyTorch 진영 (↗ §5)
  • fairseq FSDP → torch.distributed.fsdp

9 K 상수 정리 optimizer별

optimizerKtotal B/p
SGD momentum812
Adam/AdamW1216
AdaFactor~2~6
8-bit Adam37

1 Offload 동기 GPU가 부족할 때

정의 CPU DRAM·NVMe를 cold storage처럼 사용해 GPU에 담기 불가능한 model state를 체류시키는 계층화.
  • param/grad/opt state 중 일부를 offload
  • opt state는 update에서만 필요 → 먼 저장소 적합
  • param·grad은 fwd/bwd마다 접근 → swap 비쌈

2 ZeRO-Offload 구성

대상위치비고
FP16 paramGPUfwd/bwd hot
FP16 gradGPU → CPUbwd 즉시 이동
FP32 masterCPUupdate만
m, vCPUCPU-side Adam

Ren 2021 (ZeRO-Offload)

3 ZeRO-Infinity 계층

GPU HBM   ← hot (param tile)
   ↕ NVLink/PCIe
CPU DRAM  ← warm (grad, opt state)
   ↕ PCIe
NVMe SSD  ← cold (huge opt state)
  • partitioned param prefetch
  • NVMe GDS (GPUDirect Storage)

4 Bandwidth 계산 ★

t_step ≥ max( flops/P_comp , bytes/B_link )
B_link: PCIe4 ≈ 32 GB/s · PCIe5 ≈ 64 · NVMe ≈ 3~7 · HBM3 ≈ 3 TB/s 각 buffer의 bytes가 B_link보다 크면 swap이 step을 지배

5 수치 예 P=13B, PCIe4

항목
opt state (12B/p)156 GB
PCIe4 BW32 GB/s
Update 전송~4.9 s
GPU step (TC)~0.3 s
Ratio (offload/compute)~16×
naive offload → PCIe-bound. partition·pipeline·async 없으면 비실용.

6 Offload 유효성 조건

  • opt state ≫ param (Adam 12:4) → offload ROI ↑
  • step이 매우 길어 swap을 감춤 (큰 batch)
  • CPU-side optimizer 효율 (DeepSpeed CPU-Adam, fused kernel)
  • bandwidth = NVLink > NVMe+GDS > PCIe SSD

7 적합 상황 vs 부적합

상황판단
단일 A100·10B+ 모델Offload 유효
NVMe만·PCIe3Infinity 마지막 수단
다수 GPU (ZeRO-3 충분)offload 불필요
long seq·큰 activationcheckpoint 먼저

8 CPU-Adam fused DeepSpeed

  • CPU intrinsic: AVX-512 · FMA로 32-lane
  • grad FP16 → FP32 unscale + update를 단일 kernel
  • GPU→CPU grad stream 비동기화

9 체크포인트와의 상호작용

  • offload 중이면 master/m/v가 CPU에 → save 시 consolidation 필요 (↗ §15)
  • NVMe에 체크포인트 직접 쓰기 가능
  • resume 시 dtype·shard map 일치해야

1 FSDP 정의 ★ ZeRO-3 native

정의 PyTorch Fully Sharded Data Parallel: 각 unit의 param을 N rank에 shard, 필요 시 all-gather, 사용 후 reshard.
  • unit = Transformer block 단위가 표준
  • FSDP1 (legacy)·FSDP2 (per-param) 두 세대
  • ZeRO-3와 동일 메모리식 (↗ §3)

2 FlatParameter FSDP1

정의 unit 내 모든 param을 하나의 1D tensor로 평탄화 (flatten) → shard·comm 단위가 단일 tensor.
  • 장점: NCCL 호출 1회/unit (묶음 통신)
  • 단점: mixed dtype 불가·일부 hook 부적합
  • FSDP2: per-param sharding으로 해결

3 Unit wrapping 전략

정책단위
size_based≥ threshold param
transformer_auto각 block
manual사용자 지정

transformer_auto_wrap_policy가 대부분 최적

4 Sharding 전략 옵션

modeparamgradopt
FULL_SHARDshardshardshard
SHARD_GRAD_OP복제shardshard
NO_SHARD복제복제복제 (DDP)
HYBRID_SHARD내부 FULL + 외부 복제 (2D)

HYBRID: intra-node shard + inter-node replicate

5 MixedPrecision config ★

MixedPrecision(
  param_dtype  = bfloat16,
  reduce_dtype = float32,
  buffer_dtype = bfloat16)
  • param_dtype: shard·compute dtype
  • reduce_dtype: gradient reduce 정밀도
  • buffer_dtype: BN running stats 등

6 Prefetch 옵션 통신 은폐

policy효과
BACKWARD_PRE다음 layer param AG 선행
BACKWARD_POST현 layer 끝난 뒤 AG
NO_PREFETCHstall 허용

BACKWARD_PRE가 보통 최적 (compute와 overlap)

7 FSDP2 per-param 최신

  • fully_shard API · DTensor 기반
  • 각 param이 독립 ShardedTensor
  • mixed dtype unit·expert shard 가능
  • hooks·LoRA 호환성 ↑

PyTorch 2.4+, torch.distributed._composable

8 메모리 구성 요소별 저장

[unit i, rank r]
  local shard: P_i / N
  full param (AG 직후): P_i  ← temporary
  grad shard: P_i / N
  opt state: P_i / N
  (AG buffer == compute buffer 재사용)

9 안티패턴 FSDP

  • 너무 작은 unit: comm 횟수 ↑ · BW 비효율
  • 너무 큰 unit: peak full-param GB ↑ · OOM
  • root module 전체 wrap: prefetch 기회 상실
  • grad clip 구현 미스: shard만 clip → norm 계산 틀림

1 Forward 5단계 ★ AG·F·RE

  1. all-gather unit param (shard → full)
  2. forward compute
  3. reshard param (free full → keep shard)
  4. activation은 unit 외부로 전달
  5. 다음 unit 반복

activation은 full param 보존 없음 → backward에 재-AG

2 Backward 6단계 ★

  1. activation에 ∂L 도달
  2. all-gather param 다시 (unshard)
  3. backward compute → grad_full
  4. reduce-scatter grad → grad_shard
  5. reshard param
  6. 다음 unit (이전 layer)

3 Timeline ASCII

fwd[L0]: AG→compute→reshard
fwd[L1]: AG→compute→reshard
 ...
fwd[Ln]: loss
bwd[Ln]: AG→bwd→RS→reshard
bwd[Ln-1]: AG→bwd→RS→reshard
 ...
step: local shard update

4 통신량 공식

per step·per rank:
AG = 2 · P · (N−1)/N (fwd+bwd 각 1회)
RS = 1 · P · (N−1)/N (bwd) 총 ≈ 3P · (N−1)/N bytes · N GPU 기준

5 Compute-comm overlap prefetch

comp stream: | F(L0) | F(L1) | F(L2) |...
comm stream: |   AG(L1) | AG(L2) |...

bwd:
comp stream: | B(Ln) | B(Ln-1) |...
comm stream: | AG(Ln-1) | RS(Ln) + AG(Ln-2) |...
  • compute stream ≠ comm stream
  • CUDA event로 dependency

6 Root unit 특수 처리

  • 최초 forward 직전: root all-gather 누적
  • 최후 backward 이후: root reduce-scatter 대기
  • root만 sharding 안 하는 경우 있음 (lazy init)

7 Gradient accumulation no_sync

정의 no_sync() context: backward 시 RS 생략, 여러 micro-batch 누적 뒤 마지막에만 sync.
  • comm을 K배 절감 (K=accum steps)
  • 단, 누적 grad는 rank-local → 마지막 step에 RS
  • 상세 ↗ §11

8 Freeze param (LoRA 등)

  • FSDP1: frozen param을 unit에 포함하면 AG·RS 낭비
  • 해결: 별도 unit wrap, or NO_SHARD
  • FSDP2: per-param shard로 해결 용이

9 흔한 실수

  • unit 외부에서 param 접근: shard 상태라 shape 불일치
  • grad clip 전 unshard 누락: grad_norm 전체 합산 못함 → clip_grad_norm_ FSDP API 사용
  • optimizer에 full param 전달: 실제론 shard만 존재

1 동기 activation 메모리

정의 L개 layer 중 √L 지점만 activation 저장, 나머지는 backward 중 재계산 (recompute).
  • 메모리 O(L) → O(√L)
  • compute: +1 forward pass (≈ +33%)
  • Chen 2016 (Training Deep Nets with Sublinear Memory)

2 √N 공식 ★

M_act(naive) = L · m_layer
M_act(ckpt) = √L · m_layer + m_segment
optimal segment size = √L L: layer · m_layer: per-layer activation · m_segment: segment 내부 peak
L = 64 → √L = 8 → 메모리 ~1/8, 계산 ~4/3

3 적용 단위

단위구현
per-layer가장 흔함
per-blocktransformer block
per-modulecustom boundary
per-opselective (↓)

4 Selective checkpoint ★

정의 비싼 op (e.g. matmul·attention)는 저장, 싼 op (e.g. dropout·RMSNorm·elementwise)는 recompute. 메모리·compute 둘 다 최적화.
layer: [QKV proj] [SDPA] [out proj] [FFN]
store:    ✓          ✗       ✓       partial
recomp:              ✓              FFN 하위

Korthikanti 2022 · Megatron selective activation recompute

5 API 형태 PyTorch

from torch.utils.checkpoint import checkpoint

def block(x):
  return layer2(layer1(x))

y = checkpoint(block, x,
     use_reentrant=False)
  • use_reentrant=False: autograd 2세대, 권장
  • RNG state 저장·복원으로 dropout 재현

6 수치 재현성 RNG

  • dropout mask가 fwd·recompute 달라지면 grad 틀림
  • preserve_rng_state=True 기본값
  • Megatron: CUDA RNG tracker로 tensor parallel RNG 분리

7 Trade-off 표

modememcompute
none1.01.0
selective~0.4~1.08
full (per-block)~0.15~1.33
offload act~0.05+PCIe

8 Offload activation

  • GPU → CPU pinned으로 activation 이동
  • backward에서 prefetch
  • compute·comm overlap이 되면 거의 공짜

GPU HBM 해제·CPU DRAM 사용 — PCIe BW 여유 필요

9 FA와의 상호작용

  • FlashAttention은 이미 O(N) activation (softmax statistics만 저장)
  • attention block checkpoint는 추가 recompute (↗ V07)
  • QKV proj·FFN block에 집중 적용이 ROI 높음

10 흔한 실수

  • in-place op 내부에 checkpoint → recompute 후 상태 손상
  • non-deterministic kernel (e.g. atomicAdd) → grad mismatch
  • training loop 밖 RNG 변경 → dropout divergence

1 문제 상황 padding 낭비

정의 길이가 다른 sample을 batch로 묶을 때 pad token으로 max_len에 맞추면, 짧은 sample의 FLOP/BW가 낭비됨.
  • 평균/최대 길이 비 ≤ 0.5 인 dataset 흔함
  • pad는 loss masking으로 무시되나 compute는 수행됨

2 Packing 정의 ★

정의 여러 짧은 sample을 하나의 긴 sequence로 연결하고, attention에서만 sample boundary를 mask로 격리.
raw: [s1: 123 ][s2:12][s3:1234]
pack: [123 12 1234]  seq=9
bdry:  0    3  5     (누적 시작)

3 cu_seqlens 자료구조

정의 cumulative sequence lengths: 누적 길이 벡터. packed batch의 각 sample 범위를 표현.
cu_seqlens = [0, 3, 5, 9]
max_seqlen  = 4
total_tok   = 9

FlashAttention varlen API 입력 ↗ V07 §12

4 Attention mask 구성 ★

  s1 s2 s3
s1 ██ ░░ ░░
s2 ░░ ██ ░░
s3 ░░ ░░ ██   (block-diagonal)
  • 각 block 내 causal mask 추가
  • dense mask(N²) 대신 cu_seqlens로 구현
  • FA varlen: block-level skip (↗ V07)

5 FLOP 절약 공식

FLOP_pad = B · S_max² · d
FLOP_pack = Σᵢ sᵢ² · d
saving = 1 − (Σsᵢ²) / (B · S_max²) 길이 분포가 고르면 saving = 1 − 1/B (배치 수만큼 이익)
B=16·S_max=4k·avg=1k → pad 대비 pack ~75% 절약

6 Position ID 재설정

  • packed seq 전체에 0~N-1 쓰면 RoPE 엇나감
  • 각 sample 내에서 0부터 다시 시작
  • position_ids = [0,1,2, 0,1, 0,1,2,3]

7 Loss mask 주의

  • SFT: prompt 영역 loss=0, completion만 loss
  • packing 시 sample 경계에서 shift 조심
  • attention mask ≠ loss mask (둘 다 필요)

8 Curriculum & bin packing

  • 길이 bucket → 비슷한 길이끼리 배치 (bucket batching)
  • greedy·first-fit decreasing으로 bin에 채움
  • throughput 평탄화·OOM 예방

9 구현 체크리스트

요소필요
cu_seqlens
max_seqlen
position_ids per-sample
label shift·ignore_index
FA varlen flag

10 흔한 실수

  • packing 시 RNG seed 동일 → 재배치 없음, 고정 셔플
  • cross-sample leak: mask 누락 → sample 간 attention 오염
  • varlen + padding 혼재: 한쪽 통일

1 5 axis 정리 ★ 분할 대상 DP·TP·PP·EP·CP

axis대상통신
DPbatch 복제AR grad
TPhidden 분할AR per layer
PPlayer 분할P2P activation
EPMoE expertall-to-all
CPsequence 분할KV rotate

상세 통신식 ↗ V15 §3

2 Megatron TP 구조

FFN: y = GeLU(x·W1)·W2
  W1 : col-parallel (split N)
  W2 : row-parallel (split K)
  마지막 all-reduce (sum)

QKV proj: col-parallel (head 분할)
Attn out: row-parallel + all-reduce

1 layer당 fwd 1× AR · bwd 1× AR

3 조합 기하 ★

world_size N_total = TP · PP · DP ( · EP · CP)
각 rank의 3D 좌표: (t, p, d) · t∈[TP], p∈[PP], d∈[DP] 직교 축으로 process group 생성

4 통신 패턴

axisbytes/step·rank
TPO(L · B · S · H)
PPO(M · B · S · H / P)
DPO(P_param/N_dp)

B:batch, S:seq, H:hidden, L:layer, M:microbatch, P:PP stages

5 TP NVLink 제약 ★

TP all-reduce는 매 layer 2회. NVLink 없으면 느림.
TP ≤ 8 (intra-node) 권장.
  • A100·H100 single node: NVLink 8-way full mesh
  • cross-node TP는 InfiniBand BW 제약

6 실전 구성표 ★

sizeTPPPDPtotal GPU
7B1188
13B211632
30B4216128
70B8416512
400B816162048

TP=intra-node · PP=inter-node · DP 가장 바깥

7 SP (Sequence Parallelism)

  • TP 안에서 LayerNorm·Dropout의 seq 축을 분할
  • all-gather·reduce-scatter가 TP all-reduce를 대체 (동일 bytes)
  • activation 메모리 추가 절감
  • Korthikanti 2022

8 축 순서 관습

inner  TP  (NVLink)
         │
         PP  (node-to-node P2P)
         │
outer  DP  (ZeRO/FSDP 가능)

TP는 가장 fast fabric, DP는 가장 slow에 배치

1 Bubble ratio 공식 ★

bubble = (P − 1) / M
efficiency = M / (M + P − 1) P: stage 수 · M: microbatch 수
  • M이 크면 bubble ↓ (파이프라인 채움)
  • M 제한: activation 메모리 (각 micro별)

2 GPipe 기본 all-fwd-then-bwd

P=4, M=4
t: 0 1 2 3 4 5 6 7 8 9 A B
G0 F F F F       B B B B
G1   F F F F   B B B B
G2     F F F F B B B B
G3       F F F F B B B B
     └─fwd warm─┘└─bwd tail─┘
bubble = 2(P-1)/M

∴ activation이 모든 micro 저장 → 메모리 많음

3 1F1B schedule ★ Megatron

P=4, M=8, 1F1B:
t: 0 1 2 3 4 5 6 7 8 9 A B C D
G0 F F F F B F B F B F B B B B
G1   F F F B F B F B F B F B B
G2     F F F B F B F B F B F B
G3       F B F B F B F B F B F
 warm-up    steady (1F1B)    drain
bubble = (P-1)/M (GPipe의 절반)
  • 정상 구간: 각 stage가 F,B를 교대
  • activation 보존 ≤ P (stage 수)

4 Interleaved (virtual pipeline)

P=4, v=2 (virtual stages/GPU), M=8:
G0: [s0]...[s4]  F F F F B B B B (교차)
G1: [s1]...[s5]  ...
bubble = (P-1) / (v·M)
  • 각 GPU가 v개 stage 담당
  • bubble v배 ↓ · P2P 통신 v배 ↑
  • Narayanan 2021 (Megatron)

5 Zero-Bubble ★ 최신

정의 backward를 B(activation grad) + W(weight grad)로 분리, W를 bubble 구간에 배치해 bubble → 0.
1F1B-BW:
G0 F F F F B W B W B W B W W W
           └─ B·W 교차 ─┘
Qi, Zero Bubble Pipeline (2023)

bubble → 근사 0 · 구현 복잡도 ↑

6 3종 비교표

schedulebubbleact mem구현
GPipe2(P−1)/MO(M)단순
1F1B(P−1)/MO(P)중간
Interleaved(P−1)/(v·M)O(P·v)복잡
Zero-Bubble~0O(P)매우복잡

7 선택 결정

  • M ≥ 4·P 가능 → 1F1B 충분
  • M 제약 있고 P 큼 → interleaved
  • 극한 throughput·구현 여유 → zero-bubble

1 원리 CUDA stream

정의 compute streamcomm stream을 별도로 두고, CUDA event로 의존성만 선언. NCCL kernel과 GEMM이 동시에 실행.
  • SM 자원 분할: NCCL은 일부 SM만 사용
  • HBM BW 공유가 병목이 됨

2 Bucketed grad reduce ★

정의 backward 중 grad가 준비되는 대로 bucket(예: 25MB)에 모아 즉시 RS/AR 시작.
bwd:  G_L1 G_L2 G_L3 G_L4 ...
              └bkt┘└bkt┘
              RS starts asap (overlap next B)

DDP·FSDP 기본 동작

3 Parameter prefetch

  • 다음 layer의 AG를 현재 compute와 겹침
  • FSDP: BACKWARD_PRE (↗ §5)
  • pipeline depth = prefetch 횟수

4 no_sync() ★

정의 DDP/FSDP의 grad sync(AR/RS)를 skip하는 context manager. 누적 후 마지막 step에만 sync.
for i, batch in enumerate(loader):
  if (i+1) % accum != 0:
    with model.no_sync():
      loss = model(batch)
      loss.backward()
  else:
    loss = model(batch)
    loss.backward()  # sync
    optimizer.step()

5 FSDP fine-grained overlap

phaseoverlap 대상
fwd L_iAG L_{i+1}
bwd L_iAG L_{i-1}
bwd L_iRS L_{i+1}

3중 overlap이 가능하나 구현이 복잡

6 TP overlap Megatron

  • SP all-gather를 다음 GEMM과 overlap
  • UserBuffers / CUDA graphs로 kernel launch 비용 감소
  • TE 모듈에 "ub_overlap_ag" flag

7 PP P2P overlap

  • P2P send/recv은 NCCL sendrecv group
  • compute가 길면 자연 은폐
  • short stage → P2P bound 가능

8 진단 포인트

증상원인 후보
NCCL stream 빈 공간bucket 너무 큼 → 분할
compute stream gapAG 대기 → prefetch 활성
HBM BW 포화bucket·GEMM 충돌 → 크기 조정
per-step variance ↑NCCL 순서 불안정

Nsight Systems trace 기준 ↗ V18 §12

9 흔한 실수

  • bucket 크기 과소: NCCL 호출 횟수 ↑ · BW ↓
  • single stream: NCCL이 compute 차단
  • event 누락: grad race → NaN

1 Adam 업데이트 식 복습

m_t = β₁ m_{t-1} + (1−β₁) g_t
v_t = β₂ v_{t-1} + (1−β₂) g_t²
m̂ = m_t / (1−β₁ᵗ) · v̂ = v_t / (1−β₂ᵗ)
w_t = w_{t-1} − η · m̂ / (√v̂ + ε) m·v: 1st·2nd moment · β₁≈0.9, β₂≈0.95~0.999

2 Adam vs AdamW weight decay

방식decay 적용
Adam (L2)grad에 λw 추가 후 Adam
AdamWupdate 후 w ← w(1−ηλ)
  • Loshchilov 2019 · LLM 사실상 표준은 AdamW
  • decay가 m/v 통해 왜곡되지 않음

3 메모리 B/param 비교

optimizerB/p (FP32 state)
SGD0
SGD + momentum4
Adam/AdamW8 (m+v)
Adafactor~2
LAMB8
Lion4 (momentum만)

master copy 4B/p 별도 추가

4 Adafactor ★ T5 학습

정의 v_t를 full matrix로 저장 대신 row sum r·col sum c로 rank-1 근사: v ≈ (r · cᵀ) / (1ᵀc).
  • matrix 한 축이 크면 메모리 O(m+n) vs O(mn)
  • 수렴 속도는 Adam 대비 약간 느릴 수 있음
  • Shazeer 2018

5 8-bit Adam ★ bitsandbytes

정의 m·v를 block-wise quantization으로 8-bit 저장, update 순간에만 FP32로 dequant.
statebaseline8-bit
m4B1B + scale
v4B1B + scale
총 state8B~2B

Dettmers 2022 · 수렴 품질 Adam과 동등

6 Lion 최신 대안

  • sign-based update: w ← w − η·sign(m̃)
  • state: momentum만 → 4B/p
  • LR·WD 튜닝 민감도 높음
  • Chen 2023 (EvoLved Sign Momentum)

7 품질-메모리 trade ★

optmem품질권장 맥락
AdamW12B표준대부분
Adafactor6B약간 ↓초대형·T5
8-bit AdamW6B≈ AdamW단일 GPU
Lion8B혼재비전·실험
SGD-M8B비전 OKLLM 비권장

master 4B 포함 값 · LLM pretrain 기준

8 Fused optimizer kernel

  • Apex FusedAdam·PyTorch foreach·fused=True
  • element-wise update를 단일 kernel
  • CUDA graph 친화·launch overhead ↓

9 흔한 실수

  • β₂ 기본값 0.999 고수: LLM에서 0.95가 통상 안정
  • ε=1e-8 FP16에서 0: 1e-6 권장
  • Adam + L2 대신 AdamW 혼동

1 왜 loss scaling FP16

정의 FP16 representable 최소 ≈ 6·10⁻⁵. gradient 대부분이 이보다 작아 underflow (→ 0). loss × S로 scale 올려 grad 보존.
  • backward 후 grad / S 로 복원 (unscale)
  • BF16은 range가 FP32와 같아 scaling 불필요
  • 수치 상세 ↗ V09 §4

2 Dynamic loss scale ★ up·skip·down

정의 inf/NaN 발생 시 step skip + S ← S/2. 2000 step 안정이면 S ← S·2.
state:
  S (scale factor)   init ~2^16
  growth_interval    e.g. 2000
  n_ok               success 카운트

per step:
  if any(grad = ±inf, NaN):
      skip step
      S ← S / 2; n_ok ← 0
  else:
      apply step
      n_ok ← n_ok + 1
      if n_ok == growth_interval:
          S ← S · 2; n_ok ← 0

3 GradScaler API

scaler = GradScaler()
for x,y in loader:
  with autocast(dtype=FP16):
    loss = model(x,y)
  scaler.scale(loss).backward()
  scaler.unscale_(opt)       # for clip
  clip_grad_norm_(...)
  scaler.step(opt)           # skip if inf
  scaler.update()            # adjust S

4 FP8 formats ★ H100+

formatE/M용도
E4M34/3activation·weight
E5M25/2gradient (range↑)

Micikevicius 2022 · IEEE binary8 candidate

5 FP8 scaling 전략

  • per-tensor scale factor (delayed scaling)
  • amax history 유지 → next-step scale 예측
  • TE가 amax·scale graph로 관리

6 Transformer Engine path ★

x (BF16) ─cast→ E4M3 ─┐
W (BF16) ─cast→ E4M3 ─┤ WGMMA (FP8×FP8)
                       └→ FP32 accum → BF16
grad (BF16) ─cast→ E5M2 (더 큰 range)
amax history → per-tensor scale
  • TE module이 자동 관리
  • 수치 상세 ↗ V09 §6

7 FP8 safe layer 경계

layerFP8?
QKV/FFN GEMM
LayerNormBF16 유지
softmax·attentionFP16/BF16
logit·lossFP32
embedding tableBF16 (보수적)

8 흔한 실수

  • unscale 전 clip: norm이 S배 왜곡
  • skip시 step 집계: global_step 증가 금지
  • FP8 amax 미동기화: DP 간 scale 불일치 → NaN

1 목표

정의 GPU가 매 step 입력 기다리지 않도록 prefetch · parallel decode · H2D overlap. 대상은 iostream·CPU preproc·H2D의 3병목.

2 PyTorch DataLoader

param의미
num_workersworker process 수
prefetch_factorworker당 미리 버퍼링
pin_memorypinned → H2D 빠름
persistent_workersepoch 간 재사용
collate_fnbatch 조립

3 Sharding 분산 학습

  • DistributedSampler: rank별 index 분할
  • 각 rank는 dataset의 1/N만 읽음
  • epoch마다 seed 동기화로 중복 방지
  • iterable dataset: webdataset·streaming

4 Shuffle 전략

방식특성
index shufflemap-style 전용
buffer shuffleiterable·window N
tar-shufflewebdataset tar 단위

global shuffle은 많은 메모리 필요 → buffer approximation

5 저장 포맷 ★

포맷장점단점
raw image/jpg호환성decode CPU
tfrecordseq read 빠름검색 ↓
webdataset (tar)streaming·shardrandom X
arrow/parquetcolumnar·효율image 불리
mosaic (mds)deterministic resumetooling 의존

6 GPU-side decoding ★

  • NVJPEG·NVIMGCODEC: JPEG를 GPU에서 decode
  • DALI pipeline: decode + augmentation 모두 GPU
  • CPU preproc bound 모델에서 2~5× 처리량

7 Text pretrain loader 특수 LLM

  • token-level sharding: fixed-length chunk
  • index file: .bin + .idx (Megatron)
  • seq boundary preserve 여부 선택
  • packing 전제 (↗ §8)

8 Resumable sampler

정의 state_dict에 sampler의 current index + epoch + seed를 기록. resume 시 중복/스킵 없이 재개.
  • DistributedSampler는 set_epoch 필수
  • streaming은 별도 offset tracking
  • ↗ §15 checkpoint

9 진단 지표

signal의미
data-wait % > 5%loader 부족
CPU ≈ 100%decode bound
IO wait저장매체 bound
variance ↑buffer 부족

10 흔한 실수

  • num_workers=0: GPU 유휴
  • pin_memory 미사용: H2D 대역폭 절반
  • augmentation CPU fork 비용: 과한 worker 역효과

1 저장 대상 resume 정확성

항목필수
model state_dict
optimizer state
lr scheduler
grad scalerFP16 시
RNG state (cpu/cuda)
dataloader sampler
global_step

2 저장 형태 ★

mode특성
rank-0 consolidated단일 파일·느림·OOM
sharded (per-rank)빠름·world 의존
DCP (torch)rank-independent load

torch.distributed.checkpoint 2.2+

3 파일 레이아웃 예

ckpt-1000/
├ model/
│ ├ __0_0.distcp
│ ├ __1_0.distcp
│ └ ...
├ optim/
├ meta.json   (world topology)
└ sampler_state.pt

4 Sharded save 장점

  • 각 rank가 병렬로 local shard만 dump
  • write BW ∝ N (node 수)
  • rank-0 bottleneck 없음
  • load도 shard-aware로 수신

5 Consolidation 배포용

정의 학습은 sharded, 배포·평가 시에만 단일 파일로 모음. DCP에는 별도 consolidate script 존재.
  • FSDP: StateDictType.FULL_STATE_DICT
  • rank-0에 모여 save
  • 메모리 여유 필요 (full param)

6 Reshard on load ★

정의 학습 world 구성이 바뀌어도 load 가능해야 함. DCP는 topology-agnostic metadata로 자동 reshape.
save: TP=8, PP=4, DP=16
load: TP=4, PP=4, DP=32   ← OK
load: TP=8, PP=4, DP=16   ← 동일
사용: resume시 node 수 변경

7 Async save CUDA stream

  • D2H copy를 비동기로 시작 → 학습 지속
  • CPU-side에서 disk write 병렬
  • save latency 감춤 → 빈번 저장 가능
  • Mosaic·DeepSpeed·Megatron 구현 존재

8 Save interval heuristic

요인권장
MTBF 예상interval ≤ MTBF/2
save cost≤ 1% wall-time
disk 용량rotation 3~5

interval ↑ → 손실 확률 ↑ · IO ↓

9 Resume 정확성 체크

  1. loss curve 연속성 (점프 없음)
  2. grad norm 분포 유사
  3. data iteration index 일치
  4. RNG 복원으로 dropout 재현

10 흔한 실수

  • sampler state 미저장: resume 후 data 중복
  • RNG state 누락: dropout 달라짐
  • rank-0 save만: world 변경 불가

1 실패 모드 대규모

유형빈도
GPU Xid error주기적
NVLink/IB link down드묾
host crash드묾
NCCL timeout네트워크 jitter
straggler (slow)자주

1000 GPU·30일 학습 시 수십건 통상

2 MTBF와 interval

E[useful] = MTBF − (interval/2) − t_save − t_resume
optimal_interval ≈ √(2·MTBF·t_save) Young-Daly 근사 · save가 비쌀수록 interval 길게

3 Graceful detect

  • NCCL async error → C10d exception
  • timeout 설정: NCCL_TIMEOUT (default 30min)
  • CUDA Xid log 파싱 (supervisor)
  • straggler detect: step time quantile

4 torchrun elastic ★

정의 rendezvous 서버가 살아있는 worker 세트를 재합의 → failed rank 제거·replacement 수용. auto reshard 후 resume.
  • --max-restarts, --rdzv-backend=c10d
  • memgroup·global_step이 rdzv metadata
  • world size 변경 가능 (±node)

5 Elastic training 구성

[rdzv server]
     │
┌────┴────┬────┬────┐
worker  worker  ...
  │        │
 fail!
     ↓
rdzv re-bargain (N→N-1 or +spare)
     ↓
load last ckpt (reshard topology)
     ↓
resume

6 Hot spare 전략

  • spare node 1~2대 대기
  • failure 시 rendezvous에 join
  • 대규모 scheduler (SLURM/k8s) 레벨 선택

7 Deterministic replay

  • RNG + sampler state 복원 (↗ §15)
  • NCCL non-determinism: NCCL_ALGO=Ring 고정
  • CuDNN deterministic flag
  • 단, FA·reduce-scatter는 non-associative
완벽한 bit-exact replay는 비현실적 — loss curve 근접을 목표.

8 Save interval tuning

scale권장 interval
8 GPU~1h
128 GPU~30m
1024 GPU~10~20m
10k GPU<10m (async 필수)

∝ 1/√MTBF · 실측이 최종

9 흔한 실수

  • rdzv 단일 실패점: etcd/consul HA 필요
  • NCCL_TIMEOUT 기본 30m: 대규모에 짧음
  • sampler state 없음: resume 데이터 중복
  • straggler 방치: throughput 전체 ↓

1 7B 권장 ★ 단일 node

항목
HW8 × A100/H100
ParallelismDP=8 + FSDP(FULL_SHARD)
TP/PP1/1
Mixed precisionBF16 / FP32 master
OptimizerAdamW fused
Activationselective ckpt
Seq packing권장

2 70B 권장 ★ 중규모

항목
HW64~512 × H100
TP8 (intra-node NVLink)
PP4~8
DP2~16 (+ZeRO-1)
SPon (activation ↓)
Schedule1F1B interleaved v=2
Activationfull ckpt·attention만 off

3 400B 권장 ★ 초대형

항목
HW2k~16k × H100/B200
TP8
PP16 (zero-bubble)
DP16~128 + ZeRO-1
CPlong seq 시 on
EPMoE 시 8~64
FP8TE path
Async ckpt필수

4 Decision tree

start
 └ model ≤ 10B? → FSDP (FULL_SHARD) only
     │
     ▽ no
 └ intra-node NVLink? → TP=8
     │
     ▽ yes
 └ model ≤ 100B? → TP8 + PP4 + DP
     │
     ▽ no
 └ TP8 + PP≥8 + ZeRO-1 + CP/EP 조합

5 흔한 안티패턴

  • TP cross-node: NVLink 없으면 금지
  • PP + ZeRO-3: AG가 stage 간 얽힘 → ZeRO-1만
  • small DP + huge PP: DP grad AR bubble

6 메모리 diagnosis ★

OOM 원인조치
model stateZeRO stage ↑ · TP 도입
activationckpt selective → full
peak full paramunit wrap 작게
kv cache학습엔 없음 (↗ V16)
optimizer state8-bit·offload

7 Throughput diagnosis

증상조치
data waitloader worker/prefetch ↑
comm boundoverlap·bucket tuning (↗ §11)
pipeline bubbleM 증가·interleaved·zero-bubble
HBM satselective ckpt·FP8

8 교차 참조 요약

  • NCCL primitive·대역폭 ↗ V15
  • KV cache·inference serving ↗ V16
  • FP8·loss scaling 수치 ↗ V09
  • FA varlen 상세 ↗ V07
  • Profiling·stall ↗ V18

9 최종 체크리스트

  1. model state 분산 전략 확정 (ZeRO/TP/PP)
  2. activation 전략 (ckpt level·SP)
  3. precision 경로 (BF16·FP8·loss scale)
  4. comm-compute overlap 활성
  5. checkpoint async·interval 설정
  6. elastic rdzv·straggler 감시