Umar Jamil 이 자기 손으로 한 CUDA → Triton → Flash Attention 학습의 연대기. attention 의 어디가 GPU 위에서 어렵고, 그 어려움이 online softmax 와 tile 단위 사고 로 어떻게 해소되는지 — 그리고 같은 길을 따라가려는 사람이 어디서 막히는지에 대한 학습 노트. 원본 자막이 실패해 본 페이지는 도메인 지식과 공개된 자료로 재구성한 노트다.
Umar Jamil 의 강의는 알고리즘 강의가 아니다 — 한 사람이 attention 을 바닥부터 자기 손으로 짤 때 어디에 부딪혔는가 의 회고에 가깝다. 그 회고가 의미 있는 이유는, 같은 길을 가려는 사람이 거의 똑같은 자리에 부딪히기 때문이다.
강의가 던지는 두 개의 질문.
이 노트는 원본 transcript 가 실패한 강의를 도메인 지식과 공개된 자료(Tri Dao 의 Flash Attention 논문 1·2·3, Umar 의 다른 YouTube 비디오, OpenAI Triton 튜토리얼)로 재구성한 학습 노트다. 강의 안에서 직접 확인되지 않은 주장은 본문에 “원본 영상 확인 필요” 표시로 남겨둔다.
같은 알고리즘이 — naive PyTorch attention, CUDA 직접 구현, Triton 구현 — 세 형태로 짜졌을 때, 무엇이 같고 무엇이 다른지를 본인의 학습 순서대로 따라간다. 그 차이의 핵심은 “메모리 위계 어디에 무엇이 사는가” 의 통제권 이전이다.
강의의 실질적 도착점은 — Tri Dao 의 Flash Attention 2 를 Triton 튜토리얼 형태로 다시 짠 코드를 한 줄씩 읽을 수 있는 상태다. 그 자리에 도착하기 위한 사다리가 §02 부터 §07 까지에 깔려 있다.
CUDA 를 처음 짜본 사람의 진술은 거의 일치한다 — 문법이 어려운 게 아니라 “내가 thread 한 명이라면 지금 무엇을 하고 있나” 의 1인칭 사고로 넘어가는 게 어렵다. PyTorch 의 batched 연산 사고와 정반대다.
강의에서 Umar 가 거쳤다고 알려진 첫 단계.
idx = blockIdx.x * blockDim.x + threadIdx.x 의 관용구를 손에 새긴다.row = blockIdx.y * blockDim.y + threadIdx.y 가 자연스러워질 때까지.__shared__, __syncthreads() 의 의미를 본인이 그림으로 그릴 수 있어야 한다.이 시퀀스는 PMPP (Programming Massively Parallel Processors) 책 1–5장과 거의 같다. L002 Andreas Köpf 의 PMPP 강의가 같은 자리를 다룬다.
// matmul tiled — CUDA 의 인지적 어려움이 한꺼번에 모이는 자리
__global__ void matmul_tiled(float* A, float* B,
float* C, int N) {
__shared__ float As[16][16];
__shared__ float Bs[16][16];
int row = blockIdx.y * 16 + threadIdx.y;
int col = blockIdx.x * 16 + threadIdx.x;
float acc = 0.0f;
for (int t = 0; t < N/16; ++t) {
As[threadIdx.y][threadIdx.x] =
A[row*N + t*16 + threadIdx.x];
Bs[threadIdx.y][threadIdx.x] =
B[(t*16 + threadIdx.y)*N + col];
__syncthreads();
for (int k = 0; k < 16; ++k)
acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
__syncthreads();
}
C[row*N + col] = acc;
}
이 코드 안에서 “같은 thread 가 한 iteration 에서 어떤 element 를 읽고 어떤 element 를 쓰는가” 를 머릿속에서 동시 추적할 수 있어야 한다. __syncthreads() 의 의미는 “이 줄이 끝날 때까지 같은 block 의 모든 thread 가 도착해야 한다”. 이걸 빼먹으면 race condition 이 일어나는데 — 가장 짜증나는 점은 대부분의 입력에서 정답이 나온다는 것. 무작위로 깨진다.
CUDA 의 첫 단계에서 학습자가 보통 6~8주 정도를 보낸다. Umar 도 같은 시기를 거쳤음을 자기 채널의 다른 영상에서 여러 번 언급한다. 이 단계의 끝에 도착하면 — 그제야 같은 알고리즘을 다르게 짜면 다르게 빠를 수 있다 의 직관이 잡힌다.
Triton 의 가장 큰 인지적 변화는 thread 가 사라진다는 점이다. 코드 안에서 thread 단위 indexing 을 직접 쓰지 않는다. 대신 tile (작은 행렬 블록) 을 통째로 다룬다 — tl.load, tl.dot, tl.store. compiler 가 thread 분배를 알아서 한다.
__shared__ 명시, __syncthreads() 명시. 사용자가 SRAM 의 모양과 위치를 통제. bank conflict 까지 사용자 책임.
통제권 강함
tl.load(ptrs, mask), tl.dot(a, b), tl.store. SRAM 사용은 compiler 가 결정. bank conflict 도 compiler 가 회피. 하지만 launch 설정(BLOCK_M/BLOCK_N/BLOCK_K, num_warps)에 결과가 민감.
생산성 ↑
A @ B 한 줄. 메모리 계층 통제 없음. 간단한 패턴은 torch.compile 이 fused Triton 으로 자동 lowering.
생산성 ↑↑
# Triton matmul — tile 단위로 사고
@triton.jit
def matmul_kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(A + offs_m[:,None]*stride_am
+ (k+offs_k)[None,:]*stride_ak)
b = tl.load(B + (k+offs_k)[:,None]*stride_bk
+ offs_n[None,:]*stride_bn)
acc += tl.dot(a, b)
tl.store(C + offs_m[:,None]*stride_cm
+ offs_n[None,:]*stride_cn, acc)
위 코드에서 사라진 것 들을 본다.
__shared__ 선언이 없다 — Triton 이 알아서 SRAM 으로 staging.__syncthreads() 가 없다 — tile 연산이 자동 동기화.threadIdx.x 같은 thread 단위 인덱스가 없다 — tl.arange 로 한 tile 의 인덱스 벡터를 통째로 다룬다.새로 생긴 것.
BLOCK_M / BLOCK_N / BLOCK_K 라는 launch 설정. 같은 코드가 설정에 따라 5–10배 차이가 난다.tl.constexpr 의 의미 — JIT 가 컴파일 시점에 값을 고정해서 unroll 할 수 있게.num_warps, num_stages — kernel decorator 의 hyperparameter. 이 자리에서 사람이 이해해야 하는 GPU 의 디테일이 다시 등장한다.코드가 짧아지는 것이 아니라 — bank conflict, register tiling, swizzling 같은 “GPU 디테일에서 오는 hard-to-debug 버그” 가 사라진다는 점이다. 학습자가 알고리즘 자체에 집중할 수 있게 된다. CUDA 학습 6주 + Triton 학습 1주가 합쳐져야 진짜 의미가 있다 — Triton 만 배우면 “왜 이 설정이 빠른가” 의 직관이 안 잡힌다.
Flash Attention 이 풀려고 하는 문제는 두 개다 — (1) attention matrix 가 N² 메모리를 먹는다는 점, (2) softmax 가 row 전체를 알아야 정규화된다는 점. 둘이 합쳐지면 — 큰 matrix 를 HBM 에 한 번 쓰고 다시 읽고, 또 쓰고 다시 읽는 패턴이 강제된다.
standard attention 의 메모리 패턴을 풀면 —
S = Q @ K.T → S 가 (N, N) HBM 에 쓰임P = softmax(S, dim=-1) → row max 와 sum 을 위해 S 를 다시 읽고, 또 다시 쓰고O = P @ V → P 를 다시 읽고, 결과 O 를 쓰고여기서 S 와 P 는 N=8192 만 되어도 64M 원소 — fp16 으로 128MB. 이게 HBM 을 왔다갔다하는 시간이 매트릭스 곱 자체보다 더 길다. memory-bound.
attention 은 FLOPs 의 알고리즘이라기보다 메모리 패턴의 알고리즘이다. 같은 FLOPs 를 계산하면서 HBM 왕복을 줄이면 거의 그대로 빨라진다. Flash Attention 1 의 주장이 정확히 이것 — “더 적은 일을 하는 것이 아니라, 같은 일을 더 적은 메모리 왕복으로.”
그런데 “SRAM 위에서 끝까지 돌게 한다” 가 실제로 만나는 첫 벽이 — softmax 의 정규화는 row 전체의 max 와 sum 을 안 다음에야 끝난다는 점이다. naive 하게는 (1) row 전체를 한 번 읽어서 max 구하고, (2) 다시 읽어서 exp(x-max) 의 sum 구하고, (3) 또 다시 읽어서 정규화한다 — 같은 row 를 세 번 읽는다. 한 tile 안에서 끝낼 수 없다. 이 자리를 푸는 게 online softmax다.
online softmax 의 아이디어는 단순하다 — 지금까지 본 max 와 지금까지의 (rescaled) sum 만 들고 다닌다. 새 tile 의 max 가 들어오면 둘 다 보정한다. 이게 numerically stable 하게 동작한다는 점이 50년 전부터 알려져 있었다 (Milakov & Gimelshein 2018 의 재정리, 더 거슬러 올라가면 LogSumExp 의 stable 계산).
m₀ = max(x₀), ℓ₀ = sum(exp(x₀ - m₀)). SRAM 에 둘만 들고 있는다.
m₁_local = max(x₁), ℓ₁_local = sum(exp(x₁ - m₁_local)) 계산.
m₁ = max(m₀, m₁_local). 기존 sum 을 새 max 기준으로 rescale: ℓ₀_new = ℓ₀ * exp(m₀ - m₁). 새 sum 도 마찬가지: ℓ₁ = ℓ₁_local * exp(m₁_local - m₁). 합친다: ℓ = ℓ₀_new + ℓ₁.
O = O * exp(m₀ - m₁) + tile₁ × V₁. 마지막 tile 후에 O / ℓ 한 번으로 정규화.
(m, ℓ, O) 세 개만 들고 다녔다.
각 단계의 max 보정이 지수 함수의 평행이동 이다. exp(x - m₁) = exp(x - m₀) * exp(m₀ - m₁). 항상 m 이 지금까지 본 최대보다 크거나 같으므로 x - m ≤ 0 — overflow 안 난다. 이 한 줄이 Flash Attention 전체를 가능하게 한다.
이 자리에서 강의의 한 가지 미세한 점 — “naive softmax 가 fp16 에서 overflow 나는 이유” 는 max-shift 가 빠진 형태이고, 그 보정 자체가 online softmax 의 식과 같다는 점. PyTorch 의 F.softmax 도 내부적으로 max-shift 를 한다. Umar 가 강의에서 이 점을 강조했을 가능성이 높다 (원본 영상 확인 필요).
Flash Attention 의 코드를 한 줄씩 읽으면 — 알고리즘을 자체로 새로 발명한 게 아니라, 같은 attention 을 “tile 사이로 무엇을 carry 할 것인가” 의 관점에서 다시 짠 것이라는 사실이 드러난다. 이 사고가 Triton 의 tile 추상과 정확히 맞물린다.
Flash Attention 2 forward 의 outer 루프는 query block 단위, inner 루프는 key/value block 단위. 한 query block 을 잡으면 SRAM 위에 그 q 를 박고, K/V 를 tile 씩 읽어들이면서 (m, ℓ, O) 를 갱신한다.
한 query block 의 일이 끝나면 (m, ℓ, O) 를 HBM 에 한 번 쓰고 다음 query block 으로. HBM 왕복이 query block 수 만큼만 일어난다 — naive 의 N² 항이 한 번 사라진 자리.
Flash Attention 1 은 K 를 outer 로 두었다. FA2 는 Q 를 outer 로 둔다 — 이유는 output O 가 query 마다 독립이고, query block 이 outer 일 때 다른 thread block 끼리 communicate 할 일이 없기 때문이다. 더 좋은 parallelism. backward 도 마찬가지로 분해 단위가 달라졌다.
# Flash Attention 2 forward 의 한 query block (의사 코드)
def fa2_forward_block(q_block, K, V, scale):
# SRAM 에 살아있는 state
m = -inf # running max
l = 0 # running sum
O = 0 # output
for kv_block in kv_blocks_in(K, V):
s = q_block @ kv_block.k.T * scale
m_new_local = s.max(dim=-1)
m_new = max(m, m_new_local)
alpha = exp(m - m_new)
l = l * alpha + exp(s - m_new).sum(dim=-1)
O = O * alpha + exp(s - m_new) @ kv_block.v
m = m_new
return O / l, m, l # (m,l) 은 backward 용
이 코드의 모든 변수는 SRAM 에 들어가는 작은 사이즈다. BLOCK_M = 64 ~ 128, head_dim ≤ 128 정도면 한 block 의 q, kv tile, m, l, O 가 모두 register/shared 에 들어간다. HBM 왕복은 query block 한 번 + 모든 kv block 의 한 번씩.
실제 FA2 Triton 구현에는 — causal mask 처리(상삼각만 계산), dropout, GQA(group query attention 의 K/V 공유), bias term 같은 것들이 더 들어간다. 핵심 outer/inner 루프 구조는 같다. Tri Dao 의 GitHub flash-attention repo 에서 flash_attn_triton.py 가 가장 읽기 좋은 reference.
forward 가 풀려도 backward 가 어렵다. 왜냐하면 attention matrix S, P 가 forward 에서 HBM 에 안 저장됐기 때문이다 — 그게 알고리즘의 포인트였으니까. 그러면 backward 는 어떻게? 다시 계산한다.
O, row-wise m (max), row-wise ℓ (sum). softmax 의 정규화 정보. 메모리 N×d + 2N — N² 가 아니다.
O(N·d)
P_ij = exp(s_ij - m_i) / ℓ_i 를 backward 에서 같은 tile 단위로 다시 계산. forward 의 m, ℓ 가 있어서 stable.
flops 두 배
backward 식을 풀면 (Tri Dao FA paper appendix 참고) —
이 식들은 다 같은 (m, ℓ) 정보로 P 를 다시 만들 수 있어야 stable 하게 돈다. 그래서 forward 가 (m, ℓ) 를 저장하는 것이 backward 의 전제다.
backward 의 정확한 식과 그 안의 trick — 특히 D = rowsum(dO ⊙ O) 라는 작은 보조 변수를 미리 계산해두면 식이 깔끔해진다는 점 — 은 Tri Dao FA1 paper Algorithm 4, FA2 paper Algorithm 2 에 있다. Umar 가 강의에서 이 식을 손으로 끌어냈을 가능성이 있다 (원본 영상 확인 필요). 만약 안 했다면 — 자기가 손으로 한 번 풀어보는 게 학습에서 가장 큰 단계.
강의 자체가 transcript 실패라 — 같은 길을 따라가려는 사람을 위해 자료 목록을 정리해둔다. Umar 의 채널이 가장 좋은 시작점이고, Tri Dao 의 두 paper 가 reference 다.
06-fused-attention.py 가 FA2 의 가장 읽기 좋은 reference 구현. 약 200줄.flash_attn_triton.py 가 Triton 버전.CUDA → Triton → FA 의 학습 시퀀스에서 학습자들이 똑같이 막히는 자리들. Umar 의 강의가 이 자리들을 명시적으로 짚었는지는 확실하지 않지만 (원본 확인 필요), 학습 노트로서 정리해둘 가치가 있다.
FA 를 자기 손으로 한 줄씩 짤 수 있게 되는 데 — CUDA 부터 시작해 평균 3~4개월. Umar 의 강의가 그 학습 곡선을 “한 사람의 회고” 로 보여주는 데 의미가 있다. 한 번 듣고 안 잡힌다고 자책하지 말 것.
강의의 detail 보다 — 이 노트가 의미를 지니려면 손에 박아야 할 짧은 사실들.
__shared__, __syncthreads() 직접 통제. 6+ 주 학습.tl.load, tl.dot, tl.store). compiler 가 thread 분배. launch 설정에 결과 민감.06-fused-attention.py 를 한 줄씩 주석.L050 의 학습 시퀀스가 시리즈의 어느 강의들과 직접 맞물리는지.
강의 자체의 transcript 가 실패해서, 본 노트는 도메인 지식과 공개된 자료로 재구성됐다. 다음에 영상을 직접 다시 본다면 확인할 자리들.
본 노트에서 등장하는 모든 식, 알고리즘 단계, 메모리 추정값 — 강의 자체의 직접 인용이 아니라 Tri Dao 의 paper 와 Triton tutorial 에서 파생된 재구성. Umar 의 영상에서 다른 강조점이 나왔을 수 있으므로 영상을 직접 본 사람이 노트를 보강하면 좋음.