gpumode · 강의 아카이브
《GPU Mode》 L013 2024 · MAR · 30 High priority transcript · available

Ring Attention

한 GPU 의 HBM 만으로는 100M token context 를 한 모델에 못 넣는다 — 1000 GB 를 요구하니까. Ring Attention 은 이 압박을 풀려고 sequence 차원을 device 들 사이에 분산하고, K/V 를 ring 토폴로지 위에서 회전시키면서 — online softmax 를 그대로 multi-device 로 확장한다. Andreas Köpf 가 cuda-mode/ring-attention 의 구현을 깐다.

Ring Attention sequence parallelism online softmax · 분산 K/V rotation communication overlap NCCL all-to-all blockwise parallel transformer long-context training log-sum-exp
A
Speaker
Andreas Köpf
GPU Mode 공동창립자 · OpenAssistant · LLM 분산 학습
강의 번호
L013
스피커
Andreas Köpf
학습 우선순위
High · 정독
Slides
repo · ring_attention.pptx
§ 01강의가 풀려는 문제· why this lecture exists

“100M token 컨텍스트를 한 모델 위에 어떻게 띄울 것인가”

Ring Attention 논문 (Liu, Yan, Abbeel 2023) 이 직접 인용한 수치 — batch 1, seq_len 100M token, modest model 한 개에 1000 GB 이상의 메모리가 필요하다. 단일 H100 (80 GB) 한 장으로 안 되고 — 8 장이라도 부족하다. 이 압박이 ring attention 이 등장한 이유.

강의가 풀려는 질문 두 가지.

  1. seq_len 차원을 여러 device 에 분산하면 attention 의 N×N quadratic 부분이 어떻게 풀리는가 — flash attention 의 online softmax 가 multi-device 로 확장되는 자리.
  2. 그 분산 구조에서 통신 비용을 어떻게 숨기는가 — N step 의 K/V 회전이 일어나는 동안 매 step 의 attention 계산이 동시에 도는가.
강의의 인지적 frame

Ring Attention 은 새 알고리즘이 아니라 — Flash Attention 의 online softmax 가 device 경계를 넘는 형식이다. § L012 에서 깐 “tile 사이의 누산 보정” 이 “device 사이의 누산 보정” 으로 바뀐다. 같은 수학, 다른 위치.

“ring attention 은 quadratic 을 지우지 않는다 — 그저 quadratic 을 N device 에 잘 나눈다. brute force 를 device 차원으로 확장한 것.”Andreas Köpf · 강의 paraphrase
§ 02시퀀스 차원 분산의 동기· DP/TP/PP 와의 차이

data · tensor · pipeline · sequence — 4 가지 분산축의 자리

분산 학습의 표준 축은 셋이었다 — DP (data), TP (tensor), PP (pipeline). ring attention 이 도입한 것은 네 번째 축, SP (sequence). 각 축이 무엇을 자르는지 명확히 짚어둘 필요가 있다.

DP — data parallel
batch 차원을 device 사이에 분산. 각 device 가 다른 sample 을 다룬다. weight 는 모든 device 에 복제. backward 끝에 gradient all-reduce. 장점: 단순. 한계: weight 가 한 device 에 들어가야 한다.
TP — tensor parallel
한 layer 의 weight 를 column/row 으로 잘라 device 사이에 분산. forward/backward 중간에 collective 가 자주 필요. 장점: weight 가 큰 모델 가능. 한계: NVLink 같은 고대역 interconnect 필수.
PP — pipeline parallel
layer 들을 device 사이에 stage 로 나눔. micro-batch 로 채워야 bubble 이 줄어듦. 장점: 큰 모델 가능, 통신 적음. 한계: pipeline bubble, 복잡한 schedule.
SP — sequence parallel ★
한 sample 의 sequence 차원을 device 사이에 분산. 본 강의의 주제. 한 sample 의 100M token 이 device 들에 나뉘어 들어간다. attention 의 N×N 부분이 device 사이의 통신으로 풀린다.
DP/TP/PP 와 SP 의 직교성

SP 는 다른 3개 축과 동시에 적용 가능하다. 실제 production 에서는 보통 DP × TP × PP × SP 로 4D mesh 를 만든다. 예: H100 64장 = 8(SP) × 4(TP) × 2(PP) × 1(DP). 각 축이 무엇을 자르는지 명확해야 mesh 디자인이 가능.

강의에서 Andreas 가 짚은 자리 — “ring attention 은 inference 가 아니라 training 의 기술이다”. inference 에서는 KV cache 를 사용해서 한 번에 한 token 을 생성하는데 — 그 흐름에서 K/V 가 device 사이를 회전하는 게 큰 의미가 없다. RA 의 표적은 long-context training (Large World Model 같은 사례).

§ 03log-sum-exp 의 결합법칙· 분산 softmax 의 토대

두 부분의 softmax 결과를 어떻게 합치는가 — log-sum-exp 의 associative 성질

강의 repo 의 howto_log_sum_exp.ipynb 가 보여주는 것 — softmax 의 정규화 분모가 log-sum-exp 형태이고, log-sum-exp 는 두 부분의 결과를 합칠 수 있는 결합법칙을 가진다. 이 성질이 ring attention 의 수학적 토대.

numerically stable softmax 는

softmax(x)_i = exp(x_i − m) / Σ_j exp(x_j − m)
                 where m = max(x)

여기서 분모를 log 로 묶으면 log-sum-exp:

lse(x) = log Σ_j exp(x_j) = m + log Σ_j exp(x_j − m)

핵심 성질 — x 를 두 부분 x¹, x² 로 나누면:

lse(x) = lse([lse(x¹), lse(x²)])

즉 부분 lse 두 개로 전체 lse 를 “bottom-up” 으로 만들 수 있다. 각 부분에 다른 device 가 일하고 결과만 모아도 된다는 뜻.

# howto_log_sum_exp.ipynb 의 검증 패턴
import torch

def naive_softmax(x):
    return x.exp() / x.exp().sum()

x  = torch.randn(10)
x1, x2 = torch.chunk(x, 2)

# 부분 softmax 둘
s1 = naive_softmax(x1)
s2 = naive_softmax(x2)

# 합치는 자리 — running max + lse 보정
m1, m2 = x1.max(), x2.max()
m      = torch.max(m1, m2)
l1     = torch.exp(m1 - m) * x1.exp().sum() * 1
l2     = torch.exp(m2 - m) * x2.exp().sum() * 1
l      = l1 + l2

# reweight: s1 와 s2 를 전체 normalizer 로 다시 표현
combined = torch.cat([
    s1 * (l1 / l),    # s1 의 분모를 l 로 교체
    s2 * (l2 / l),
])

# target = naive_softmax(x) 와 정확히 일치
FlashAttention 과의 동치성

§ L012 의 online softmax 패턴 — running max m, running sum l, running output o — 이 정확히 이 결합법칙의 streaming 버전. tile 단위로 보면 FA, device 단위로 보면 RA. 같은 수학이 두 scale 에 적용된다.

“log-sum-exp 가 결합법칙을 가진다는 사실 하나가 — flash attention 과 ring attention 모두를 가능하게 한 수학적 토대.”학습 노트 paraphrase
§ 04ring 토폴로지 위 K/V 회전· N step · N device

4 device · 4 step — K/V 가 한 바퀴 도는 동안 모든 attention 이 끝난다

강의의 본론. 4개 device 위에서 sequence 가 4 chunk 로 나뉘어 있을 때 — Q 는 자기 chunk 가 device 마다 고정되어 있고, K/V 는 매 step 마다 다음 device 로 흘러간다. 4 step 후에 K/V 가 한 바퀴 돌아서 모든 Q × K^T 짝이 다 계산된다.

FIG · ring topology — 4 device 위에서 K/V 가 회전step 0 ~ step 3
Dev 0 Q₀ · K₀V₀ Dev 1 Q₁ · K₁V₁ Dev 2 Q₂ · K₂V₂ Dev 3 Q₃ · K₃V₃ K/V rotation →
step 0local pair 각 device 가 자기 Q 와 자기 K/V 로 attention 의 한 부분을 계산.
device i: S_ii = Q_i × K_i^T → online softmax state (m, l, o) 누산.
통신 0. 모두 SRAM 안.
step 1K/V 한 칸 회전 각 device 가 K/V 를 다음 이웃으로 보낸다 (NCCL send/recv).
device i 는 이제 K/V_(i−1) 를 가진다.
같은 Q_i 와 새 K/V_(i−1) 로 추가 누산: m_new ← max(m, max(s_new)), 이전 누산 보정.
통신 1 step + 계산 1 step (overlap)
step 2 K/V 한 칸 더 회전. device i 는 K/V_(i−2) 를 가진다.
같은 누산 상태에 추가 contribution.
통신 + 계산 overlap
step 3마지막 회전 K/V 가 한 바퀴 다 돌고 device i 는 모든 K/V 를 이미 본 상태.
최종 O_i = o_i / l_i — softmax 의 분모로 정규화.
출력 완성
정확성 — log-sum-exp 의 보장

이 누산은 § 03 의 lse 결합법칙 덕분에 — 분산하지 않은 표준 attention 과 정확히 같은 결과를 낸다 (numerical precision 한도 안에서). approximation 이 아니다.

§ 05통신 vs 계산 오버랩· communication-free

K/V 가 다음 이웃으로 가는 동안 — 현재 K/V 로 attention 이 돈다

RA 의 핵심 성능 메커니즘. 매 step 의 K/V send/recv 는 NCCL 의 ring 통신을 쓰는데, 그 통신이 도는 동안 GPU 는 비어 있지 않고 이전 step 의 K/V 로 attention 계산을 한다. 두 경로가 같은 시간에 진행.

FIG · 통신과 계산의 overlapcompute stream + comm stream
compute
attn(Q_i, KV₀)
attn(Q_i, KV₁)
attn(Q_i, KV₂)
attn(Q_i, KV₃)
finalize
comm
send/recv KV
send/recv KV
send/recv KV
통신과 계산이 같은 시간에 도는 자리. NVLink 의 양방향 대역폭이 충분하면 — 통신 시간이 계산 시간 안에 흡수되어, ring attention 의 wall-clock 이 단일 device 의 attention 시간과 거의 같아진다 (단일 device 에서는 못 도는 큰 N 을 가능하게 하면서).
overlap 가 깨지는 자리

실제로는 — (1) interconnect 가 느리면 (PCIe 만 있고 NVLink 없으면) 통신이 계산보다 길어짐, (2) chunk 가 너무 작으면 계산이 통신보다 짧아짐, (3) NCCL kernel 자체도 SM 을 쓰기 때문에 큰 attention 과 SM 경합. § 08 에서 topology aware 매핑 디테일.

“ring attention 의 의미는 — memory 를 device 사이에 나누면서도 wall-clock 이 늘지 않는다. 통신이 계산 안에 숨는 한.”학습 노트 paraphrase
§ 06메모리 회계· per-device footprint

device 당 메모리는 어떻게 줄어드는가 — N → N/D

분산의 가치를 직접 측정하는 자리. D 개의 device 가 있을 때, 각 device 가 들고 있어야 하는 것:

표준 attention (단일 device)

Q, K, V: 각 N×d → 3 N·d
activation memory: ~Nd × layers
seq_len 100M, d=128, 32 layer 의 fp16: ~800 GB. 한 GPU 에 못 들어감.

ring attention (D 개 device)

device i 가 들고 있는 것: Q_i, K_i, V_i (N/D)×d + 한 개의 K/V tile rotation buffer (N/D)×d.
per-device: 4·N·d/D
D=8 일 때 device 당 ~100 GB. 충분히 들어감.

이 1차 효과 외에도 — activation memory 도 sequence 차원으로 잘려서 학습 중 layer 별 activation 도 device 사이에 분산된다. PyTorch 의 torch.distributed.checkpoint 와 함께 쓰면 forward/backward activation 도 잘게 자를 수 있다.

scaling 의 한계

device D 를 늘리면 메모리는 1/D 로 줄지만 — step 수도 D 로 늘어난다. attention 자체는 N²/D 의 계산을 D 번 → 총 N² 계산. 즉 wall-clock 의 strong scaling 은 “통신이 계산 안에 흡수되는 한” 만 유지된다. 그 한계점이 RA 의 실용적 sweet spot 결정.

§ 07대규모 컨텍스트 학습· Large World Model 사례

Liu, Yan, Abbeel 의 “Large World Model” — 1M token video

Andreas 가 강의에서 자주 언급한 사례 — Hao Liu, Wilson Yan, Pieter Abbeel 의 Large World Model 시리즈. 이 모델이 ring attention 으로 학습 가능한 첫 production-grade 결과.

주요 사실 (논문 참조).

  • video tokenization → 1M token 길이의 sequence.
  • ring attention 위에서 학습 — 32~64 GPU mesh, sequence 차원이 그 mesh 위에 분산.
  • 관련 작업의 흐름: BlockwiseParallelTransformer (Liu, Abbeel 2023) → Ring Attention (2023) → World Model (2024).

이 시리즈가 보여준 것 — “sequence 차원만 늘려서 model 의 long-range 능력을 직접 다룰 수 있다”. 기존의 sliding-window / sparse attention 같은 approximation 과 다른 자리.

cuda-mode/ring-attention

강의 시점에 Andreas 가 운영하던 community implementation — cuda-mode/ring-attention. PyTorch + Triton 기반. 현재 시점 repo 상태와 활성도는 확인 필요.

inference 에서의 한계

Andreas 가 강조 — “RA 는 training 의 기술이지, inference (특히 token-by-token decode) 의 기술이 아니다.” autoregressive decode 에서 KV cache 가 device 사이에 흩어져 있으면 매 token 마다 ring 통신이 필요해서 latency 가 폭발. inference 는 다른 분산 패턴 (TP, sharded KV) 을 쓴다.

§ 08NCCL · NVLink 매핑· topology aware

ring 의 “이웃” 은 누구인가 — physical topology 에 정확히 매핑되어야 한다

RA 의 알고리즘에서 “next neighbor” 는 logical 개념이지만 — 실제 성능은 logical neighbor 가 physical NVLink 로 직결된 device 인지에 달려 있다. ring 이 NVLink 의 ring 과 일치해야 함.

NVIDIA HGX 의 8-GPU 노드 — NVSwitch 가 연결되어 있어 어느 GPU 두 개도 NVLink 로 직접 통신 가능. ring 이 어떤 순서여도 상관없다. 최적 자리.

multi-node 가 끼면 — node 안 (intra-node) 은 NVLink, node 사이 (inter-node) 는 InfiniBand 또는 RoCE. node 간 hop 이 ring 안에 있으면 그 step 만 5-10× 느려진다. 이 때문에 ring 을 “node 안 ring + node 간 ring” 의 hierarchical 형태로 짜야 함.

실용 도구

NCCL 의 NCCL_TOPO_DUMP_FILE 환경변수로 자기 노드의 topology graph 를 dump. 그 위에서 ring 을 명시적으로 정의 (NCCL_RINGS). PyTorch 의 torch.distributed.init_process_group 호출 전에 설정해야 함.

통신 패턴 분리

실제 RA 구현에서는 — K 와 V 를 따로 send/recv 하지 않고 concatenated buffer 한 번으로 보낸다. NCCL 의 latency overhead 가 message 크기에 약간만 비례하므로 두 작은 message 보다 한 큰 message 가 빠르다.

“ring 의 logical 순서가 physical NVLink 의 ring 과 어긋나면 — 통신 latency 가 두세 배. 같은 algorithm 이 mesh 매핑 한 줄로 무너진다.”학습 노트 paraphrase
§ 09한계와 변종· stripe · zigzag

causal mask 가 RA 의 자연스러운 부하 균형을 깬다 — stripe / zigzag 변종

Andreas 가 강의 후반부에 짚은 함정 — RA 가 unmasked attention 에서는 device 마다 동등한 일을 하지만, autoregressive LM 의 causal mask 가 들어가면 부하가 device 마다 크게 달라진다.

원인. causal mask 는 “query token i 가 key token j ≤ i 만 본다” 의 조건 — 즉 attention 행렬이 lower-triangular. RA 의 단순 partition 에서는:

  • device 0 (Q₀, sequence 의 앞쪽 chunk) 은 자기 K₀V₀ 만 보면 됨. 일이 적음.
  • device D-1 (Q_{D-1}, sequence 의 뒤쪽 chunk) 은 모든 K/V 를 봐야 함. 일이 많음.

결과: device 마다 부하가 D 배 차이. wall-clock 은 가장 느린 device 에 묶임 → strong scaling 깨짐.

stripe RA — 순서 재배치

해결책 1. token 들을 순서가 아닌 stripe (interleave) 로 device 사이에 배치. 모든 device 가 sequence 의 앞/뒤 일부를 섞어 가지므로 각 device 의 평균 부하가 비슷해짐. 알고리즘은 그대로, partition 만 바뀜.

zigzag RA

해결책 2. ring 의 회전 방향을 step 마다 바꿔서 부하를 분산. 더 정교한 변형 — Liu et al 의 후속 연구 / 다른 group 들의 implementation. 강의에서는 이름만 언급.

다른 한계들

(1) flash attention 자체의 register spilling 이 분산 안에서도 그대로 — 각 device 의 local attention 이 register-aware 하게 짜여야 한다, (2) backward 의 통신 패턴이 forward 와 다름 — gradient 가 ring 의 반대 방향으로도 흐름, (3) gradient checkpointing 과 결합할 때 추가 sync 필요.

§ 10기억할 메모와 코드· key takeaways · repo

다시 열었을 때 5분 안에 손에 잡혀야 할 것

log-sum-exp 결합법칙
lse(x) = lse([lse(x¹), lse(x²)]). 부분 lse 두 개로 전체 lse 복원. RA 와 FA 의 공통 토대.
SP — 4번째 분산축
DP/TP/PP 와 직교. 한 sample 의 sequence 를 device 사이에 분산. attention 의 N×N 부분이 통신으로 풀린다.
ring rotation
Q 는 device 마다 고정, K/V 가 D step 동안 ring 을 한 바퀴 회전. 매 step online softmax 누산 갱신.
comm-compute overlap
step n 의 K/V 가 다음 device 로 가는 동안 step n−1 의 K/V 로 attention 계산. NVLink 가 충분하면 통신이 흡수됨.
메모리 4·N·d/D
device 당 Q + K + V + rotation buffer. 단일 device 의 1/D. activation 도 sequence 차원으로 sharded.
training only
RA 는 train 시 long context 를 푸는 도구. autoregressive inference 에서는 ring 통신이 매 token 마다 발생 → latency 폭발.
causal mask 부하 불균형
단순 partition 시 device 간 D 배 부하 차이. stripe RA 또는 zigzag 로 재배치 필요.
topology mapping
logical ring 이 physical NVLink ring 과 일치해야 함. NCCL_RINGS 와 topology dump 로 강제.

손에 새기기 — 실습 시퀀스

  1. log-sum-exp 결합법칙 검증howto_log_sum_exp.ipynb 를 그대로 실행. 두 chunk 의 부분 softmax 를 합쳐 전체 softmax 와 일치하는지 확인.
  2. 2-GPU ring attention — torchrun 으로 2 process. seq_len 을 반으로 나누고 K/V 를 한 번 회전시켜 attention 결과가 단일 device 와 일치하는지 verify.
  3. 4-GPU ring + overlap — 같은 코드를 4 device 로. NCCL 통신과 attention 계산을 별도 stream 으로 깐다. nsys 로 실제 overlap 시각화.
  4. causal mask 부하 측정 — masked vs unmasked 에서 device 별 attention 시간을 따로 측정. 부하 불균형 D 배 차이 직접 관찰.
  5. stripe partition 적용 — token 을 stripe 로 재배치한 RA 구현. masked attention 의 device 간 부하가 균등해지는지 확인.
  6. NCCL ring 매핑 확인NCCL_DEBUG=INFONCCL_TOPO_DUMP_FILE 로 자기 노드의 ring 이 NVLink 와 일치하는지 검증.
  7. seq_len scaling — 4K, 16K, 64K 에서 ring attention 의 wall-clock 측정. communication-compute overlap 이 깨지는 지점 찾기.
§ 11다른 강의로 이어지는 길· connections

이 강의의 도구가 다음에 어디에 다시 등장하는지

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

검증 메모

이 노트의 step 0~3 시각화는 4-device ring 의 단순 모델. 실제 production 구현은 stripe partition, hierarchical ring, gradient checkpoint 결합 등이 섞여 더 복잡하다. cuda-mode/ring-attention 의 코드를 직접 보는 것이 강의의 마무리 학습.

← Lecture 012 Flash Attention — single device 에서 multi device 로의 자연스러운 확장 Lecture 014 → Practitioners Guide to Triton — Umer Adil 이 깐 Triton 의 실전 패턴