gpumode · 강의 아카이브
《GPU Mode》 L039 2024 · DEC High priority transcript · available

Torchtitan — PyTorch native distributed training

Megatron-LM 까지의 분산 학습 코드는 “그 framework 안에서 학습 코드를 다시 짜는” 부담을 강요했다. Torchtitan 의 답 — PyTorch native API 만으로 FSDP / Tensor Parallel / Sequence Parallel / Pipeline Parallel / 4D 까지 한 stack 으로. Mark Saroufim 과 Tianyu Liu 가 깐 PyTorch 분산 학습의 새로운 형태 — DTensor 가 어떻게 “여러 GPU 의 한 tensor” 를 표현하고, 통신과 계산이 어떻게 overlap 되며, checkpoint 를 어떻게 sharded 로 떠놓는지.

FSDP DTensor Tensor Parallel Sequence Parallel activation checkpointing comm overlap checkpoint sharding torch.compile
M
Speaker
Mark Saroufim
Meta · PyTorch core · GPU Mode 운영진
T
Speaker
Tianyu Liu
Meta · PyTorch distributed
강의 번호
L039
스피커
Mark Saroufim · Tianyu Liu
학습 우선순위
High · 정독
다시 볼 때
multi-GPU 직접
§ 01강의가 풀려는 문제· PyTorch native distributed

“분산 학습은 PyTorch 가 아니라 framework 가 한다” 의 시대를 끝낸다

2020년대 초까지 분산 학습은 보통 Megatron-LM 이나 DeepSpeed 같은 framework 의 일이었다. PyTorch 위에 layered 인 별도 stack 이 직접 model 을 받아 자기 식으로 분산 — 그러면 사용자가 그 framework 의 model code 를 따라가야 했다. 분산이 어렵다는 인식의 대부분이 거기서 나왔다.

Torchtitan 의 출발점은 단순하다 — “PyTorch 가 직접 분산을 한다.” 사용자는 자기 model 을 PyTorch 로 그대로 짜고, distributed 변환은 한 함수 호출 한 줄로 모델 위에 얹는다. FSDP, TP, SP, PP 가 모두 이 형태로 들어간다.

강의의 인지적 frame

Mark 의 입장 — “코드는 작아야 한다. Torchtitan 은 1000~2000 줄. 가져다가 자기 회사 안에서 자유롭게 fork 해라.” Megatron 처럼 self-contained 한 framework 가 아니라 학습 시작점으로의 reference. 이게 의도된 설계 — code base 차원의 결정.

“코드를 plagiarize 하라. 자기 회사로 가져가서 자기 식으로 고쳐라. 그게 우리가 원하는 사용 방식이다.”Mark Saroufim · 강의 도입부

그래서 강의 끝에 손에 잡혀야 할 자산 — (1) 5 가지 분산 차원 (DP / FSDP / TP / SP / PP) 이 어떻게 한 stack 위에서 결합되는가, (2) DTensor 가 “한 tensor 가 여러 GPU 에 나뉘어 있다” 를 어떻게 표현하는가, (3) 통신 overlap 이 forward / backward 양쪽에서 어떻게 일어나는가, (4) checkpoint 가 분산 환경에서 어떻게 sharded 로 떨어지는가.

§ 02분산 학습 코드의 표준화· 왜 새 stack 인가

DDP → FSDP → DTensor — PyTorch native 의 진화

2018 · DDP DistributedDataParallel모델 전체를 GPU 마다 복제. gradient 만 all-reduce. 같은 모델, 다른 데이터. data parallel
2022 · FSDP FullyShardedDataParallel모델 weight 도 GPU 마다 자른다. forward 직전 all-gather, backward 직후 reduce-scatter. ZeRO-3 의 PyTorch 구현. 큰 모델 · 한 GPU 안 못 들어감
2024 · DTensor distributed tensor 추상tensor 한 개가 “이 device mesh 위에서 sharded” 라는 metadata 를 가진다. TP / SP / FSDP 모두 이 위에서 표현. 통일 추상
2024 · Torchtitan DTensor 를 모델에 적용한 referenceLlama 류 모델에 한 줄로 FSDP+TP+SP 적용. 1500 줄 안. 최종 형태

강의에서 Tianyu 가 짚은 흐름 — “DDP 는 ‘모델이 한 GPU 에 들어간다’ 의 가정. FSDP 는 ‘들어가지 않는다’ 의 답. 그런데 FSDP 만으로는 큰 모델의 단일 layer 도 한 GPU 에 못 들어가는 경우가 있어 — 그래서 layer 안 weight 도 자르는 TP 가 필요. 이게 다 같은 추상으로 통일된 게 DTensor.”

코드 양 비교

Megatron-LM 은 자체 model 구현 + 분산 코드 합쳐 ~50k 줄. Torchtitan 은 ~1500~2000 줄. 차이의 본질은 “PyTorch 위에 얹는가, PyTorch 가 직접 하는가”. PyTorch 의 distributed primitive 가 충분히 성숙해진 시점이라 이 minimalism 이 가능.

§ 03FSDP · TP · SP 결합· 2D / 3D / 4D parallel

같은 모델을 5 차원으로 자른다 — DP × FSDP × TP × SP × PP

DP
Data Parallel
같은 모델, 다른 batch. gradient all-reduce. 가장 간단.
FSDP
Sharded Data Parallel
weight/optim/grad 모두 sharded. forward 직전 all-gather. 큰 모델 한 GPU 안 못 들어갈 때.
TP
Tensor Parallel
한 layer 안 weight 를 자른다. matmul 결과 all-reduce. node 안 NVLink 위에서.
SP
Sequence Parallel
activation 의 sequence 차원 분할. TP 와 자연스럽게 결합. activation 메모리 절약.
PP
Pipeline Parallel
layer stack 을 stage 로 자른다. node 간 bandwidth 절약. micro-batch 로 stall 채움.

강의에서 흥미로운 결정 — Torchtitan 은 PP 를 명시적으로 “이 강의에서는 안 다룬다” 고 못 박는다. 이유 — pipeline parallel 은 model 코드에 if/else 분기가 들어가는 경향이 있어 “단순한 model 코드 + 한 줄 분산” 의 철학이 깨진다. 다른 4 개는 model 코드 손 안 대고 wrap 한 줄로 적용 가능.

결합 패턴

실전 — FSDP × TP × SP 가 표준. 한 node (8 GPU) 안에서 TP+SP 로 layer 안 weight/activation 자르고, node 간으로 FSDP 로 layer-level sharding. 이게 2D parallel. 더 큰 모델은 PP 까지 더해 3D 또는 4D.

“5개 차원이 다 직교한다. 모델 코드는 그대로 두고, mesh 를 어떻게 자르는가만 바꾼다. 그 결과 같은 model 이 1 GPU 부터 1000 GPU 까지 같은 코드로 동작.”Tianyu Liu
§ 04PyTorch DTensor· 한 tensor, 여러 GPU

device mesh + placement 두 metadata 만으로 분산 표현

DTensor 의 핵심 — 하나의 tensor 가 (1) device mesh (이 GPU들에) + (2) placement (이렇게 자르거나 복제) 라는 두 metadata 를 갖는다. operator 가 이 metadata 를 보고 자동으로 통신을 끼워넣는다.

# DTensor — 한 tensor 가 mesh 위에서 sharded
from torch.distributed.tensor import (
    DeviceMesh, Shard, Replicate, distribute_tensor)

# 8 GPU 를 2x4 mesh 로
mesh = DeviceMesh("cuda", [[0,1,2,3],
                            [4,5,6,7]],
                  mesh_dim_names=("dp", "tp"))

# 같은 weight 를 dp 차원에는 복제, tp 차원으로는 분할
W = distribute_tensor(big_weight, mesh,
                      placements=[Replicate(), Shard(0)])

# 일반 tensor 처럼 쓴다 — operator 가 알아서 통신
y = x @ W                        # 자동으로 all-gather + matmul

이 구조의 의미 — 모든 분산 패턴이 “mesh + placement” 두 metadata 의 다른 조합으로 표현된다. FSDP 의 weight 는 Shard(0), TP 의 weight 는 Shard(0) on different mesh dim, DDP 의 weight 는 Replicate().

  • Shard(dim) — tensor 의 dim 차원을 mesh dim 위에 균등 분할.
  • Replicate() — mesh dim 의 모든 device 가 같은 사본.
  • Partial(reduce_op) — 부분합. 다음 op 에서 reduce 가 필요.

operator 가 placement 를 보고 적절한 collective (all-reduce / all-gather / reduce-scatter) 를 자동으로 끼워넣는다. 사용자는 “이 weight 가 어떻게 분산됐는지” 만 선언하고 통신 코드를 안 짠다.

redistribute — 한 표현에서 다른 표현으로

x.redistribute(placements=[...]) 가 placement 를 바꾼다. 예 — Shard(seq) 인 activation 을 Replicate() 로 바꾸면 all-gather. Partial 을 Replicate 으로 바꾸면 all-reduce. collective 가 redistribute 라는 한 함수로 통일.

§ 05checkpoint sharding· DCP — distributed checkpoint

1000 GPU 학습의 checkpoint 를 어떻게 떠놓고 다시 읽는가

분산 학습의 잘 안 보이는 어려운 문제 — checkpoint. 1000 GPU 가 sharded 로 weight 를 나눠 가지고 있을 때, 이걸 disk 에 어떻게 적고, 다시 읽을 때 다른 분산 형태(예: 다른 GPU 수, 다른 mesh)로 어떻게 reshape 하는가.

PyTorch 의 답 — DCP (Distributed Checkpoint). 두 개의 design 결정.

resharding 의 의미

예 — 1000 GPU 에서 학습 중간에 200 GPU 로 옮긴다. 이전 checkpoint 는 1000 rank shard. DCP 가 metadata 를 보고 새 200 rank 의 mesh 에 맞게 자동 재분할. weight 의 의미는 변하지 않으므로 학습이 그대로 이어진다. 이게 framework 차원에서 “계속 동작” 하는 게 production 의 핵심.

“checkpoint 가 분산 학습의 가장 무시당하는 hard problem. 1000 GPU 가 한 파일에 적으려 하면 그 자체가 학습보다 오래 걸린다.”학습 노트
§ 06활성화 재계산· selective AC

activation 메모리를 줄이려고 일부 layer 를 backward 에서 다시 forward

큰 모델의 학습은 메모리 부족이 거의 항상. activation 이 forward 동안 쌓여서 backward 에 쓰이는데, 이걸 다 보존하면 메모리 limit 을 넘긴다. 표준 답 — activation checkpointing (re-compute).

Torchtitan 이 추가한 것 — selective AC. 모든 layer 를 다 re-compute 하면 forward time 의 30% 정도 추가. 하지만 layer 마다 “저장 비용 / re-compute 비용” trade-off 가 다르다 — softmax 처럼 cheap 한 op 는 저장 안 하고, MLP 의 GEMM 결과처럼 expensive 한 op 는 저장.

두 모드

(a) full AC — layer 전체 re-compute. 메모리 절약 best, latency cost +30%. (b) selective AC — op 별 정책. “이 op 의 save 가 그 op 의 re-compute 보다 비싸면 re-compute”. Torchtitan 의 default 는 selective.

강의에서 Mark 가 짚은 — “selective AC 의 정책 list 는 hard-coded 가 아니라 offline 분석 결과. 일반적으로 잘 작동.” 즉 어떤 op 을 저장하고 어떤 op 을 re-compute 할지 미리 결정된 표가 있고, 사용자가 평소엔 안 건드린다.

§ 07통신 오버랩· async TP · compute hide

matmul 을 부분으로 쪼개서, 한 부분 결과 통신과 다음 부분 계산을 동시에

분산 학습의 두 번째 큰 비용 — collective communication (all-reduce, all-gather). NVLink/IB 가 빨라도 모델이 클수록 통신 시간이 꽉 찬다. 답은 계산과 통신의 overlap.

FIG · TP matmul 의 split + overlapidealized
naive
matmul (전체)
all-reduce
async TP
matmul A
matmul B
comm A
matmul C
comm B
comm C
matmul 을 N 조각으로 나누면, 한 조각 결과의 all-reduce 가 다음 조각의 matmul 과 겹친다. 통신 시간이 거의 사라짐.

강의에서 Tianyu 가 짚은 사실 — async TP 는 PyTorch 안에 일급 기능으로 들어와 있다. 사용자가 model 코드를 안 바꿔도 parallelize_module(...) 호출에서 async 옵션만 켜면 자동으로 분할 + overlap.

FSDP overlap

FSDP 도 비슷한 overlap 이 있다. forward N+1 layer 의 weight all-gather 를 forward N layer 의 compute 와 동시에. backward 는 그 반대 — backward N layer 의 grad reduce-scatter 를 backward N-1 layer 의 compute 와. 전부 PyTorch 내부에서 자동.

§ 08Megatron 과 비교· code base 차원

같은 일을 “하나의 framework 으로” vs “PyTorch native 로”

Megatron-LMTorchtitan의미
code base 크기~50k 줄~1.5k 줄철학 차이
model 정의framework 안 자체PyTorch 그대로대체 가능
분산 추상자체 구현DTensorPyTorch native
torch.compile제한기본 지원fusion
checkpoint자체 formatDCPcross-job 호환
기능 커버리지매우 넓음표준 시나리오trade-off
사용자가 fork 하기큰 frameworkcopy-paste 쉬움의도된 design

정확한 입장 — Megatron 이 사라지지 않는다. 매우 큰 모델의 깊은 최적화는 Megatron 이 여전히 유리한 영역. Torchtitan 은 “회사가 자기 model 코드를 그대로 쓰면서 분산을 하고 싶을 때”의 선택. 두 stack 이 다른 자리.

“Megatron 은 좋은 framework. 그런데 사용자가 그 framework 의 model 정의를 따라가야 한다. PyTorch native 의 의미는 그 부담이 없다는 것.”Mark Saroufim
§ 09채택 사례· production hooks

Meta 안팎에서 누가 어떻게 쓰는가

강의 시점 기준의 채택 풍경. NDA 로 자세한 회사 이름은 막혀 있지만 패턴은 분명. Llama 류 dense transformer 의 학습이 가장 자연스러운 fit.

실제 모델 종류 한계

강의 시점에 Torchtitan 이 강하게 지원하는 건 dense transformer (Llama 류). MoE 는 진행형. multimodal (vision + text) 도 부분적. 일반 model 을 넘어선 시나리오는 추가 작업이 필요. 이게 “회사가 자기 model 을 fork 해서 쓴다” 의 의미가 더 강해지는 자리.

§ 10기억할 메모와 코드· repo · paper

다시 열었을 때 5분 안에 손에 잡혀야 할 것

PyTorch native
framework 위 framework 가 아니라 PyTorch 자체로 분산. 모델 코드는 그대로.
DTensor
device mesh + placement 두 metadata 가 분산 표현의 1급 추상.
5 차원 분산
DP / FSDP / TP / SP / PP. PP 만 model 코드 분기 필요. 나머지 4 는 직교.
async TP
matmul 을 N 조각 + collective overlap. 통신 시간 거의 사라짐.
selective AC
activation checkpointing 의 op 별 정책. cheap 은 저장, expensive 는 re-compute.
DCP
distributed checkpoint. rank 별 파일 + metadata. resharding 가능.
code 양
~1500~2000 줄. fork 가 의도된 design. self-contained framework 안 됨.
Megatron 과의 위치
Megatron 안 사라짐. Torchtitan 은 “PyTorch 그대로” 가 우선인 사용자 자리.

손에 새기기 — 실습 시퀀스

  1. repo clone + 1 GPU 로 toy run — README 의 train.sh 한 줄. Llama 2 7B 의 초기 학습 step 이 single GPU 에서 도는지.
  2. 2 GPU 로 FSDP — 같은 명령에 --training.data_parallel_replicate_degree 1 --training.data_parallel_shard_degree 2. 메모리 사용량이 절반 가까이 줄어드는지.
  3. 4 GPU 로 FSDP×TP — TP=2, FSDP=2 의 2D mesh. matmul 안의 weight 가 어떻게 sharded 되는지 DTensor metadata 직접 print.
  4. DTensor 직접 만들기 — 작은 tensor 를 distribute_tensor 로 분산하고 redistribute 로 placement 변경. 어떤 collective 가 호출됐는지 NCCL trace 로 확인.
  5. checkpoint 떠보기 — 학습 중간에 DCP 로 ckpt. 그 ckpt 를 다른 GPU 수 (예: TP=4, FSDP=2) 로 다시 read. 학습이 이어지는지.
  6. activation checkpointing 모드 비교 — full / selective / off 세 모드의 메모리 사용량 + step 시간. selective 가 sweet spot 인지 직접 확인.
  7. async TP overlap 측정 — 같은 모델, async TP on/off. Chrome trace (L001) 위에서 collective 와 GEMM 이 실제로 겹치는지.
§ 11다른 강의로 이어지는 길· connections

Torchtitan 이 시리즈 안 다른 강의로 이어지는 자리

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

검증 메모

Torchtitan 은 빠르게 갱신 중. code 양 (1500~2000 줄), 채택 회사 수, 지원 모델 등은 강의 시점 기준. 자기 시점의 release 노트로 갱신 확인 필요.

← Lecture 038 Low Bit ARM kernels Lecture 040 → CUDA Docs for Humans — Charles Frye 가 깐 CUDA 의 task-oriented re-organization