gpumode · 강의 아카이브
《GPU Mode》 L036 2024 · NOV High priority transcript · available

CUTLASS & Flash Attention 3

H100 의 새 자원 — TMAWGMMA — 가 attention 커널의 모양을 통째로 바꾼다. FA2 까지의 multi-stage pipeline 이 H100 에서는 더 이상 효율적이지 않다는 측정에서 출발해, warp specialization (producer-consumer)ping-pong scheduling, FP8 attention 의 in-kernel transpose, 그리고 그 모든 걸 표현하는 CuTe layout 까지 — Tri Dao 의 공저자 Jay Shah 가 FA3 의 설계 결정을 한 시간 안에 깐 강의를 학습 노트로 다시 정리한다.

H100 TMA WGMMA warp specialization CuTe FP8 ping-pong FlashAttention
J
Speaker
Jay Shah
Colfax Research · FA3 공저자
강의 번호
L036
스피커
Jay Shah
학습 우선순위
High · 정독
다시 볼 때
CUTLASS code 같이
§ 01강의가 풀려는 문제· 왜 FA2 로 부족한가

FA2 를 H100 에 그대로 옮겨봤더니 — peak 의 35% 정도밖에 못 받았다

FlashAttention 2 는 A100 에서 매우 좋다. 그런데 같은 코드를 H100 으로 들고 가면 peak FP16 throughput 의 약 35% 만 받는다 (강의의 측정). gem 커널은 80% 위로 가는데 attention 만 떨어진다. 이 격차가 FA3 의 출발점.

이유는 이산적이다 — H100 은 두 개의 새 자원을 도입했는데, FA2 의 multi-stage pipeline 구조로는 그 둘을 동시에 못 끌어 쓴다.

둘 다 비동기다. FA2 의 multi-stage 패턴은 한 thread block 안에서 cp.async 로 K/V tile 을 미리 가져오고 다음 step 의 MMA 와 인터리브하는 구조 — 이게 H100 의 두 자원과 잘 안 맞는다. 자원의 양이 늘어났는데 그걸 동시에 차지하는 단일 thread/warp 의 결정 구조가 부족하다.

강의의 인지적 frame

Jay 의 일관된 입장 — “H100 의 진짜 prog 모델은 ‘하나의 thread block 이 모든 일을 다 한다’ 가 아니라 ‘warp group 들이 producer/consumer 로 나뉘어 비동기로 협력한다’ 다.” FA3 는 이 mental model 위에서 다시 짜진 attention 이고, 이 강의는 그 설계를 단계별로 깐다.

“BF16 에서 FA2 대비 최대 3배 빠르다. peak 의 85% 까지 올라간다 — gem 커널과 사실상 동등.”Jay Shah · 강의 도입부

강의 끝에 손에 잡혀야 하는 자산은 — (1) producer-consumer pipeline 이 attention 커널 안에서 어떻게 구체화되는지, (2) warp specialization 이 register 파티션과 어떻게 묶이는지, (3) ping-pong scheduling 이 softmax 와 GEMM 의 직렬화를 어떻게 푸는지, (4) CuTe layout 이 같은 SMEM 을 두 형태로 보는 트릭을 어떻게 표현하는지.

§ 02H100 의 새 자원· TMA · WGMMA

“하나의 thread 가 큰 일을 시키는” 모델로의 전환

A100 까지의 모델 — 모든 thread 가 자기 자리의 일을 하고, 그 합이 곧 SM 의 일. H100 에서는 그 모델이 깨진다. TMA 는 단일 thread 가 issue 하고, WGMMA 는 warp group 단위로 issue 한다. 이 두 자원은 “더 큰 단위” 로 일을 시키게 만든다.

A100 (cp.async)
thread 별 load
각 thread 가 자기 자리의 한 element 를 cp.async 로 가져옴. issue 비용이 thread 수만큼 누적. SMEM 채우려면 모든 thread 가 협력.
H100 (TMA)
한 thread 가 tile 을 issue
한 elected thread 가 “이만큼 가져와라” 한 줄. 나머지 thread 는 다른 일에 쓸 수 있다. address calculation 도 hardware 가 함.
register 효과
load 코드가 사라진다
FA2 에서 load 에 쓰던 register 들이 풀린다. 그 자리가 softmax stat (m, l) 또는 더 큰 tile 의 accumulator 로 간다.

WGMMA 도 비슷한 구조 변화를 일으킨다. A100 의 mma.sync 는 warp 단위 (32 thread) — 하지만 H100 의 wgmma 는 warp group (128 thread = 4 warp) 단위. 한 issue 가 더 큰 MMA 를 비동기로 시작하고, 결과는 wgmma_wait 로 받는다.

register 회계가 바뀐다

WGMMA 의 accumulator 는 register 위에 있고, 4 warp 가 분할 보유. FA2 에서는 이 자리가 매우 압박이었다. FA3 는 producer warp 가 register 를 거의 안 쓰게 deallocate 하고, 그만큼 consumer warp 에 register 를 더 줘서 spill 을 막는다 — 이게 register reallocation 의 핵심 이득.

§ 03producer-consumer pipeline· warp specialization

한 thread block 안에서 warp group 을 두 역할로 나눈다

FA3 의 가장 큰 구조 변화. 한 CTA(thread block) 안에 3개 warp group 을 띄우고, 그중 1 개를 producer, 2 개를 consumer 로 쓴다. producer 는 K/V tile 을 TMA 로 가져오는 일만 하고, consumer 는 MMA 와 softmax 를 한다.

FIG · FA3 의 한 CTA 내부 — 3 warp groupswarp specialization
WG 0 · 32 threads → 128 으로 보면 1 warp
PRODUCER
TMA 로 다음 K/V tile 을 SMEM 에 채운다. register 거의 안 씀.
WG 1 · 4 warp = 128 threads
CONSUMER A
현재 K tile 로 QK^T MMA + softmax + V MMA. 절반 row.
WG 2 · 4 warp = 128 threads
CONSUMER B
consumer A 의 다른 절반 row 를 동시에 처리. ping-pong (§04).
SMEM
circular buffer
N stage 로 K/V tile 을 미리 채워둔다. mbarrier 로 producer↔consumer 동기.
producer 와 consumer 사이 동기는 mbarrier (memory barrier object) 로 한다. 각 stage 마다 “채워졌다” / “비어졌다” 두 상태를 trace — 이게 H100 이 hardware 로 지원하는 mechanism.

이 분리가 만들어내는 이득은 register 차원에서 가장 크다. producer 는 작은 일만 하므로 register 를 deallocate (setmaxnreg.dec) 해서 consumer 에게 더 준다. 그 결과 consumer 가 attention 의 큰 accumulator 와 softmax stat (m, l) 을 register 에 모두 보유하면서도 spill 을 안 일으킨다.

“producer 가 register 를 양보하지 않으면 consumer 가 spill 한다 — 그러면 attention 커널의 register reallocation 이 attention 의 정체성이다.”Jay Shah
§ 04ping-pong scheduling· 두 consumer 의 교차

softmax 와 GEMM 이 서로의 그림자에 들어간다 — inter-warp-group 병렬

attention 의 inner loop 는 — QK^T → softmax → P·V → softmax stat 갱신. 이 시퀀스에서 softmax 가 GEMM 과 다른 unit (MUFU.EX2 / 일반 ALU) 위에서 도니까 둘이 동시에 굴러가야 throughput 이 산다. ping-pong 은 그 동시 진행을 강제하는 schedule.

FIG · ping-pong schedule (idealized)두 consumer 의 시간선
Consumer A
QK^T
softmax
PV
QK^T
softmax
PV
Consumer B
softmax
PV
QK^T
softmax
PV
QK^T
tensor core util
거의 항상 점유
A 가 softmax 를 도는 동안 B 는 GEMM. B 가 softmax 를 도는 동안 A 는 GEMM. tensor core 점유율이 떨어지지 않는다 — 같은 SM 위 두 consumer 가 서로의 그림자.

ping-pong 은 idealized 하다. 실제로 두 consumer 사이의 정확한 위상 정렬은 어렵고, 강의에서 Jay 가 명시한 한 줄 — “barrier 로 어느 정도 근사할 수 있지만 hardware 가 직접 보장해주지는 않는다.” 그래도 측정상 평균 점유율이 크게 올라간다.

intra-warp-group overlap 도 같이 깐다

위는 inter-warp-group (A vs B) 의 ping-pong. 거기에 더해 한 consumer 안에서도 softmax 의 일부 (rescale, stat update) 가 다음 GEMM 의 issue 와 겹치게 코드를 짠다. Jay 가 이걸 “warp specialization 의 fine-grained 버전” 으로 부른다.

§ 05FP8 attention 의 in-kernel transpose· layout 강제 회피

FP8 의 layout 제약을 register 위에서 트릭으로 푼다

H100 의 FP8 MMA 는 layout 제약이 까다롭다. 두 피연산자 중 한쪽이 row-major, 다른 쪽이 col-major 여야 하는 등. attention 안의 V 텐서가 자연스럽게 들어오는 layout 과 안 맞을 수 있다 — 그러면 transpose 가 필요한데, 그게 SMEM 트래픽으로 가면 비용이 크다.

FA3 의 답 — SMEM 위에서 layout 을 두 번 본다. 같은 데이터를 SMEM 에 한 번만 적고, 두 read path 가 각자 필요한 layout 으로 본다. CuTe 의 layout algebra 가 이걸 표현한다.

in-kernel transpose 의 위치

강의에서 Jay 가 짚은 한 줄 — “producer 가 SMEM 에 적은 후 consumer 가 register 로 옮길 때, layout swap 이 일어난다.” 추가 SMEM 트래픽이 없는 transpose. 비용은 register 단계의 인덱스 재계산뿐. Hopper 의 ldmatrix 변종이 이 swap 을 hardware path 로 지원.

이 트릭이 없으면 FP8 attention 의 throughput 이 큰 폭으로 떨어진다 — 강의의 측정으로 약 30% 까지. 같은 hardware 자원에서 단순 layout 정렬 한 줄이 그만큼의 차이를 만든다는 게 흥미로운 지점.

FP8 의 quality (정확도) 문제 — incoherent processing, smoothing, hadamard 변환 같은 PTQ 트릭 — 는 강의에서 이번 talk 의 범위가 아님 을 명시. attention 커널의 hardware 효율 차원만 다룸. quality 는 별도 paper.

§ 06backward 의 메모리 회계· recompute · stats

forward 가 풀린 자리에서 backward 가 다시 까다로워진다

FA forward 의 핵심 — attention matrix 를 HBM 에 안 적는다. backward 는 그 결정의 비용을 받는다 — gradient 를 계산하려면 P (attention prob) 가 필요한데, 다시 만들거나 stat 만 저장해서 partial 하게 복원해야 한다.

저장할 것크기backward 비용FA3 결정
전체 P (attention matrix)N² · b · hfree너무 큼
아무것도 저장 안 함0완전 recompute너무 느림
softmax stat (m, ℓ) per row2N · b · hQK^T recompute + 정확표준
stat + dropout mask+ N²/8dropout 정확option

FA 의 표준 결정은 softmax stat 만 저장. backward 는 그 stat 으로 P 를 정확히 복원할 수 있다 (online softmax 의 mathematical 성질). 강의에서 Jay 는 backward 의 H100 변형이 아직 진행형이라고 짚었다 — forward 만큼 깔끔한 producer-consumer schedule 을 짜기가 더 어렵다.

이유는 backward 의 data dependency 가 더 복잡하기 때문. forward 는 K/V 가 producer 의 입력, P/O 가 consumer 의 출력으로 단방향. backward 는 dQ/dK/dV 가 모두 같은 inner loop 에서 갱신되며 cross-row reduction 이 들어간다.

§ 07FA2 와의 차이를 한 줄로· multi-stage → warp-spec

같은 알고리즘 위에서, 한 thread block 의 내부 분업이 통째로 바뀌었다

알고리즘 차원 — online softmax + tile-based — 은 FA1 부터 같다. 바뀐 건 한 thread block 안에서 누가 뭘 하느냐. 그 표를 정리한다.

FA2 (Ampere)FA3 (Hopper)
load 단위cp.async per threadTMA per CTA
MMA 단위mma.sync per warpwgmma per warp group
overlap 패턴multi-stage (intra-warp)warp-specialized + ping-pong
register 분배warp 별 평등producer 가 양보 (setmaxnreg)
precisionFP16/BF16FP16/BF16/FP8
peak 의 비율A100 ~75% / H100 ~35%H100 ~85%
“같은 알고리즘인데 코드는 거의 다 다시 짜야 했다 — Hopper 의 prog 모델 자체가 다르니까.”Jay Shah · Q&A
§ 08CuTe layouts· tile · partition · copy

같은 SMEM 을 두 형태로 보는 트릭이 코드로 표현되는 자리

FA3 의 코드는 CUTLASS 위에 있고, 그 안의 layout 추상은 CuTe. shape 와 stride 를 한 객체로 묶어서, 같은 메모리에 대한 다른 view 를 정수 인덱스가 아니라 algebra 로 표현한다.

L0 · Layout (shape, stride)(M, N) : (1, M) — col-major. (N, M) : (M, 1) — row-major. 같은 메모리, 다른 view. 기본 type
L1 · Tensor data + Layout실제 메모리 포인터에 layout 을 묶어 1급 객체로 make_tensor
L2 · Tile / Partition 큰 tensor 를 작은 tile 로 자르고 thread 별 자리 결정local_tile, local_partition work decomposition
L3 · Copy copy(src, dst) — 가장 추상적인 형태로 모든 transfer 표현HBM→SMEM, SMEM→reg, reg→reg 모두 같은 함수 통일된 transfer
L4 · Atom / TiledCopy hardware instruction 까지 매핑SM90_TMA_LOAD, ldmatrix, cp.async — 같은 copy() 가 다른 atom 으로 specialize PTX 까지 내려간다

강의에서 Jay 가 강조한 한 줄 — “CuTe 의 layout algebra 를 손에 잡으면, FP8 의 in-kernel transpose 같은 트릭이 코드 두 줄로 표현된다. 그게 익숙하지 않으면 같은 트릭이 100 줄 인덱스 계산이 된다.” CuTe 자체가 한 강의 분량이고, 별도 자료로 깊게 봐야 한다 (cute docs).

§ 09hardware 가 어디까지 따라잡았나· 85% utilization

attention 이 GEMM 과 사실상 같은 수준의 효율로 도는 시점

강의의 마지막 측정 — FA3 가 BF16 에서 H100 peak 의 85% 까지 받는다. 같은 hardware 의 best gem 커널과 사실상 동등. 이게 “attention 도 이제 GEMM 만큼 효율적” 의 의미.

FA2 on H100 (BF16)~35% peak
FA3 on H100 (BF16)~85% peak
cuBLAS GEMM (BF16)~88% peak
FA3 on H100 (FP8)~75% peak
남는 15%

backward 가 forward 만큼 깔끔하게 안 되어 있다, FP8 의 in-kernel transpose 가 SMEM 트래픽을 약간 더 만든다, ping-pong 이 idealized 가 아니라 근사다 — 이런 자잘한 제약이 합쳐서 15% 가 남는다. backward 가 다음 큰 라운드의 작업.

“FA3 가 GEMM 과 동등 수준에 왔다는 건, attention 의 효율 게임이 사실상 끝났다는 뜻은 아니지만 — 한 시대의 종결.”학습 노트
§ 10기억할 메모와 코드· repo · paper

다시 열었을 때 5분 안에 손에 잡혀야 할 것

FA3 를 6개월 뒤 다시 마주했을 때 가장 빨리 복원해야 할 사실들과 — 직접 읽어볼 코드 위치들.

TMA · WGMMA
H100 의 두 비동기 자원. 단일 thread 또는 warp group 단위로 issue. 둘 다 동기 instruction 이 아닌 mbarrier 기반.
warp specialization
한 CTA 안에서 producer 1 + consumer 2 의 분업. mbarrier 로 stage 별 동기.
setmaxnreg.dec/inc
producer 가 register 를 deallocate, consumer 가 더 받음. spill 회피의 결정적 자리.
ping-pong scheduling
두 consumer 가 GEMM 과 softmax 를 교차. tensor core util 거의 100%.
in-kernel transpose
FP8 layout 제약을 SMEM 위에서 두 view 로 푼다. CuTe 가 표현.
backward 의 stat
softmax stat (m, ℓ) 만 저장. P 를 recompute. backward 의 H100 변형은 진행형.
CuTe layout algebra
(shape, stride) 의 한 줄 표현이 같은 SMEM 의 다른 view 들을 통일. 별도 학습 필요.
peak utilization
BF16 ~85%, FP8 ~75%. cuBLAS GEMM 과 동등 수준.

손에 새기기 — 실습 시퀀스

  1. FA2 → FA3 비교 측정 — 같은 H100 위에서 같은 (Q, K, V) 에 대해 FA2 와 FA3 forward 시간 비교. 차이가 위 표 (35% → 85%) 와 가까워야 정상.
  2. CUTLASS hello world — CUTLASS 의 example_xx 시리즈에서 H100 GEMM 을 빌드. WGMMA 가 코드 안에서 어떻게 issue 되는지 직접 본다.
  3. CuTe layout 연습 — make_layout, local_tile, local_partition 만 가지고 row-major / col-major 변환을 코드 두 줄로 표현. CuTe 의 정신이 손에 잡힐 때까지.
  4. FA3 의 forward 코드 읽기 — flash-attention repo 의 hopper/ 디렉토리. mainloop 함수를 따라가며 producer/consumer 분기, mbarrier 사용을 그림으로 옮긴다.
  5. NCU 로 떠보기 — ping-pong 이 실제로 일어나는지 metric 으로 확인 — tensor core util, sm__inst_executed_pipe_alu (softmax pipeline), warp scheduler stall reasons. 둘이 교차해야 한다.
  6. FP8 attention 의 quality — 강의 범위 밖이지만 짝지어 학습. incoherent processing, smoothing 의 PTQ 트릭들이 어떻게 perplexity 를 살리는지.
§ 11다른 강의로 이어지는 길· connections

FA3 의 도구가 시리즈 안에서 다시 등장하는 자리

FA3 의 producer-consumer 패턴이 다른 강의들의 어느 자리로 이어지는지.

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

강의에서 부분적으로만 등장한 주제, 후속 작업으로 비워둔 자리들.

검증 메모

이 노트의 percentage 수치(35%, 85%)는 강의 슬라이드를 재구성한 예시. 자기 H100 에서 직접 측정해야 baseline 이 의미를 가진다. 또 강의 시점 이후 FA 의 release 가 빠르게 갱신되고 있어 — version 별로 backward, FP8 path 의 상태를 매번 확인하는 게 현실적.

← Lecture 035 SGLang Lecture 037 → SASS & GPU Microarchitecture — Arun Demeure 가 깐 NVIDIA 머신코드의 자리