cudatraining · 학습 기록

LESSON 07 · 2026.04.19 · T4

PyTorch Custom Op — ~50 줄로 CUDA 를 꽂다

레슨 6 의 Flash Attention 커널을 torch.ops.mylib.flash_attention 으로 등록. 이 순간부터 vLLM 의 csrc/*.cu 가 낯설지 않다.

GPU · T4 layer · C++ host wrapper baseline · F.scaled_dot_product_attention

핵심 ~5 줄

한 파일 (330 줄, 대부분이 커널 본체) 에서 PyTorch 세계로의 전환은 아래 5 줄이 전부다.

check_qkv(Q, "Q");                           // shape/device/dtype/contig
auto O = torch::empty({N, d}, Q.options());  // PyTorch 가 alloc
auto stream = at::cuda::getCurrentCUDAStream();
my_kernel<<<grid, block, 0, stream>>>(
    Q.data_ptr<float>(), ..., O.data_ptr<float>());
return O;

이게 레슨 6 의 main() 수백 줄 (CLI 파싱, malloc, CSV 출력, CPU ref 검증) 을 완전 대체. production 패턴.

정확도

Nnaive abs errflash abs err
1283.3e-72.5e-7
5125.4e-74.3e-7
10244.0e-72.8e-7
20483.7e-73.7e-7

FP32 machine epsilon (~1.2e-7) 의 3–5 배. SDPA 와 우리 커널이 둘 다 FP32 rounding 한계 안.

속도 (vs F.scaled_dot_product_attention, T4, d=64)

Nours naiveours flashSDPAflash / sdpa
5120.4750.7340.2620.36×
10240.7991.3090.4290.33×
20483.0861.2530.4280.34×
40962.4981.3740.55×

SDPA 대비 0.33–0.55×. N 이 커질수록 격차 좁아진다 (N² 에 묶여 둘 다 비슷한 bound). cuDNN 의 튜닝 수준이 선명히 보이는 숫자.

교훈 1 · stream-aware 런치는 타협 불가

at::cuda::getCurrentCUDAStream() 을 네 번째 런치 인자에 안 주면 default stream 으로 감. PyTorch 는 다른 stream 을 쓰고 있을 수 있음 → 침묵의 race condition. 크래시 안 나는데 결과가 비결정적. vLLM 의 PagedAttention 런치도 정확히 같은 패턴.

교훈 2 · dispatcher 가 CPU 가드를 공짜로 준다

TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
  m.impl("flash_attention", &flash_attention_forward);
}

CUDA backend 에만 등록. CPU 텐서가 오면 dispatcher 가 먼저 차단:

RuntimeError: Could not run 'mylib::flash_attention' with arguments from the 'CPU' backend

autograd, dtype promotion, device dispatch 가 하나의 파이프에 엮이는 PyTorch op 시스템의 깔끔한 부분.

교훈 3 · 3× 격차는 선이자 현실

우리 자리는 SOTA 추격이 아니다. 내장에 빠진 연산 (novel op, custom sparsity) 을 채우는 곳.

함정 기록

레슨 1–7 스택

Python 모델 (vLLM, 내 서비스)
        │
        │   torch.ops.mylib.flash_attention(q,k,v)    ← 레슨 07 이 뚫은 층
        ▼
torch dispatcher
        │
        ▼
C++ host wrapper (tensor → raw ptr, stream)
        │
        ▼
CUDA kernel (flash_attention_v1)                     ← 레슨 06
        │
        ▼
Warp / thread (shuffle, tiled mma)                   ← 레슨 03,04,05
        │
        ▼
Memory hierarchy (HBM↔L2↔smem↔reg)                   ← 레슨 01,02

여기까지가 CUDA Phase 1. 다음 레슨부터 Triton.