gpumode · 강의 아카이브
《GPU Mode》 L012 2024 · MAR · 23 High priority transcript · available

Flash Attention

2022 년 Tri Dao 의 FlashAttention 논문 이후, Transformer 의 attention block 은 같은 수학을 GPU 위에서 다르게 구현한다. “N×N score 행렬을 절대 HBM 에 통째로 만들지 않는다” 는 한 줄짜리 원칙 뒤에 — online softmax, tile streaming, register accounting 의 3개 메커니즘이 깔려 있다. Thomas Viehmann 이 직접 짠 CUDA 구현을 반대로 해부하면서 그 원칙이 코드에 어떻게 떨어지는지를 깐다.

FlashAttention online softmax tile streaming HBM ↔ SRAM register spilling Tensor Core MMA CUDA C++ cuda-python FA v1 · v2
T
Speaker
Thomas Viehmann
PyTorch core · MathInf · GPU Mode mod
강의 번호
L012
스피커
Thomas Viehmann
학습 우선순위
High · 정독
코드
repo · flash_attention.cu
§ 01강의가 풀려는 문제· why this lecture exists

표준 attention 은 GPU 의 가장 약한 부분 (HBM bandwidth) 을 정확히 친다

표준 attention 의 수식 — O = softmax(QKT/√d) · V — 은 수학적으로 깔끔하다. 하지만 그 수식을 그대로 구현하면 중간 산출물 S = QKT 가 N×N 크기의 행렬이 되어 HBM 을 통째로 들었다 놨다 한다. seq_len 4K 에서 16M 원소, 32K 에서는 1G 원소.

강의가 풀려는 질문은 두 개다.

  1. 같은 수학적 결과를 어떻게 N×N 행렬을 만들지 않고 계산하는가 — online softmax 와 tile streaming 의 조합.
  2. 그 알고리즘을 CUDA 로 짜면 register / shared memory / Tensor Core 가 어떻게 분배되는가 — Thomas 가 직접 짠 flash_attention.cu (d=128, B_r=8, B_c=32) 의 해부.
강의의 인지적 frame

대부분의 다른 GPU Mode 강의가 “이미 있는 도구를 어떻게 부르는가” 라면 — 이 강의는 “알고리즘과 hardware 의 만남이 어떻게 새 implementation 을 강제하는가” 다. attention 의 수학은 안 바뀌었지만, GPU 의 모양 (HBM 너무 느림, SRAM 충분히 큼) 이 알고리즘을 다시 쓰게 만들었다. “hardware aware algorithms” 의 대표 사례.

“표준 attention 은 N² 의 메모리 트래픽으로, FlashAttention 은 N²/M 의 트래픽으로 같은 결과를 낸다 (M 은 SRAM 크기). 알고리즘이 같은 수학을 다른 비용으로 푼다.”학습 노트 paraphrase
§ 02큰 점수행렬을 만들지 마라· memory-bound attention

HBM 과 SRAM 의 한 자릿수 차이가 attention 을 뒤집는다

표준 attention 의 비용을 메모리 시각으로 분해 — 같은 dimension d=128, seq_len N=4K 의 한 head 에 대해.

FIG · A100 의 메모리 hierarchy대역폭 · 크기
HBMoff-chip DRAM
2.0 TB/s · 80 GB
slowest
L2 cacheon-chip
~5 TB/s · 40 MB
SRAMshared memory · L1
~19 TB/s · 192 KB/SM
~10× HBM
registerper-thread
>20 TB/s · 256 KB/SM
fastest
SRAM 이 HBM 보다 약 한 자릿수 빠르다. attention 이 memory-bound 일 때 — N×N score 행렬을 HBM 에 만들었다가 다시 읽는 비용이 전체 시간의 거의 전부. 만약 그 행렬을 만들지 않고, 입력 tile 을 SRAM 으로 한 번만 읽고 그 위에서 score + softmax + output 까지 끝낼 수 있다면 — HBM 트래픽이 N²/√M 단위로 떨어진다.

표준 attention 의 메모리 트래픽 분해 (Q, K, V, S, P, O 가 N×d, N×N, N×d 크기일 때, 한 번씩 R/W 한다고 가정).

  • Q, K, V read: 3 N·d
  • S = QKᵀ write:
  • P = softmax(S) read+write: 2 N²
  • O = PV write: N·d
  • 합계: ~3 N² + 4 N·d — N 이 커지면 N² 항이 dominate.

FlashAttention 의 메모리 트래픽 — Q 는 한 번씩, K/V 는 outer loop 만큼 반복.

  • Q read: N·d (한 번만 SRAM 으로)
  • K, V read (outer loop): 2 · ⌈N/B_r⌉ · N·d ≈ N²·d / B_r
  • O write: N·d
  • 합계: ~N²·d / B_r + 2 N·d — N² 의 계수가 d/B_r 로 깎인다.
§ 03online softmax 의 누산 패턴· running max + sum

전체 score 를 안 보고 softmax 를 어떻게 계산하는가 — running normalizer

FlashAttention 의 가장 영리한 부분 — softmax 가 “전체 row 를 봐야 normalize 가능” 한 op 인데, tile 단위로 보면서 동시에 정확한 normalizer 를 누산해 나간다. 핵심 도구는 “running max” 와 그에 따른 보정 항.

Numerical stability 를 위해 softmax 는 softmax(x_i) = exp(x_i − max(x)) / Σ exp(x_j − max(x)) 로 계산한다. 문제는 — tile 별로 보면 그 tile 안의 max 만 알고, 다음 tile 에서 더 큰 값이 나오면 이미 누산해 둔 값이 틀린 normalizer 로 계산된 것이 된다. 그래서 max 가 갱신될 때마다 이전 누산을 “보정”한다.

init m ← −∞ (running max)
l ← 0 (running sum of exp)
o ← 0 (running output)
m=−∞, l=0, o=0
tile 1 새 score 부분 s¹ 받음.
m_new ← max(m, max(s¹))
α ← exp(m − m_new) (이전 누산 보정 계수)
l ← α·l + Σ exp(s¹ − m_new)
o ← α·o + exp(s¹ − m_new) · v¹
m ← m_new
m=m¹, l=l¹, o=o¹
tile 2 새 score 부분 s² 받음.
m_new ← max(m, max(s²))
α ← exp(m − m_new)
l ← α·l + Σ exp(s² − m_new)
o ← α·o + exp(s² − m_new) · v²
m ← m_new
m=m², l=l², o=o²
모든 tile 을 처리하고 나면 m 은 row 전체의 max, l 은 row 전체의 정규화 분모와 같아진다. m=max, l=Σexp(s−max)
finalize O ← o / l
logsumexp ← m + log(l) (backward 에서 필요)
O 완성
왜 보정 계수가 정확한가

α = exp(m − m_new) 는 — 이전 단계에서 계산해둔 exp(s − m) 들이 새 max m_new 기준으로는 exp(s − m_new) = exp(s − m) · exp(m − m_new) = exp(s − m) · α 가 되어야 하기 때문. 전체 누산이 알고리즘적으로 lossless — softmax 의 평행이동 불변성이 정확히 들어맞는 자리.

“online softmax 는 사실 FlashAttention 보다 먼저 (Milakov & Gimelshein 2018) 등장했다. FA 의 contribution 은 그 온라인성을 attention 의 tile 구조와 합치는 자리다.”학습 노트 paraphrase
§ 04타일 스트리밍 재구성· B_r · B_c outer/inner loop

Q tile 을 outer 로 — 한 번 SRAM 에 올린 Q 위에서 K/V 를 흘린다

Thomas 의 코드 (flash_attention.cu) 의 핵심 구조 — d=128, B_r=8 (Q tile rows), B_c=32 (K/V tile cols). 이 숫자들이 어떻게 정해지고 어떤 의미인지.

FIG · Q outer × K/V inner — 한 block 이 한 Q tile 을 담당B_r=8, B_c=32, d=128
Q₀
K₀
K₁
K₂
K₃
Q₁
Q₂
Q₃
한 thread block 이 한 행 (한 Q tile 8×128) 을 담당. block 안에서 K₀, K₁, K₂, K₃, … 을 차례로 SRAM 으로 가져와 score · softmax 누산 · O 갱신을 반복. Q 는 처음 한 번만 SRAM 에 올라온다.

코드의 메모리 레이아웃 (Thomas 의 flash_attention.cu).

// shared memory: SRAM 에 올라가는 것들
__shared__ float Q_i[B_r][d];      // 8×128 = 4 KB
__shared__ float K_j[B_c][d];      // 32×128 = 16 KB
__shared__ float V_j[B_c][d];      // 32×128 = 16 KB
__shared__ float S[B_r][B_c];      // 8×32   = 1 KB

// register: thread 별 누산 상태
float l_i[B_r_over_bdy];        // running sum
float m_i[B_r_over_bdy];        // running max
float O_i[B_r_over_bdy][d_over_bdx];  // running output
// outer loop: Q tile 한 번 (per block)
for (int i = bix; i < bix + 1; i++) {
    // Q_i 를 한 번 shmem 으로
    load_Q(Q_i, Q + i * B_r);

    // inner loop: K/V tiles
    for (int j = 0; j < T_c; j++) {
        load_K_V(K_j, V_j, j);
        __syncthreads();

        // 1) S = scaling * Q @ K^T
        compute_S(S, Q_i, K_j, scaling);
        __syncthreads();

        // 2) online softmax 갱신 + O 누산
        update_running_state(m_i, l_i, O_i, S, V_j);
    }

    // 3) 최종 O = O_i / l_i, store l_i 도 (backward 위해)
    write_output(out, out_l, O_i, l_i, m_i, i);
}
B_r 과 B_c 가 결정되는 자리

B_c = 32 는 SRAM 용량과 K/V tile 사이즈의 trade-off. K/V 두 tile 이 16 KB + 16 KB = 32 KB 를 잡으니까 192 KB 의 SRAM 에 충분. B_r = 8 은 register 용량으로 결정 — 한 thread 가 누산할 O_i 의 행 수가 너무 크면 register spill 이 일어난다 (§ 05).

§ 05register accounting· spilling 진단

register 가 모자라면 “spill” — 그 순간 SRAM 이득이 다 사라진다

Thomas 가 강의에서 가장 중요하게 보여준 진단 자리 — “register spilling from registers”. running 누산기 (m_i, l_i, O_i) 가 register 에 들어가는 게 FA 의 메모리 모델의 전제다. 만약 thread 당 register 가 부족하면 그 누산기들이 local memory (실제로는 HBM-backed) 로 spill 되고 — 이때 모든 IO advantage 가 사라진다.

Thomas 가 본 spill 의 신호.

  • 동일 코드인데 특정 d 또는 B_r 에서만 두세 배 느려짐 — register limit 을 넘는 임계점 직후.
  • NCU 의 “Local Memory Overhead” 가 0 이 아닌 값으로 나옴.
  • PTX 안에 ld.local, st.local 명령이 등장.

강의 repo 에 flash_attention_spilling_from_registers.cu 라는 파일이 따로 있는데 — 같은 알고리즘을 register 가 모자라게 짠 “일부러 안 좋은” 버전. profile 비교용.

register 사용 줄이는 도구들

(1) __launch_bounds__(...) 로 nvcc 에 thread 당 register 한계 hint, (2) loop 을 unroll 안 함, (3) 누산 차원 줄이기 (B_r 줄임), (4) 일부를 shared memory 에 두기 (대신 SRAM 압박). 이 trade-off 가 FA 구현의 가장 까다로운 자리.

진단 명령

nvcc -Xptxas -v --resource-usage — 컴파일 시 thread 당 register 사용량과 spill 바이트 수를 직접 출력. spill bytes > 0 이면 거의 항상 문제. PTX 도 같이 dump 해서 어디서 spill 되는지 추적 가능.

“FA 는 SRAM 에서 도는 알고리즘이지만, 사실 가장 자주 만나는 함정은 register 다. running 누산기 한 개가 register 에서 떨어지면 전체가 붕괴.”Thomas Viehmann · 강의 paraphrase
§ 06backward 의 메모리 회계· recompute · L 저장

backward 는 score 를 어떻게 다시 만드는가 — N² 메모리 없이

FA 의 forward 가 N×N score 행렬을 안 만들고 했다면 — backward 도 같은 제약 안에서 풀려야 한다. 하지만 backward 에는 forward 의 attention probabilities P 가 다시 필요하다. 어떻게 푸는가.

핵심 idea — attention probabilities 를 저장하지 말고 다시 계산한다 (recompute). 그 대신 forward 에서 logsumexp = m + log(l) 한 줄을 저장. 이 한 줄만 있으면 backward 에서 P_ij = exp(s_ij − logsumexp_i) 로 정확한 P 를 빠르게 복원할 수 있다.

저장량 비교 — 표준 attention backward 가 N² 의 P 를 저장해야 한다면, FA backward 는 N 크기의 logsumexp 한 줄만 저장. 메모리가 N² → N 로 줄어든다.

대신 backward 의 compute 는 forward 와 비슷한 수준 (recompute 가 들어가니까). 결과: backward 가 forward 의 약 2× 시간. 표준 attention 대비 절대 시간은 비슷하거나 더 빠른 경우가 많다 — HBM 트래픽이 dominate 하기 때문.

// forward 에서 저장하는 한 줄
out_l[i] = m_i + logf(l_i);   // = logsumexp_i

// backward 에서 P 복원 (tile 단위)
float s_ij = scaling * Q_i[ii] @ K_j[jj];
float P_ij = expf(s_ij - out_l[i]);   // 정확

// 그 P_ij 로 dQ, dK, dV 누산
dV_j   += P_ij * dO_i;
dP_ij   = dO_i @ V_j;
dS_ij   = P_ij * (dP_ij - /* row reduction */);
dQ_i   += scaling * dS_ij @ K_j;
dK_j   += scaling * dS_ij^T @ Q_i;
PyTorch 와의 통합

PyTorch 2.0 의 F.scaled_dot_product_attention 은 가능하면 FA-2 backend 로 자동 dispatch 한다. 사용자 코드를 안 바꿔도 — 같은 코드가 환경에 따라 표준/FA 두 implementation 사이에서 선택된다. 강제로 FA 를 끄고 비교하려면 torch.backends.cuda.sdp_kernel(enable_flash=False).

§ 07IO complexity 의 새 모델· SRAM-aware

“FLOPs 가 같아도 algorithm 이 다르다” — IO 복잡도라는 새 metric

FA 가 학계에 준 가장 큰 영향 중 하나 — algorithm 의 비용을 측정할 때 FLOPs 만 보는 게 아니라 “HBM ↔ SRAM 트래픽”도 별도 metric 으로 본다 는 시각.

FIG · 표준 vs FA — 같은 FLOPs, 다른 IO 트래픽seq_len 4K, d=128
표준 attentionS, P 를 HBM 에
~50 MB · 100% baseline
~ 4 ms
FlashAttention v1tile streaming
~12 MB · 25%
~ 1.0 ms
FlashAttention v2work partitioning 개선
~9 MB · 18%
~ 0.7 ms
FLOPs 는 모두 같다 (수학적 결과가 같으니까). 차이는 전적으로 메모리 트래픽. 강의 시점 (FA-3 출현 전) 의 추정치.
IO complexity formula

논문의 정식 표현 — FA 의 HBM accesses = O(N²d²/M), M 은 SRAM 크기. 표준 attention 은 O(N²) 이상. d=128, M ≈ 100KB 이면 d²/M ≈ 0.16 — 이론상 6× 적은 IO. 실측은 그 수치 근처에서 도착한다.

§ 08FA v1 → v2 의 차이· work partitioning

같은 알고리즘 — 다른 work partition. 그게 throughput 을 두 배 키웠다

FA 논문이 두 차례 나왔다. v1 (Dao 2022) 은 위에 깐 알고리즘. v2 (Dao 2023) 는 같은 알고리즘이지만 — “누가 어떤 일을 하느냐” 의 work partition 을 다시 짠다.

v1 의 한계 — outer loop 이 K/V 위에 있고 inner loop 이 Q 위에 있었다. 즉 한 thread block 이 한 K/V tile 을 잡고 모든 Q 를 훑는다. 결과: 같은 K/V 가 SRAM 에 한 번 올라오면 좋다 — 하지만 thread block 들 사이의 reduction 이 필요해서 추가 sync.

v2 의 변경 — outer loop 을 Q 로 바꿈. 한 block 이 한 Q tile 을 잡고 K/V 를 streaming. 그래서:

  • block 사이 reduction 필요 없음 — 각 block 이 한 Q tile 의 출력을 완전히 만든다.
  • 병렬도 ↑ — Q 차원과 batch · head 차원을 모두 grid 로 사용 가능.
  • warp 분배 개선 — 한 block 안의 warp 들이 column 으로 나뉘어 inner reduction 을 같이 함.

강의의 flash_attention.cu 가 사실은 v2 의 work partitioning 을 따르는 형태 — Q 가 outer.

v2 의 throughput 효과

같은 GPU 위에서 v1 대비 v2 가 약 2× throughput. 일부 길이에서는 더 많이. 알고리즘은 안 바뀌었지만 partitioning 만 바뀌었다 — “같은 알고리즘도 work distribution 으로 두 배 빨라질 수 있다” 의 좋은 사례.

FA-3 의 등장 (강의 이후)

강의 시점 (2024 March) 에는 v2 가 최신. 이후 FA-3 (Hopper-aware) 가 나오면서 H100 의 TMA, WGMMA, asynchronous warpgroup 같은 새 hardware 기능을 활용. 강의 시점 이후 자료. 별도로 추적 필요.

§ 09hardware 가 따라잡는 길· FA-3 · TMA · WGMMA

알고리즘이 hardware 를 끌고 갔다 — H100 의 새 instruction 들

강의가 끝나면서 Thomas 가 짚는 미래 방향 — FlashAttention 의 알고리즘 패턴이 NVIDIA 의 다음 hardware 디자인을 끌었다. H100 의 새 instruction 들이 사실상 “FA 를 빠르게 돌리려면 무엇이 필요한가” 의 답으로 설계되어 있다.

TMA (Tensor Memory Accelerator)
SRAM ↔ HBM tile copy 를 하드웨어 비동기 엔진으로. FA 의 inner loop K/V 로드를 compute 와 overlap. host thread 가 직접 copy 를 발행 안 함.
WGMMA (warpgroup MMA)
4개 warp 이 한 group 으로 큰 MMA 를 비동기 발행. async barrier 와 함께 쓰면 다음 K/V 를 load 하는 동안 현재 score MMA 가 도는 구조.
async barrier
__syncthreads() 의 비동기 변형. producer/consumer 패턴을 명시적으로 — load 끝났음을 신호하고 compute 를 기다린다.
FP8 + Hopper
attention 의 K, V 를 FP8 로 양자화하고 누산은 FP32 로. throughput 2× 의 추가 잠재력 — accuracy 손실 거의 없음.
“FA-1 은 SRAM 인지를, FA-2 는 work partition 을, FA-3 은 hardware 의 비동기성을 끌어들였다 — 같은 알고리즘이 hardware 와 함께 점점 더 정교해진다.”학습 노트 paraphrase
§ 10기억할 메모와 코드· key takeaways · repo

다시 열었을 때 5분 안에 손에 잡혀야 할 것

attention 위에서 다시 일을 하게 됐을 때 — 이 핵심들이 빠르게 복원되어야 한다.

N² 행렬 회피 원칙
Q, K, V 만 HBM 에 두고 score S 와 prob P 는 SRAM 안에서 만들고 버린다. backward 에서는 logsumexp 한 줄만 저장.
online softmax 패턴
running max m, running sum l, running output o. tile 마다 m_new ← max(m, max(s)) 한 뒤 α=exp(m−m_new) 로 이전 누산 보정.
tile streaming 구조
FA-v2: outer Q, inner K/V. 한 block 이 한 Q tile 의 출력을 완전히 만든다. block 간 reduction 없음.
register accounting
m_i, l_i, O_i 가 register 에 있어야 algorithm 이 의미를 가진다. spill 한 번에 IO 이득이 사라짐. nvcc -Xptxas -v 로 진단.
B_r, B_c, d 의 trade-off
B_c 는 SRAM 으로, B_r 은 register 로 결정. d 가 커지면 B_r 작게 — 같은 SM 안에 들어갈 수 있도록.
backward = recompute
forward 에서 logsumexp 만 저장 (N 크기). backward 에서 P 를 다시 만들고 dQ, dK, dV 누산. 메모리 N²→N.
IO complexity
FA: O(N²d²/M). 표준: Ω(N²). FLOPs 는 같지만 메모리 트래픽이 d²/M 배 적다.
PyTorch 통합
F.scaled_dot_product_attention 가 자동 dispatch. sdp_kernel 컨텍스트로 백엔드 선택.
Slides repo 에 별도 슬라이드 없음 — 강의는 거의 전적으로 코드 + 라이브 실행

손에 새기기 — 실습 시퀀스

  1. online softmax 한 페이지 짜기 — Python/numpy 로 chunk 단위 softmax 가 전체 softmax 와 같은지 verify. § 03 의 누산 표를 코드로 그대로.
  2. flash_attention.cu 컴파일 + 실행 — Thomas 의 코드를 nvcc -Xptxas -v 와 함께 빌드. register 사용량 출력 확인. 작은 N 에서 표준 attention 과 결과 일치 확인.
  3. spilling 비교flash_attention_spilling_from_registers.cu 와 정상 버전을 같은 입력으로 실행. NCU 로 “Local Memory Overhead” 비교.
  4. B_r, B_c sweep — B_r ∈ {4, 8, 16}, B_c ∈ {16, 32, 64} 의 9 조합으로 시간 측정. spill 시작점과 SRAM 한계점 시각화.
  5. F.scaled_dot_product_attention vs eagersdp_kernel(enable_flash=True/False) 컨텍스트로 같은 입력의 forward+backward 시간 비교. 모델 크기·seq_len 별 cross-over 점 찾기.
  6. backward 의 logsumexp 활용 — forward 에서 저장한 logsumexp 한 줄로 P 를 복원하는 mini-script. 정확한 backward 결과와 일치 검증.
  7. IO complexity 측정 — NCU 의 “DRAM Read/Write Bytes” 메트릭으로 표준 vs FA 의 실측 트래픽 비교. 이론값과의 차이 분석.
§ 11다른 강의로 이어지는 길· connections

이 강의의 도구가 다음에 어디에 다시 등장하는지

FA 의 메커니즘이 이후 강의들의 핵심 패턴을 결정한다.

§ 12열린 질문· open questions

다음에 다시 들었을 때 직접 검증해야 할 것들

FA 는 빠르게 진화하는 자리라 — 강의 시점 (2024 March) 이후의 변화가 크다.

검증 메모

이 노트의 timing 수치 (~4 ms, ~1.0 ms) 는 강의 패턴 + 일반적 A100 측정의 paraphrase. 자기 GPU 와 자기 length 에서 직접 측정해야 baseline 이 잡힌다. F.scaled_dot_product_attention 의 sdp_kernel 컨텍스트가 가장 실용적.

← Lecture 011 Sparsity — hardware 모양이 algorithm 을 결정한 또 다른 사례 Lecture 013 → Ring Attention — online softmax 를 multi-GPU 로 확장한 그 다음 자리