LESSON 07 · 2026.04.19 · T4
PyTorch Custom Op — ~50 줄로 CUDA 를 꽂다
레슨 6 의 Flash Attention 커널을 torch.ops.mylib.flash_attention
으로 등록. 이 순간부터 vLLM 의 csrc/*.cu 가 낯설지 않다.
GPU · T4
layer · C++ host wrapper
baseline · F.scaled_dot_product_attention
핵심 ~5 줄
한 파일 (330 줄, 대부분이 커널 본체) 에서 PyTorch 세계로의 전환은 아래 5 줄이 전부다.
check_qkv(Q, "Q"); // shape/device/dtype/contig
auto O = torch::empty({N, d}, Q.options()); // PyTorch 가 alloc
auto stream = at::cuda::getCurrentCUDAStream();
my_kernel<<<grid, block, 0, stream>>>(
Q.data_ptr<float>(), ..., O.data_ptr<float>());
return O;
이게 레슨 6 의 main() 수백 줄 (CLI 파싱, malloc, CSV 출력, CPU ref 검증) 을 완전 대체. production 패턴.
정확도
| N | naive abs err | flash abs err |
| 128 | 3.3e-7 | 2.5e-7 |
| 512 | 5.4e-7 | 4.3e-7 |
| 1024 | 4.0e-7 | 2.8e-7 |
| 2048 | 3.7e-7 | 3.7e-7 |
FP32 machine epsilon (~1.2e-7) 의 3–5 배. SDPA 와 우리 커널이 둘 다 FP32 rounding 한계 안.
속도 (vs F.scaled_dot_product_attention, T4, d=64)
| N | ours naive | ours flash | SDPA | flash / sdpa |
| 512 | 0.475 | 0.734 | 0.262 | 0.36× |
| 1024 | 0.799 | 1.309 | 0.429 | 0.33× |
| 2048 | 3.086 | 1.253 | 0.428 | 0.34× |
| 4096 | — | 2.498 | 1.374 | 0.55× |
SDPA 대비 0.33–0.55×. N 이 커질수록 격차 좁아진다 (N² 에 묶여 둘 다 비슷한 bound). cuDNN 의 튜닝 수준이 선명히 보이는 숫자.
교훈 1 · stream-aware 런치는 타협 불가
at::cuda::getCurrentCUDAStream() 을 네 번째 런치 인자에 안 주면 default stream 으로 감. PyTorch 는 다른 stream 을 쓰고 있을 수 있음 → 침묵의 race condition. 크래시 안 나는데 결과가 비결정적. vLLM 의 PagedAttention 런치도 정확히 같은 패턴.
교훈 2 · dispatcher 가 CPU 가드를 공짜로 준다
TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
m.impl("flash_attention", &flash_attention_forward);
}
CUDA backend 에만 등록. CPU 텐서가 오면 dispatcher 가 먼저 차단:
RuntimeError: Could not run 'mylib::flash_attention' with arguments from the 'CPU' backend
autograd, dtype promotion, device dispatch 가 하나의 파이프에 엮이는 PyTorch op 시스템의 깔끔한 부분.
교훈 3 · 3× 격차는 선이자 현실
- cuDNN 은 register allocation / instruction scheduling 을 NVIDIA 엔지니어가 손으로 튜닝
- Warp-specialized FP32 경로 (WMMA 는 FP16 전용) 사용
- smem bank conflict 회피 레이아웃
- 우리 Br=64 / d=64 에선 register pressure 로 스필 가능성
우리 자리는 SOTA 추격이 아니다. 내장에 빠진 연산 (novel op, custom sparsity) 을 채우는 곳.
함정 기록
import mylib_ext → PyInit_mylib_ext not defined. 해결: 빈 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} 추가.
nvcc not found. CUDA_HOME=/usr/local/cuda 명시.
- CUDA 12.9 toolkit vs PyTorch cu121 wheel — minor version compat 으로 돌긴 함. 프로덕션은 일치.
레슨 1–7 스택
Python 모델 (vLLM, 내 서비스)
│
│ torch.ops.mylib.flash_attention(q,k,v) ← 레슨 07 이 뚫은 층
▼
torch dispatcher
│
▼
C++ host wrapper (tensor → raw ptr, stream)
│
▼
CUDA kernel (flash_attention_v1) ← 레슨 06
│
▼
Warp / thread (shuffle, tiled mma) ← 레슨 03,04,05
│
▼
Memory hierarchy (HBM↔L2↔smem↔reg) ← 레슨 01,02
여기까지가 CUDA Phase 1. 다음 레슨부터 Triton.
LESSON 07 · 2026.04.19 · T4
PyTorch Custom Op — wiring CUDA in ~50 lines
Register Lesson 6's Flash Attention kernel as torch.ops.mylib.flash_attention. From this moment on, vLLM's csrc/*.cu stops looking foreign.
GPU · T4
layer · C++ host wrapper
baseline · F.scaled_dot_product_attention
The core ~5 lines
In a single file (330 lines, mostly the kernel body), the pivot into the PyTorch world is these five lines:
check_qkv(Q, "Q"); // shape/device/dtype/contig
auto O = torch::empty({N, d}, Q.options()); // PyTorch allocates
auto stream = at::cuda::getCurrentCUDAStream();
my_kernel<<<grid, block, 0, stream>>>(
Q.data_ptr<float>(), ..., O.data_ptr<float>());
return O;
That replaces Lesson 6's hundreds of lines of main() (CLI parsing, malloc, CSV output, CPU ref check) entirely. This is the production pattern.
Accuracy
| N | naive abs err | flash abs err |
| 128 | 3.3e-7 | 2.5e-7 |
| 512 | 5.4e-7 | 4.3e-7 |
| 1024 | 4.0e-7 | 2.8e-7 |
| 2048 | 3.7e-7 | 3.7e-7 |
3–5× FP32 machine epsilon (~1.2e-7). SDPA and our kernel both sit inside FP32 rounding limits.
Speed (vs F.scaled_dot_product_attention, T4, d=64)
| N | ours naive | ours flash | SDPA | flash / sdpa |
| 512 | 0.475 | 0.734 | 0.262 | 0.36× |
| 1024 | 0.799 | 1.309 | 0.429 | 0.33× |
| 2048 | 3.086 | 1.253 | 0.428 | 0.34× |
| 4096 | — | 2.498 | 1.374 | 0.55× |
0.33–0.55× of SDPA. The gap narrows with N (both are bound by N², converging toward a similar regime). The number is where cuDNN's tuning level shows up clearly.
Lesson 1 · stream-aware launches are non-negotiable
Skip at::cuda::getCurrentCUDAStream() as the 4th launch argument and you go to the default stream. PyTorch might be using a different stream → silent race condition. No crash, non-deterministic output. vLLM's PagedAttention launch follows exactly this pattern.
Lesson 2 · the dispatcher gives you the CPU guard for free
TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
m.impl("flash_attention", &flash_attention_forward);
}
Register only under the CUDA backend. When a CPU tensor arrives, the dispatcher blocks it upfront:
RuntimeError: Could not run 'mylib::flash_attention' with arguments from the 'CPU' backend
Autograd, dtype promotion, device dispatch all wired into one pipe — a clean part of PyTorch's op system.
Lesson 3 · the 3× gap is a line and reality
- cuDNN has register allocation and instruction scheduling hand-tuned by NVIDIA engineers
- Uses warp-specialized FP32 paths (WMMA is FP16-only)
- Smem layouts that avoid bank conflicts
- At our Br=64 / d=64, register pressure may be causing spills
Our place isn't SOTA chasing. It's filling operators the library is missing (novel ops, custom sparsity).
Trap log
import mylib_ext → PyInit_mylib_ext not defined. Fix: add an empty PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}.
nvcc not found. Set CUDA_HOME=/usr/local/cuda.
- CUDA 12.9 toolkit vs PyTorch cu121 wheel — runs under minor-version compat. For production, keep them aligned.
Lessons 1–7 stack
Python model (vLLM, my service)
│
│ torch.ops.mylib.flash_attention(q,k,v) ← layer broken by lesson 07
▼
torch dispatcher
│
▼
C++ host wrapper (tensor → raw ptr, stream)
│
▼
CUDA kernel (flash_attention_v1) ← lesson 06
│
▼
Warp / thread (shuffle, tiled mma) ← lessons 03, 04, 05
│
▼
Memory hierarchy (HBM↔L2↔smem↔reg) ← lessons 01, 02
That's CUDA Phase 1 in the bag. Triton starts next lesson.