2022 년 Tri Dao 의 FlashAttention 논문 이후, Transformer 의 attention block 은 같은 수학을 GPU 위에서 다르게 구현한다. “N×N score 행렬을 절대 HBM 에 통째로 만들지 않는다” 는 한 줄짜리 원칙 뒤에 — online softmax, tile streaming, register accounting 의 3개 메커니즘이 깔려 있다. Thomas Viehmann 이 직접 짠 CUDA 구현을 반대로 해부하면서 그 원칙이 코드에 어떻게 떨어지는지를 깐다.
표준 attention 의 수식 — O = softmax(QKT/√d) · V — 은 수학적으로 깔끔하다. 하지만 그 수식을 그대로 구현하면 중간 산출물 S = QKT 가 N×N 크기의 행렬이 되어 HBM 을 통째로 들었다 놨다 한다. seq_len 4K 에서 16M 원소, 32K 에서는 1G 원소.
강의가 풀려는 질문은 두 개다.
flash_attention.cu (d=128, B_r=8, B_c=32) 의 해부.대부분의 다른 GPU Mode 강의가 “이미 있는 도구를 어떻게 부르는가” 라면 — 이 강의는 “알고리즘과 hardware 의 만남이 어떻게 새 implementation 을 강제하는가” 다. attention 의 수학은 안 바뀌었지만, GPU 의 모양 (HBM 너무 느림, SRAM 충분히 큼) 이 알고리즘을 다시 쓰게 만들었다. “hardware aware algorithms” 의 대표 사례.
표준 attention 의 비용을 메모리 시각으로 분해 — 같은 dimension d=128, seq_len N=4K 의 한 head 에 대해.
표준 attention 의 메모리 트래픽 분해 (Q, K, V, S, P, O 가 N×d, N×N, N×d 크기일 때, 한 번씩 R/W 한다고 가정).
FlashAttention 의 메모리 트래픽 — Q 는 한 번씩, K/V 는 outer loop 만큼 반복.
FlashAttention 의 가장 영리한 부분 — softmax 가 “전체 row 를 봐야 normalize 가능” 한 op 인데, tile 단위로 보면서 동시에 정확한 normalizer 를 누산해 나간다. 핵심 도구는 “running max” 와 그에 따른 보정 항.
Numerical stability 를 위해 softmax 는 softmax(x_i) = exp(x_i − max(x)) / Σ exp(x_j − max(x)) 로 계산한다. 문제는 — tile 별로 보면 그 tile 안의 max 만 알고, 다음 tile 에서 더 큰 값이 나오면 이미 누산해 둔 값이 틀린 normalizer 로 계산된 것이 된다. 그래서 max 가 갱신될 때마다 이전 누산을 “보정”한다.
α = exp(m − m_new) 는 — 이전 단계에서 계산해둔 exp(s − m) 들이 새 max m_new 기준으로는 exp(s − m_new) = exp(s − m) · exp(m − m_new) = exp(s − m) · α 가 되어야 하기 때문. 전체 누산이 알고리즘적으로 lossless — softmax 의 평행이동 불변성이 정확히 들어맞는 자리.
Thomas 의 코드 (flash_attention.cu) 의 핵심 구조 — d=128, B_r=8 (Q tile rows), B_c=32 (K/V tile cols). 이 숫자들이 어떻게 정해지고 어떤 의미인지.
코드의 메모리 레이아웃 (Thomas 의 flash_attention.cu).
// shared memory: SRAM 에 올라가는 것들
__shared__ float Q_i[B_r][d]; // 8×128 = 4 KB
__shared__ float K_j[B_c][d]; // 32×128 = 16 KB
__shared__ float V_j[B_c][d]; // 32×128 = 16 KB
__shared__ float S[B_r][B_c]; // 8×32 = 1 KB
// register: thread 별 누산 상태
float l_i[B_r_over_bdy]; // running sum
float m_i[B_r_over_bdy]; // running max
float O_i[B_r_over_bdy][d_over_bdx]; // running output
// outer loop: Q tile 한 번 (per block)
for (int i = bix; i < bix + 1; i++) {
// Q_i 를 한 번 shmem 으로
load_Q(Q_i, Q + i * B_r);
// inner loop: K/V tiles
for (int j = 0; j < T_c; j++) {
load_K_V(K_j, V_j, j);
__syncthreads();
// 1) S = scaling * Q @ K^T
compute_S(S, Q_i, K_j, scaling);
__syncthreads();
// 2) online softmax 갱신 + O 누산
update_running_state(m_i, l_i, O_i, S, V_j);
}
// 3) 최종 O = O_i / l_i, store l_i 도 (backward 위해)
write_output(out, out_l, O_i, l_i, m_i, i);
}
B_c = 32 는 SRAM 용량과 K/V tile 사이즈의 trade-off. K/V 두 tile 이 16 KB + 16 KB = 32 KB 를 잡으니까 192 KB 의 SRAM 에 충분. B_r = 8 은 register 용량으로 결정 — 한 thread 가 누산할 O_i 의 행 수가 너무 크면 register spill 이 일어난다 (§ 05).
Thomas 가 강의에서 가장 중요하게 보여준 진단 자리 — “register spilling from registers”. running 누산기 (m_i, l_i, O_i) 가 register 에 들어가는 게 FA 의 메모리 모델의 전제다. 만약 thread 당 register 가 부족하면 그 누산기들이 local memory (실제로는 HBM-backed) 로 spill 되고 — 이때 모든 IO advantage 가 사라진다.
Thomas 가 본 spill 의 신호.
ld.local, st.local 명령이 등장.강의 repo 에 flash_attention_spilling_from_registers.cu 라는 파일이 따로 있는데 — 같은 알고리즘을 register 가 모자라게 짠 “일부러 안 좋은” 버전. profile 비교용.
(1) __launch_bounds__(...) 로 nvcc 에 thread 당 register 한계 hint, (2) loop 을 unroll 안 함, (3) 누산 차원 줄이기 (B_r 줄임), (4) 일부를 shared memory 에 두기 (대신 SRAM 압박). 이 trade-off 가 FA 구현의 가장 까다로운 자리.
nvcc -Xptxas -v --resource-usage — 컴파일 시 thread 당 register 사용량과 spill 바이트 수를 직접 출력. spill bytes > 0 이면 거의 항상 문제. PTX 도 같이 dump 해서 어디서 spill 되는지 추적 가능.
FA 의 forward 가 N×N score 행렬을 안 만들고 했다면 — backward 도 같은 제약 안에서 풀려야 한다. 하지만 backward 에는 forward 의 attention probabilities P 가 다시 필요하다. 어떻게 푸는가.
핵심 idea — attention probabilities 를 저장하지 말고 다시 계산한다 (recompute). 그 대신 forward 에서 logsumexp = m + log(l) 한 줄을 저장. 이 한 줄만 있으면 backward 에서 P_ij = exp(s_ij − logsumexp_i) 로 정확한 P 를 빠르게 복원할 수 있다.
저장량 비교 — 표준 attention backward 가 N² 의 P 를 저장해야 한다면, FA backward 는 N 크기의 logsumexp 한 줄만 저장. 메모리가 N² → N 로 줄어든다.
대신 backward 의 compute 는 forward 와 비슷한 수준 (recompute 가 들어가니까). 결과: backward 가 forward 의 약 2× 시간. 표준 attention 대비 절대 시간은 비슷하거나 더 빠른 경우가 많다 — HBM 트래픽이 dominate 하기 때문.
// forward 에서 저장하는 한 줄
out_l[i] = m_i + logf(l_i); // = logsumexp_i
// backward 에서 P 복원 (tile 단위)
float s_ij = scaling * Q_i[ii] @ K_j[jj];
float P_ij = expf(s_ij - out_l[i]); // 정확
// 그 P_ij 로 dQ, dK, dV 누산
dV_j += P_ij * dO_i;
dP_ij = dO_i @ V_j;
dS_ij = P_ij * (dP_ij - /* row reduction */);
dQ_i += scaling * dS_ij @ K_j;
dK_j += scaling * dS_ij^T @ Q_i;
PyTorch 2.0 의 F.scaled_dot_product_attention 은 가능하면 FA-2 backend 로 자동 dispatch 한다. 사용자 코드를 안 바꿔도 — 같은 코드가 환경에 따라 표준/FA 두 implementation 사이에서 선택된다. 강제로 FA 를 끄고 비교하려면 torch.backends.cuda.sdp_kernel(enable_flash=False).
FA 가 학계에 준 가장 큰 영향 중 하나 — algorithm 의 비용을 측정할 때 FLOPs 만 보는 게 아니라 “HBM ↔ SRAM 트래픽”도 별도 metric 으로 본다 는 시각.
논문의 정식 표현 — FA 의 HBM accesses = O(N²d²/M), M 은 SRAM 크기. 표준 attention 은 O(N²) 이상. d=128, M ≈ 100KB 이면 d²/M ≈ 0.16 — 이론상 6× 적은 IO. 실측은 그 수치 근처에서 도착한다.
FA 논문이 두 차례 나왔다. v1 (Dao 2022) 은 위에 깐 알고리즘. v2 (Dao 2023) 는 같은 알고리즘이지만 — “누가 어떤 일을 하느냐” 의 work partition 을 다시 짠다.
v1 의 한계 — outer loop 이 K/V 위에 있고 inner loop 이 Q 위에 있었다. 즉 한 thread block 이 한 K/V tile 을 잡고 모든 Q 를 훑는다. 결과: 같은 K/V 가 SRAM 에 한 번 올라오면 좋다 — 하지만 thread block 들 사이의 reduction 이 필요해서 추가 sync.
v2 의 변경 — outer loop 을 Q 로 바꿈. 한 block 이 한 Q tile 을 잡고 K/V 를 streaming. 그래서:
강의의 flash_attention.cu 가 사실은 v2 의 work partitioning 을 따르는 형태 — Q 가 outer.
같은 GPU 위에서 v1 대비 v2 가 약 2× throughput. 일부 길이에서는 더 많이. 알고리즘은 안 바뀌었지만 partitioning 만 바뀌었다 — “같은 알고리즘도 work distribution 으로 두 배 빨라질 수 있다” 의 좋은 사례.
강의 시점 (2024 March) 에는 v2 가 최신. 이후 FA-3 (Hopper-aware) 가 나오면서 H100 의 TMA, WGMMA, asynchronous warpgroup 같은 새 hardware 기능을 활용. 강의 시점 이후 자료. 별도로 추적 필요.
강의가 끝나면서 Thomas 가 짚는 미래 방향 — FlashAttention 의 알고리즘 패턴이 NVIDIA 의 다음 hardware 디자인을 끌었다. H100 의 새 instruction 들이 사실상 “FA 를 빠르게 돌리려면 무엇이 필요한가” 의 답으로 설계되어 있다.
__syncthreads() 의 비동기 변형. producer/consumer 패턴을 명시적으로 — load 끝났음을 신호하고 compute 를 기다린다.attention 위에서 다시 일을 하게 됐을 때 — 이 핵심들이 빠르게 복원되어야 한다.
nvcc -Xptxas -v 로 진단.F.scaled_dot_product_attention 가 자동 dispatch. sdp_kernel 컨텍스트로 백엔드 선택.nvcc -Xptxas -v 와 함께 빌드. register 사용량 출력 확인. 작은 N 에서 표준 attention 과 결과 일치 확인.flash_attention_spilling_from_registers.cu 와 정상 버전을 같은 입력으로 실행. NCU 로 “Local Memory Overhead” 비교.sdp_kernel(enable_flash=True/False) 컨텍스트로 같은 입력의 forward+backward 시간 비교. 모델 크기·seq_len 별 cross-over 점 찾기.FA 의 메커니즘이 이후 강의들의 핵심 패턴을 결정한다.
FA 는 빠르게 진화하는 자리라 — 강의 시점 (2024 March) 이후의 변화가 크다.
이 노트의 timing 수치 (~4 ms, ~1.0 ms) 는 강의 패턴 + 일반적 A100 측정의 paraphrase. 자기 GPU 와 자기 length 에서 직접 측정해야 baseline 이 잡힌다. F.scaled_dot_product_attention 의 sdp_kernel 컨텍스트가 가장 실용적.