gpumode · 강의 아카이브
《GPU Mode》 L050 2024 High priority transcript · failed

A learning journey: CUDA, Triton, Flash Attention

Umar Jamil 이 자기 손으로 한 CUDA → Triton → Flash Attention 학습의 연대기. attention 의 어디가 GPU 위에서 어렵고, 그 어려움이 online softmaxtile 단위 사고 로 어떻게 해소되는지 — 그리고 같은 길을 따라가려는 사람이 어디서 막히는지에 대한 학습 노트. 원본 자막이 실패해 본 페이지는 도메인 지식과 공개된 자료로 재구성한 노트다.

CUDA Triton Flash Attention online softmax tile programming async copy SRAM backward pass
U
Speaker
Umar Jamil
YouTube ML educator · CUDA/Triton from-scratch tutorials
강의 번호
L050
스피커
Umar Jamil
Transcript
failed · 본 노트는 재구성
학습 우선순위
High · 흐름 따라가기
§ 01강의가 풀려는 문제· why this lecture exists

“CUDA 부터 짜야 하나, Triton 부터 짜야 하나” 의 학습 순서를 직접 걸어본 노트

Umar Jamil 의 강의는 알고리즘 강의가 아니다 — 한 사람이 attention 을 바닥부터 자기 손으로 짤 때 어디에 부딪혔는가 의 회고에 가깝다. 그 회고가 의미 있는 이유는, 같은 길을 가려는 사람이 거의 똑같은 자리에 부딪히기 때문이다.

강의가 던지는 두 개의 질문.

  1. CUDA 와 Triton 사이의 인지적 거리 — 학습자가 한쪽에서 다른쪽으로 넘어갈 때 무엇을 새로 배우고 무엇을 버려야 하는가.
  2. Flash Attention 이 “어려운” 진짜 이유 — 알고리즘이 복잡해서가 아니라 메모리 계층수학적 안정성autograd 회계 가 동시에 얽히기 때문이다.

이 노트는 원본 transcript 가 실패한 강의를 도메인 지식과 공개된 자료(Tri Dao 의 Flash Attention 논문 1·2·3, Umar 의 다른 YouTube 비디오, OpenAI Triton 튜토리얼)로 재구성한 학습 노트다. 강의 안에서 직접 확인되지 않은 주장은 본문에 “원본 영상 확인 필요” 표시로 남겨둔다.

강의의 인지적 frame

같은 알고리즘이 — naive PyTorch attention, CUDA 직접 구현, Triton 구현 — 세 형태로 짜졌을 때, 무엇이 같고 무엇이 다른지를 본인의 학습 순서대로 따라간다. 그 차이의 핵심은 “메모리 위계 어디에 무엇이 사는가” 의 통제권 이전이다.

“CUDA 는 thread 의 언어다. Triton 은 tile 의 언어다. 같은 알고리즘이 두 언어에서 서로 다른 모양으로 분해된다.” 학습 노트 · 재구성

강의의 실질적 도착점은 — Tri Dao 의 Flash Attention 2 를 Triton 튜토리얼 형태로 다시 짠 코드를 한 줄씩 읽을 수 있는 상태다. 그 자리에 도착하기 위한 사다리가 §02 부터 §07 까지에 깔려 있다.

§ 02CUDA 의 첫 인상· grid · block · thread 의 벽

“thread 한 명의 입장” 으로 코드를 짜는 사고가 가장 어렵다

CUDA 를 처음 짜본 사람의 진술은 거의 일치한다 — 문법이 어려운 게 아니라 “내가 thread 한 명이라면 지금 무엇을 하고 있나” 의 1인칭 사고로 넘어가는 게 어렵다. PyTorch 의 batched 연산 사고와 정반대다.

강의에서 Umar 가 거쳤다고 알려진 첫 단계.

  1. vector add — 가장 작은 커널. idx = blockIdx.x * blockDim.x + threadIdx.x 의 관용구를 손에 새긴다.
  2. matmul naive — 2D grid, 2D block. row = blockIdx.y * blockDim.y + threadIdx.y 가 자연스러워질 때까지.
  3. matmul tiled — shared memory 등장. __shared__, __syncthreads() 의 의미를 본인이 그림으로 그릴 수 있어야 한다.
  4. matmul + bank conflict — 같은 코드가 padding 한 줄 차이로 2배 빠르거나 느려지는 자리.

이 시퀀스는 PMPP (Programming Massively Parallel Processors) 책 1–5장과 거의 같다. L002 Andreas Köpf 의 PMPP 강의가 같은 자리를 다룬다.

// matmul tiled — CUDA 의 인지적 어려움이 한꺼번에 모이는 자리
__global__ void matmul_tiled(float* A, float* B,
                            float* C, int N) {
  __shared__ float As[16][16];
  __shared__ float Bs[16][16];

  int row = blockIdx.y * 16 + threadIdx.y;
  int col = blockIdx.x * 16 + threadIdx.x;

  float acc = 0.0f;
  for (int t = 0; t < N/16; ++t) {
    As[threadIdx.y][threadIdx.x] =
        A[row*N + t*16 + threadIdx.x];
    Bs[threadIdx.y][threadIdx.x] =
        B[(t*16 + threadIdx.y)*N + col];
    __syncthreads();

    for (int k = 0; k < 16; ++k)
      acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
    __syncthreads();
  }
  C[row*N + col] = acc;
}
학습 곡선의 첫 벽

이 코드 안에서 “같은 thread 가 한 iteration 에서 어떤 element 를 읽고 어떤 element 를 쓰는가” 를 머릿속에서 동시 추적할 수 있어야 한다. __syncthreads() 의 의미는 “이 줄이 끝날 때까지 같은 block 의 모든 thread 가 도착해야 한다”. 이걸 빼먹으면 race condition 이 일어나는데 — 가장 짜증나는 점은 대부분의 입력에서 정답이 나온다는 것. 무작위로 깨진다.

CUDA 의 첫 단계에서 학습자가 보통 6~8주 정도를 보낸다. Umar 도 같은 시기를 거쳤음을 자기 채널의 다른 영상에서 여러 번 언급한다. 이 단계의 끝에 도착하면 — 그제야 같은 알고리즘을 다르게 짜면 다르게 빠를 수 있다 의 직관이 잡힌다.

§ 03Triton 으로 갈아타기· tile-level abstraction

thread 의 언어 → tile 의 언어 — 사라지는 것과 새로 생기는 것

Triton 의 가장 큰 인지적 변화는 thread 가 사라진다는 점이다. 코드 안에서 thread 단위 indexing 을 직접 쓰지 않는다. 대신 tile (작은 행렬 블록) 을 통째로 다룬다 — tl.load, tl.dot, tl.store. compiler 가 thread 분배를 알아서 한다.

FIG · CUDA vs Triton — 같은 matmul 의 사고 단위thread → tile
CUDA thread 한 명의 입장row, col, k 인덱스를 직접 계산 __shared__ 명시, __syncthreads() 명시. 사용자가 SRAM 의 모양과 위치를 통제. bank conflict 까지 사용자 책임. 통제권 강함
Triton tile 한 장의 입장BLOCK_M × BLOCK_K, BLOCK_K × BLOCK_N 의 한 덩어리 tl.load(ptrs, mask), tl.dot(a, b), tl.store. SRAM 사용은 compiler 가 결정. bank conflict 도 compiler 가 회피. 하지만 launch 설정(BLOCK_M/BLOCK_N/BLOCK_K, num_warps)에 결과가 민감. 생산성 ↑
PyTorch batched op 의 입장매트릭스 통째로 A @ B 한 줄. 메모리 계층 통제 없음. 간단한 패턴은 torch.compile 이 fused Triton 으로 자동 lowering. 생산성 ↑↑
통제권 ↔ 생산성 의 trade-off 가 한 줄에 정렬된다. Triton 이 두 극단의 중간자리를 잡는다.
# Triton matmul — tile 단위로 사고
@triton.jit
def matmul_kernel(A, B, C, M, N, K,
                  stride_am, stride_ak,
                  stride_bk, stride_bn,
                  stride_cm, stride_cn,
                  BLOCK_M: tl.constexpr,
                  BLOCK_N: tl.constexpr,
                  BLOCK_K: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        a = tl.load(A + offs_m[:,None]*stride_am
                       + (k+offs_k)[None,:]*stride_ak)
        b = tl.load(B + (k+offs_k)[:,None]*stride_bk
                       + offs_n[None,:]*stride_bn)
        acc += tl.dot(a, b)
    tl.store(C + offs_m[:,None]*stride_cm
                + offs_n[None,:]*stride_cn, acc)

위 코드에서 사라진 것 들을 본다.

  • __shared__ 선언이 없다 — Triton 이 알아서 SRAM 으로 staging.
  • __syncthreads() 가 없다 — tile 연산이 자동 동기화.
  • threadIdx.x 같은 thread 단위 인덱스가 없다 — tl.arange 로 한 tile 의 인덱스 벡터를 통째로 다룬다.

새로 생긴 것.

  • BLOCK_M / BLOCK_N / BLOCK_K 라는 launch 설정. 같은 코드가 설정에 따라 5–10배 차이가 난다.
  • tl.constexpr 의 의미 — JIT 가 컴파일 시점에 값을 고정해서 unroll 할 수 있게.
  • num_warps, num_stages — kernel decorator 의 hyperparameter. 이 자리에서 사람이 이해해야 하는 GPU 의 디테일이 다시 등장한다.
Triton 의 진짜 장점

코드가 짧아지는 것이 아니라 — bank conflict, register tiling, swizzling 같은 “GPU 디테일에서 오는 hard-to-debug 버그” 가 사라진다는 점이다. 학습자가 알고리즘 자체에 집중할 수 있게 된다. CUDA 학습 6주 + Triton 학습 1주가 합쳐져야 진짜 의미가 있다 — Triton 만 배우면 “왜 이 설정이 빠른가” 의 직관이 안 잡힌다.

§ 04attention 의 핵심 어려움· 메모리 폭발 + softmax

왜 standard attention 은 GPU 에 친절하지 않은가

Flash Attention 이 풀려고 하는 문제는 두 개다 — (1) attention matrix 가 N² 메모리를 먹는다는 점, (2) softmax 가 row 전체를 알아야 정규화된다는 점. 둘이 합쳐지면 — 큰 matrix 를 HBM 에 한 번 쓰고 다시 읽고, 또 쓰고 다시 읽는 패턴이 강제된다.

standard attention 의 메모리 패턴을 풀면 —

  1. S = Q @ K.TS 가 (N, N) HBM 에 쓰임
  2. P = softmax(S, dim=-1) → row max 와 sum 을 위해 S 를 다시 읽고, 또 다시 쓰고
  3. O = P @ VP 를 다시 읽고, 결과 O 를 쓰고

여기서 SP 는 N=8192 만 되어도 64M 원소 — fp16 으로 128MB. 이게 HBM 을 왔다갔다하는 시간이 매트릭스 곱 자체보다 더 길다. memory-bound.

핵심 통찰

attention 은 FLOPs 의 알고리즘이라기보다 메모리 패턴의 알고리즘이다. 같은 FLOPs 를 계산하면서 HBM 왕복을 줄이면 거의 그대로 빨라진다. Flash Attention 1 의 주장이 정확히 이것 — “더 적은 일을 하는 것이 아니라, 같은 일을 더 적은 메모리 왕복으로.”

FIG · 메모리 위계별 대역폭A100 기준 추정
HBM
2 TB/s
L2 cache
5 TB/s
SRAM (shared)
19 TB/s
register
40+ TB/s
SRAM 위에서 끝까지 도는 게 HBM 왕복을 한 번 줄이는 것의 10배 효과. Flash Attention 의 본질이 이 지점.

그런데 “SRAM 위에서 끝까지 돌게 한다” 가 실제로 만나는 첫 벽이 — softmax 의 정규화는 row 전체의 max 와 sum 을 안 다음에야 끝난다는 점이다. naive 하게는 (1) row 전체를 한 번 읽어서 max 구하고, (2) 다시 읽어서 exp(x-max) 의 sum 구하고, (3) 또 다시 읽어서 정규화한다 — 같은 row 를 세 번 읽는다. 한 tile 안에서 끝낼 수 없다. 이 자리를 푸는 게 online softmax다.

“Flash Attention 의 핵심은 fancy 한 hardware 트릭이 아니라 online softmax 라는 50년 된 numerical analysis 결과의 재발견이다.” 학습 노트 · 재구성
§ 05online softmax 가 푸는 자리· running max + running sum

row 전체를 안 보고도 softmax 를 정확히 계산하는 방법

online softmax 의 아이디어는 단순하다 — 지금까지 본 max 와 지금까지의 (rescaled) sum 만 들고 다닌다. 새 tile 의 max 가 들어오면 둘 다 보정한다. 이게 numerically stable 하게 동작한다는 점이 50년 전부터 알려져 있었다 (Milakov & Gimelshein 2018 의 재정리, 더 거슬러 올라가면 LogSumExp 의 stable 계산).

step 0 · 첫 tile tile 안에서 m₀ = max(x₀), ℓ₀ = sum(exp(x₀ - m₀)). SRAM 에 둘만 들고 있는다.
step 1 · 둘째 tile 도착 새 tile 의 m₁_local = max(x₁), ℓ₁_local = sum(exp(x₁ - m₁_local)) 계산.
step 1 · 보정 새 max m₁ = max(m₀, m₁_local). 기존 sum 을 새 max 기준으로 rescale: ℓ₀_new = ℓ₀ * exp(m₀ - m₁). 새 sum 도 마찬가지: ℓ₁ = ℓ₁_local * exp(m₁_local - m₁). 합친다: ℓ = ℓ₀_new + ℓ₁.
step 1 · output output 도 같은 rescaling 을 받는다 — O = O * exp(m₀ - m₁) + tile₁ × V₁. 마지막 tile 후에 O / ℓ 한 번으로 정규화.
완료 row 전체를 한 번도 통째로 메모리에 두지 않았다. SRAM 위에서 (m, ℓ, O) 세 개만 들고 다녔다.
왜 stable 한가

각 단계의 max 보정이 지수 함수의 평행이동 이다. exp(x - m₁) = exp(x - m₀) * exp(m₀ - m₁). 항상 m 이 지금까지 본 최대보다 크거나 같으므로 x - m ≤ 0 — overflow 안 난다. 이 한 줄이 Flash Attention 전체를 가능하게 한다.

FIG · 한 row 의 attention 계산이 tile 단위로 어떻게 흐르는가schematic
m₀
ℓ₀
O₀
tile₁
load
m,ℓ
update
O+=
tile₂
load
m,ℓ
update
O+=
tile₃
load
m,ℓ
update
O+=
O/ℓ
final
붉은 칸 = SRAM 에 항상 살아있는 작은 state, 황색 = HBM 에서 매번 새로 읽는 tile, 파란 = 끝에 한 번 정규화. 한 tile 처리하는 동안 HBM 은 K/V 의 한 tile 만 읽힌다.

이 자리에서 강의의 한 가지 미세한 점 — “naive softmax 가 fp16 에서 overflow 나는 이유” 는 max-shift 가 빠진 형태이고, 그 보정 자체가 online softmax 의 식과 같다는 점. PyTorch 의 F.softmax 도 내부적으로 max-shift 를 한다. Umar 가 강의에서 이 점을 강조했을 가능성이 높다 (원본 영상 확인 필요).

§ 06tile 단위 사고· SRAM 위에서만 도는 루프

알고리즘을 “이 tile 에서 무엇을 다음 tile 로 들고 나갈까” 로 다시 짜기

Flash Attention 의 코드를 한 줄씩 읽으면 — 알고리즘을 자체로 새로 발명한 게 아니라, 같은 attention 을 “tile 사이로 무엇을 carry 할 것인가” 의 관점에서 다시 짠 것이라는 사실이 드러난다. 이 사고가 Triton 의 tile 추상과 정확히 맞물린다.

Flash Attention 2 forward 의 outer 루프는 query block 단위, inner 루프는 key/value block 단위. 한 query block 을 잡으면 SRAM 위에 그 q 를 박고, K/V 를 tile 씩 읽어들이면서 (m, ℓ, O) 를 갱신한다.

한 query block 의 일이 끝나면 (m, ℓ, O) 를 HBM 에 한 번 쓰고 다음 query block 으로. HBM 왕복이 query block 수 만큼만 일어난다 — naive 의 N² 항이 한 번 사라진 자리.

왜 query block 을 outer 로 두는가 (FA2 의 변경점)

Flash Attention 1 은 K 를 outer 로 두었다. FA2 는 Q 를 outer 로 둔다 — 이유는 output O 가 query 마다 독립이고, query block 이 outer 일 때 다른 thread block 끼리 communicate 할 일이 없기 때문이다. 더 좋은 parallelism. backward 도 마찬가지로 분해 단위가 달라졌다.

# Flash Attention 2 forward 의 한 query block (의사 코드)
def fa2_forward_block(q_block, K, V, scale):
    # SRAM 에 살아있는 state
    m = -inf                  # running max
    l = 0                      # running sum
    O = 0                      # output

    for kv_block in kv_blocks_in(K, V):
        s = q_block @ kv_block.k.T * scale
        m_new_local = s.max(dim=-1)
        m_new = max(m, m_new_local)
        alpha = exp(m - m_new)
        l = l * alpha + exp(s - m_new).sum(dim=-1)
        O = O * alpha + exp(s - m_new) @ kv_block.v
        m = m_new

    return O / l, m, l       # (m,l) 은 backward 용

이 코드의 모든 변수는 SRAM 에 들어가는 작은 사이즈다. BLOCK_M = 64 ~ 128, head_dim ≤ 128 정도면 한 block 의 q, kv tile, m, l, O 가 모두 register/shared 에 들어간다. HBM 왕복은 query block 한 번 + 모든 kv block 의 한 번씩.

“Flash Attention 의 코드를 한 줄씩 읽고 손으로 따라 적어보면 — softmax 와 matmul 만 알면 짤 수 있는 알고리즘이라는 사실이 명확해진다. 어렵게 느껴지는 건 표기뿐이다.” 학습 노트 · 재구성
실전의 추가 디테일

실제 FA2 Triton 구현에는 — causal mask 처리(상삼각만 계산), dropout, GQA(group query attention 의 K/V 공유), bias term 같은 것들이 더 들어간다. 핵심 outer/inner 루프 구조는 같다. Tri Dao 의 GitHub flash-attention repo 에서 flash_attn_triton.py 가 가장 읽기 좋은 reference.

§ 07backward 의 회계· recompute vs save

forward 에서 무엇을 저장하고 backward 에서 무엇을 다시 계산할 것인가

forward 가 풀려도 backward 가 어렵다. 왜냐하면 attention matrix S, P 가 forward 에서 HBM 에 안 저장됐기 때문이다 — 그게 알고리즘의 포인트였으니까. 그러면 backward 는 어떻게? 다시 계산한다.

FIG · forward 가 저장하는 것 vs backward 가 다시 만드는 것autograd 회계
save forward 가 HBM 에 남기는 것O, m, ℓ — 작은 양 출력 O, row-wise m (max), row-wise (sum). softmax 의 정규화 정보. 메모리 N×d + 2N — N² 가 아니다. O(N·d)
recompute backward 가 다시 만드는 것S, P 의 한 tile 씩 P_ij = exp(s_ij - m_i) / ℓ_i 를 backward 에서 같은 tile 단위로 다시 계산. forward 의 m, ℓ 가 있어서 stable. flops 두 배
trade memory ↔ computeO(N²) HBM 절약을 위해 forward FLOPs 의 ~50% 추가 N 이 커질수록 메모리 절약이 절대적이고, FLOPs 추가는 어차피 SRAM 위에서 도니까 거의 공짜. Long context 에서 결정적. N 큼 → 이득 큼
backward 의 분해 단위가 forward 와 약간 달라야 한다는 점이 FA2 의 디테일. forward 는 q-outer, backward 는 kv-outer 의 분해가 더 빠르다는 게 Tri Dao 의 관찰.

backward 식을 풀면 (Tri Dao FA paper appendix 참고) —

  • dV = P.T @ dO — kv block 단위로 나눠서 계산
  • dP = dO @ V.T
  • dS = P ⊙ (dP - rowsum(dP ⊙ P)) — softmax 의 jacobian
  • dQ = dS @ K, dK = dS.T @ Q

이 식들은 다 같은 (m, ℓ) 정보로 P 를 다시 만들 수 있어야 stable 하게 돈다. 그래서 forward 가 (m, ℓ) 를 저장하는 것이 backward 의 전제다.

학습 자료 추적

backward 의 정확한 식과 그 안의 trick — 특히 D = rowsum(dO ⊙ O) 라는 작은 보조 변수를 미리 계산해두면 식이 깔끔해진다는 점 — 은 Tri Dao FA1 paper Algorithm 4, FA2 paper Algorithm 2 에 있다. Umar 가 강의에서 이 식을 손으로 끌어냈을 가능성이 있다 (원본 영상 확인 필요). 만약 안 했다면 — 자기가 손으로 한 번 풀어보는 게 학습에서 가장 큰 단계.

“forward 의 모든 sleek 한 디자인은 backward 의 회계까지 풀려야 의미가 있다. FA1 → FA2 의 가장 큰 변화도 backward 분해의 재설계.” 학습 노트 · 재구성
§ 08학습 자료 추적· paper · code · tutorials

이 길을 다시 걸을 사람이 손에 들어야 하는 자료

강의 자체가 transcript 실패라 — 같은 길을 따라가려는 사람을 위해 자료 목록을 정리해둔다. Umar 의 채널이 가장 좋은 시작점이고, Tri Dao 의 두 paper 가 reference 다.

Flash Attention 1 paper
Dao et al., 2022. Algorithm 1, 2, 4 가 핵심 — forward / backward 의 tile 단위 의사코드. 메모리 위계의 명시적 분석.
Flash Attention 2 paper
Dao, 2023. FA1 대비 (1) Q-outer 분해, (2) backward 재설계, (3) parallel over heads. 읽는 순서: FA1 먼저, FA2 차이점만.
Flash Attention 3 paper
Hopper (H100) 아키텍처 활용 — TMA, async, fp8. Triton 보다 CUTLASS 에 더 가까움. 학습 순서상 마지막.
Triton tutorials
OpenAI Triton 의 06-fused-attention.py 가 FA2 의 가장 읽기 좋은 reference 구현. 약 200줄.
Tri Dao flash-attention repo
CUDA C++ + CUTLASS 의 production 구현. flash_attn_triton.py 가 Triton 버전.
Umar Jamil YouTube
CUDA, Triton, Mamba, RoPE 등 from-scratch tutorial 시리즈. 같은 호흡으로 길게 따라갈 수 있는 화이트보드 강의.
Online softmax
Milakov & Gimelshein, 2018, “Online normalizer calculation for softmax”. FA 가 인용하는 base.
PMPP
Programming Massively Parallel Processors (Hwu, Kirk, El Hajj). CUDA 학습의 표준 교재. L002, L009 가 같은 책 위에 있다.
§ 09학습 곡선의 함정· 어디서 막히는가

같은 길을 가는 사람들이 보통 부딪히는 자리

CUDA → Triton → FA 의 학습 시퀀스에서 학습자들이 똑같이 막히는 자리들. Umar 의 강의가 이 자리들을 명시적으로 짚었는지는 확실하지 않지만 (원본 확인 필요), 학습 노트로서 정리해둘 가치가 있다.

  1. CUDA 의 “thread 1인칭” 사고 진입에 6주 이상 — 빨리 갈 수 없다. PMPP 책을 직접 풀어야 한다. 강의만으로는 안 됨.
  2. Triton 으로 빨리 넘어가려는 유혹 — Triton 만 배우면 “왜 이 설정이 빠른가” 의 직관이 안 잡힌다. CUDA 의 SRAM/bank/coalesce 의미를 손에 들고 Triton 으로 가야 진짜 이해.
  3. online softmax 의 식 한 번에 안 잡힘 — 종이 위에서 N=4 의 작은 예제로 한 번 손계산 해보는 게 직관 잡는 가장 빠른 길.
  4. FA1 코드 → FA2 코드 차이를 “자세히 안 봄” — 같은 알고리즘처럼 보이지만 outer 루프가 다름. 같이 두고 diff 해야 의미 보임.
  5. backward 식 직접 안 풀어봄 — autograd 가 자동이라서 안 풀어도 된다고 착각. 실제로 FA backward 의 reference 코드는 식을 알고 봐야 짤리는 자리가 많음.
  6. Triton autotuner 결과를 “공식 답” 으로 받아들임 — 사실 search space 가 좁아서 항상 최적이 아님. 직접 sweep 해본 사람이 더 빠른 설정을 찾는 경우가 종종 있음.
  7. “성능 측정” 할 때 워밍업, sync 빼먹음L001 의 timing 패턴 (CUDA Event + warmup + sync) 을 항상 쓸 것.
  8. head_dim 이 큰 경우(>128) 의 분해 차이 모름 — register pressure 때문에 같은 코드가 안 컴파일되거나 occupancy 가 무너짐. 이 자리를 위해 CUTLASS / FA3 가 개입.
한 번에 다 안 잡히는 게 정상

FA 를 자기 손으로 한 줄씩 짤 수 있게 되는 데 — CUDA 부터 시작해 평균 3~4개월. Umar 의 강의가 그 학습 곡선을 “한 사람의 회고” 로 보여주는 데 의미가 있다. 한 번 듣고 안 잡힌다고 자책하지 말 것.

“이건 FA 를 짜는 강의가 아니라 — FA 를 짜기까지 무엇을 배워야 했는지의 강의다.” 학습 노트 · 재구성
§ 10기억할 메모와 코드 자료· key takeaways

다시 열었을 때 5분 안에 손에 잡혀야 할 것

강의의 detail 보다 — 이 노트가 의미를 지니려면 손에 박아야 할 짧은 사실들.

CUDA 는 thread 의 언어
1인칭 thread 의 입장에서 코드를 짜는 사고. __shared__, __syncthreads() 직접 통제. 6+ 주 학습.
Triton 은 tile 의 언어
tile 단위 op (tl.load, tl.dot, tl.store). compiler 가 thread 분배. launch 설정에 결과 민감.
attention 은 memory-bound
naive 구현이 N² 행렬을 HBM 에 두 번 왕복. FA 는 같은 FLOPs 를 SRAM 위에서.
online softmax
running max + running sum 만 들고 다닌다. 새 tile 들어오면 둘 다 보정. SRAM 위에서 가능.
FA forward 의 outer 루프
FA1: K outer / FA2: Q outer. 후자가 더 좋은 parallelism, 더 작은 cross-block 통신.
backward 의 회계
forward 가 (O, m, ℓ) 만 저장. backward 는 P 를 다시 만든다. 메모리 N² → N·d 로 절약.
학습 시퀀스
PMPP 1–5장 → Triton tutorial → FA1 paper → FA2 paper → Tri Dao Triton 코드 한 줄씩.
학습 시간
CUDA 6주 + Triton 1주 + softmax/FA 식 종이로 1주 + 코드 정독 2주 ≈ 3개월.

손에 새기기 — 실습 시퀀스

  1. PMPP 1–5장 풀이 — vector add, matmul naive, matmul tiled 를 직접. L002 와 같은 자리.
  2. Triton 의 첫 matmul — OpenAI tutorial 03 을 그대로 베껴 짠 다음, BLOCK_M/BLOCK_N/BLOCK_K 를 sweep.
  3. online softmax 종이 위에서 풀기 — N=8 의 row 한 줄로 max-shift, running update 손계산. step by step.
  4. FA1 의 Algorithm 1 손코딩 — paper 의 의사코드를 그대로 PyTorch 로 짠다 (느려도 됨). 정답 검증 위주.
  5. Triton tutorial 06 정독 — OpenAI 의 06-fused-attention.py 를 한 줄씩 주석.
  6. FA1 vs FA2 코드 diff — Tri Dao 의 두 버전을 같이 두고 outer 루프 차이 직접 비교.
  7. backward 식 직접 풀기 — paper appendix 따라 dQ, dK, dV 를 손으로. autograd 안 쓰고.
  8. NCU 로 한 번 떠보기 — 자기가 짠 Triton FA 의 occupancy 와 SRAM 사용량을 직접 측정.
§ 11다른 강의로 이어지는 길· connections

같은 자리를 다른 각도에서 다루는 강의들

L050 의 학습 시퀀스가 시리즈의 어느 강의들과 직접 맞물리는지.

§ 12열린 질문· open questions

원본 자막 실패로 본 노트에서 비워둔 자리들

강의 자체의 transcript 가 실패해서, 본 노트는 도메인 지식과 공개된 자료로 재구성됐다. 다음에 영상을 직접 다시 본다면 확인할 자리들.

검증 메모

본 노트에서 등장하는 모든 식, 알고리즘 단계, 메모리 추정값 — 강의 자체의 직접 인용이 아니라 Tri Dao 의 paper 와 Triton tutorial 에서 파생된 재구성. Umar 의 영상에서 다른 강조점이 나왔을 수 있으므로 영상을 직접 본 사람이 노트를 보강하면 좋음.

← Lecture 049 이전 강의 Lecture 051 → Consumer GPU performance — Jake Cannell