한 GPU 의 HBM 만으로는 100M token context 를 한 모델에 못 넣는다 — 1000 GB 를 요구하니까. Ring Attention 은 이 압박을 풀려고 sequence 차원을 device 들 사이에 분산하고, K/V 를 ring 토폴로지 위에서 회전시키면서 — online softmax 를 그대로 multi-device 로 확장한다. Andreas Köpf 가 cuda-mode/ring-attention 의 구현을 깐다.
Ring Attention 논문 (Liu, Yan, Abbeel 2023) 이 직접 인용한 수치 — batch 1, seq_len 100M token, modest model 한 개에 1000 GB 이상의 메모리가 필요하다. 단일 H100 (80 GB) 한 장으로 안 되고 — 8 장이라도 부족하다. 이 압박이 ring attention 이 등장한 이유.
강의가 풀려는 질문 두 가지.
Ring Attention 은 새 알고리즘이 아니라 — Flash Attention 의 online softmax 가 device 경계를 넘는 형식이다. § L012 에서 깐 “tile 사이의 누산 보정” 이 “device 사이의 누산 보정” 으로 바뀐다. 같은 수학, 다른 위치.
분산 학습의 표준 축은 셋이었다 — DP (data), TP (tensor), PP (pipeline). ring attention 이 도입한 것은 네 번째 축, SP (sequence). 각 축이 무엇을 자르는지 명확히 짚어둘 필요가 있다.
SP 는 다른 3개 축과 동시에 적용 가능하다. 실제 production 에서는 보통 DP × TP × PP × SP 로 4D mesh 를 만든다. 예: H100 64장 = 8(SP) × 4(TP) × 2(PP) × 1(DP). 각 축이 무엇을 자르는지 명확해야 mesh 디자인이 가능.
강의에서 Andreas 가 짚은 자리 — “ring attention 은 inference 가 아니라 training 의 기술이다”. inference 에서는 KV cache 를 사용해서 한 번에 한 token 을 생성하는데 — 그 흐름에서 K/V 가 device 사이를 회전하는 게 큰 의미가 없다. RA 의 표적은 long-context training (Large World Model 같은 사례).
강의 repo 의 howto_log_sum_exp.ipynb 가 보여주는 것 — softmax 의 정규화 분모가 log-sum-exp 형태이고, log-sum-exp 는 두 부분의 결과를 합칠 수 있는 결합법칙을 가진다. 이 성질이 ring attention 의 수학적 토대.
numerically stable softmax 는
softmax(x)_i = exp(x_i − m) / Σ_j exp(x_j − m)
where m = max(x)
여기서 분모를 log 로 묶으면 log-sum-exp:
lse(x) = log Σ_j exp(x_j) = m + log Σ_j exp(x_j − m)
핵심 성질 — x 를 두 부분 x¹, x² 로 나누면:
lse(x) = lse([lse(x¹), lse(x²)])
즉 부분 lse 두 개로 전체 lse 를 “bottom-up” 으로 만들 수 있다. 각 부분에 다른 device 가 일하고 결과만 모아도 된다는 뜻.
# howto_log_sum_exp.ipynb 의 검증 패턴
import torch
def naive_softmax(x):
return x.exp() / x.exp().sum()
x = torch.randn(10)
x1, x2 = torch.chunk(x, 2)
# 부분 softmax 둘
s1 = naive_softmax(x1)
s2 = naive_softmax(x2)
# 합치는 자리 — running max + lse 보정
m1, m2 = x1.max(), x2.max()
m = torch.max(m1, m2)
l1 = torch.exp(m1 - m) * x1.exp().sum() * 1
l2 = torch.exp(m2 - m) * x2.exp().sum() * 1
l = l1 + l2
# reweight: s1 와 s2 를 전체 normalizer 로 다시 표현
combined = torch.cat([
s1 * (l1 / l), # s1 의 분모를 l 로 교체
s2 * (l2 / l),
])
# target = naive_softmax(x) 와 정확히 일치
§ L012 의 online softmax 패턴 — running max m, running sum l, running output o — 이 정확히 이 결합법칙의 streaming 버전. tile 단위로 보면 FA, device 단위로 보면 RA. 같은 수학이 두 scale 에 적용된다.
강의의 본론. 4개 device 위에서 sequence 가 4 chunk 로 나뉘어 있을 때 — Q 는 자기 chunk 가 device 마다 고정되어 있고, K/V 는 매 step 마다 다음 device 로 흘러간다. 4 step 후에 K/V 가 한 바퀴 돌아서 모든 Q × K^T 짝이 다 계산된다.
S_ii = Q_i × K_i^T → online softmax state (m, l, o) 누산.
통신 0. 모두 SRAM 안.
이 누산은 § 03 의 lse 결합법칙 덕분에 — 분산하지 않은 표준 attention 과 정확히 같은 결과를 낸다 (numerical precision 한도 안에서). approximation 이 아니다.
RA 의 핵심 성능 메커니즘. 매 step 의 K/V send/recv 는 NCCL 의 ring 통신을 쓰는데, 그 통신이 도는 동안 GPU 는 비어 있지 않고 이전 step 의 K/V 로 attention 계산을 한다. 두 경로가 같은 시간에 진행.
실제로는 — (1) interconnect 가 느리면 (PCIe 만 있고 NVLink 없으면) 통신이 계산보다 길어짐, (2) chunk 가 너무 작으면 계산이 통신보다 짧아짐, (3) NCCL kernel 자체도 SM 을 쓰기 때문에 큰 attention 과 SM 경합. § 08 에서 topology aware 매핑 디테일.
분산의 가치를 직접 측정하는 자리. D 개의 device 가 있을 때, 각 device 가 들고 있어야 하는 것:
Q, K, V: 각 N×d → 3 N·d
activation memory: ~Nd × layers
seq_len 100M, d=128, 32 layer 의 fp16: ~800 GB. 한 GPU 에 못 들어감.
device i 가 들고 있는 것: Q_i, K_i, V_i (N/D)×d + 한 개의 K/V tile rotation buffer (N/D)×d.
per-device: 4·N·d/D
D=8 일 때 device 당 ~100 GB. 충분히 들어감.
이 1차 효과 외에도 — activation memory 도 sequence 차원으로 잘려서 학습 중 layer 별 activation 도 device 사이에 분산된다. PyTorch 의 torch.distributed.checkpoint 와 함께 쓰면 forward/backward activation 도 잘게 자를 수 있다.
device D 를 늘리면 메모리는 1/D 로 줄지만 — step 수도 D 로 늘어난다. attention 자체는 N²/D 의 계산을 D 번 → 총 N² 계산. 즉 wall-clock 의 strong scaling 은 “통신이 계산 안에 흡수되는 한” 만 유지된다. 그 한계점이 RA 의 실용적 sweet spot 결정.
Andreas 가 강의에서 자주 언급한 사례 — Hao Liu, Wilson Yan, Pieter Abbeel 의 Large World Model 시리즈. 이 모델이 ring attention 으로 학습 가능한 첫 production-grade 결과.
주요 사실 (논문 참조).
이 시리즈가 보여준 것 — “sequence 차원만 늘려서 model 의 long-range 능력을 직접 다룰 수 있다”. 기존의 sliding-window / sparse attention 같은 approximation 과 다른 자리.
강의 시점에 Andreas 가 운영하던 community implementation — cuda-mode/ring-attention. PyTorch + Triton 기반. 현재 시점 repo 상태와 활성도는 확인 필요.
Andreas 가 강조 — “RA 는 training 의 기술이지, inference (특히 token-by-token decode) 의 기술이 아니다.” autoregressive decode 에서 KV cache 가 device 사이에 흩어져 있으면 매 token 마다 ring 통신이 필요해서 latency 가 폭발. inference 는 다른 분산 패턴 (TP, sharded KV) 을 쓴다.
RA 의 알고리즘에서 “next neighbor” 는 logical 개념이지만 — 실제 성능은 logical neighbor 가 physical NVLink 로 직결된 device 인지에 달려 있다. ring 이 NVLink 의 ring 과 일치해야 함.
NVIDIA HGX 의 8-GPU 노드 — NVSwitch 가 연결되어 있어 어느 GPU 두 개도 NVLink 로 직접 통신 가능. ring 이 어떤 순서여도 상관없다. 최적 자리.
multi-node 가 끼면 — node 안 (intra-node) 은 NVLink, node 사이 (inter-node) 는 InfiniBand 또는 RoCE. node 간 hop 이 ring 안에 있으면 그 step 만 5-10× 느려진다. 이 때문에 ring 을 “node 안 ring + node 간 ring” 의 hierarchical 형태로 짜야 함.
NCCL 의 NCCL_TOPO_DUMP_FILE 환경변수로 자기 노드의 topology graph 를 dump. 그 위에서 ring 을 명시적으로 정의 (NCCL_RINGS). PyTorch 의 torch.distributed.init_process_group 호출 전에 설정해야 함.
실제 RA 구현에서는 — K 와 V 를 따로 send/recv 하지 않고 concatenated buffer 한 번으로 보낸다. NCCL 의 latency overhead 가 message 크기에 약간만 비례하므로 두 작은 message 보다 한 큰 message 가 빠르다.
Andreas 가 강의 후반부에 짚은 함정 — RA 가 unmasked attention 에서는 device 마다 동등한 일을 하지만, autoregressive LM 의 causal mask 가 들어가면 부하가 device 마다 크게 달라진다.
원인. causal mask 는 “query token i 가 key token j ≤ i 만 본다” 의 조건 — 즉 attention 행렬이 lower-triangular. RA 의 단순 partition 에서는:
결과: device 마다 부하가 D 배 차이. wall-clock 은 가장 느린 device 에 묶임 → strong scaling 깨짐.
해결책 1. token 들을 순서가 아닌 stripe (interleave) 로 device 사이에 배치. 모든 device 가 sequence 의 앞/뒤 일부를 섞어 가지므로 각 device 의 평균 부하가 비슷해짐. 알고리즘은 그대로, partition 만 바뀜.
해결책 2. ring 의 회전 방향을 step 마다 바꿔서 부하를 분산. 더 정교한 변형 — Liu et al 의 후속 연구 / 다른 group 들의 implementation. 강의에서는 이름만 언급.
(1) flash attention 자체의 register spilling 이 분산 안에서도 그대로 — 각 device 의 local attention 이 register-aware 하게 짜여야 한다, (2) backward 의 통신 패턴이 forward 와 다름 — gradient 가 ring 의 반대 방향으로도 흐름, (3) gradient checkpointing 과 결합할 때 추가 sync 필요.
NCCL_RINGS 와 topology dump 로 강제.howto_log_sum_exp.ipynb 를 그대로 실행. 두 chunk 의 부분 softmax 를 합쳐 전체 softmax 와 일치하는지 확인.NCCL_DEBUG=INFO 와 NCCL_TOPO_DUMP_FILE 로 자기 노드의 ring 이 NVLink 와 일치하는지 검증.이 노트의 step 0~3 시각화는 4-device ring 의 단순 모델. 실제 production 구현은 stripe partition, hierarchical ring, gradient checkpoint 결합 등이 섞여 더 복잡하다. cuda-mode/ring-attention 의 코드를 직접 보는 것이 강의의 마무리 학습.