《GPU Mode》L025Composable Kernel · ROCmHigh prioritytranscript · slides · available
Speaking Composable Kernel (CK)
AMD GPU 위의 “CUTLASS 대응” — Composable Kernel 의 layered template 디자인. tile programming 으로 GEMM/Conv/Attention 을 짜고, AI Template 가 그 위에서 모델 graph 를 lowering. ROCm 의 Haocong Wang 이 보여주는 — CUTLASS 와 같은 idea, 다른 hardware (MI100/200/300, MFMA instruction) 위의 stack 학습 노트.
2018년 AMD MI100 출시. 첫 matrix-core (MFMA) 가 들어가는 generation. NVIDIA 에는 이미 cuDNN/cuBLAS/CUTLASS 가 있었다. AMD 에는 — 같은 위치의 라이브러리 스택이 없었다. CK 가 그 자리를 메우러 출발한 프로젝트.
강의가 깐 큰 질문 세 개.
왜 fused / composable 한 단위로 짜는가 — 한 모델 안에 수백 개의 layer 가 있다. 각각 hand-tuned 라이브러리 호출이면 polynomial 한 코드 폭발. tile-level building block 으로 합성 의 idea(§ 04).
AMD 의 MFMA 가 NVIDIA 의 mma.sync 와 어떻게 다른가 — instruction 레벨에서 미묘한 차이. 코드 portable 이 어려운 자리(§ 03).
이 stack 위에 어떻게 올라가는 ML 프레임워크 — Meta 의 AI Template 와의 협업이 turning point. PyTorch 모델이 CK 커널로 lowering 되는 path(§ 07).
강의의 인지적 frame
CK 의 디자인 철학은 CUTLASS 와 거의 같다 — “tile-level template 으로 큰 hardware 위 산술을 작은 단위로 표현, 그 단위들을 합성”. 차이는 hardware (MFMA vs mma.sync, LDS vs shared memory). 그래서 두 라이브러리는 같은 idea 의 다른 instantiation 이라고 봐야 한다.
“making one tile-level building block, then composing it into many fused kernels — that's the goal of composable kernel.”Haocong Wang (요약)
§ 02CK 의 역사· 2018 → 2024
“2018년 모델 inference 를 위한 라이브러리 → 2024년 LLM 학습/추론의 기둥” 까지
2018MI100 (CDNA1) 출시. CK 프로젝트 출발 — “AMD 위에 cuDNN-equivalent 가 없다” 의 답으로.
2020초기 CK 가 GEMM, conv, batched matmul 같은 기초 op 를 지원. 그러나 fusion abstraction 이 부족.
2022Meta 의 AI Template (AIT) 와 협업. AIT 는 CK 의 building block 위에 model graph compile 을 올리는 framework. 방향성의 큰 전환점.
2023tile programming abstraction 정식화. tile_window, tile_distribution, sweep 등의 추상이 raw template 위에 layer 됨.
2024MI300 generation 지원. ROCm FlashAttention 이 CK 위에서 정식 deployment. PyTorch 의 SDPA backend 로 통합 시작.
왜 AI Template 이 transformative 였나
그 전 CK 는 “라이브러리 작성자용” 도구였다. AI Template 가 모델 작성자가 직접 쓸 수 있는 entry point 를 만들었다. 이 흐름이 CK 가 productionize 되는 큰 momentum.
§ 03AMD GPU 의 다른 점· MFMA · LDS · CDNA
“CUDA 코드를 그냥 못 옮기는” 자리들
NVIDIA (CUDA)
Tensor Core: mma.sync, wgmma.mma_async
shared memory (SMEM)
warp = 32 thread
SM (Streaming Multiprocessor)
cuBLAS / cuDNN / CUTLASS / CuTe
AMD (ROCm / HIP)
Matrix Core: v_mfma_* (MFMA)
LDS (Local Data Share) — shared memory equivalent
wavefront = 64 thread (NVIDIA 의 2배)
CU (Compute Unit) — SM equivalent
rocBLAS / MIOpen / Composable Kernel
가장 미묘한 차이 — wavefront 가 64 thread. CUDA 코드의 “warp 32” 가정이 거의 모든 곳에서 깨진다. shared memory bank 수도 다르고 (32 vs 32 — 같지만 wavefront 가 64 라서 access pattern 이 다름), MFMA 의 tile shape 도 다르다(예: v_mfma_f32_16x16x16, v_mfma_f32_32x32x8).
CK 가 풀려는 추상 문제
“같은 algorithm 을 hardware 별로 다른 tile shape 로 instance 화 한다”. 사용자는 algorithm (GEMM, conv, attention) 만 표현, CK 가 hardware-specific 한 tile/swizzle/MFMA 를 자동 선택. CUTLASS 의 정신과 정확히 같다.
§ 04tile programming 의 핵심· tile distribution · sweep
“tensor 를 tile 로 쪼개고, tile 위 sweep 을 쓴다”
강의의 가장 중요한 abstraction. “CK 의 코드는 일반 CUDA 보다 high-level 하다” — tile 단위의 op 들이 사용자가 보는 primitive 다.
// CK tile programming — 약식// 1. DRAM 위에 tile_window 를 만든다 — “이 region 을 본다”auto a_dram_window = make_tile_window(
a_dram_tensor,
Sequence<128, 32>{}, // tile shape
{iM, iK}); // tile origin// 2. LDS (shared) 위에 distribution 만들기auto a_lds_block = make_static_distributed_tensor<...>();
// 3. load — DRAM tile → LDS
load_tile(a_lds_block, a_dram_window);
// 4. sweep — tile 의 각 element 위 연산
sweep_tile(a_lds_block, [&](auto idx) {
// idx 가 tile 안 좌표
acc(idx) += a_lds_block(idx) * b_lds_block(idx);
});
// 5. MFMA — block-level matrix multiply
block_tile_gemm_xdl_cshuffle(c_acc, a_lds_block, b_lds_block);
tile_window
tensor 의 부분 영역. shape 와 origin 으로 정의. global memory 의 “lens”.
tile_distribution
한 tile 의 element 가 thread 들에 어떻게 분배되는가. 좋은 distribution 이 coalescing 과 MFMA 호환을 동시에 보장.
load_tile / store_tile
tile 단위 데이터 이동. DRAM ↔ LDS ↔ register. 한 함수 호출로 여러 thread 협력.
sweep_tile
tile 안 elementwise 연산을 lambda 로. compile time 에 thread 별 할당이 결정.
block_tile_gemm
block-level GEMM building block. 내부에서 MFMA 를 띄움. CK 의 가장 중요한 primitive.
§ 05implicit GEMM 으로 conv· NHWC → GEMM mapping
“conv 를 GEMM 으로 본다 — input 을 직접 reorder 안 하고”
강의의 specific 한 예시. CNN 의 convolution 을 어떻게 CK 의 GEMM building block 위에서 짜는가. 답 — implicit GEMM.
traditional approach (im2col) — input NHWC 를 (N·H·W, K·K·C) 의 큰 matrix 로 펼침. memory 폭발.
implicit GEMM — input 을 그대로 두고, GEMM 의 index 계산 안에 conv 의 sliding window 를 직접 표현. “GEMM_M = N·H·W, GEMM_N = output channel, GEMM_K = K·K·input_channel” 의 mapping.
// implicit GEMM — input tensor index 가 conv 의 spatial index 로 변환// gemm_m = n * H * W + h * W + w// gemm_k = ky * Kx * C + kx * C + c// (n, h+ky*dilation, w+kx*dilation, c) → input// CK 가 이 mapping 을 tile_window 의 stride 로 표현auto a_window_strided = make_implicit_gemm_input_window(
input_tensor, kernel_size, strides, padding);
// 그 다음은 일반 GEMM 코드 그대로
왜 이게 강력한가
같은 GEMM building block 이 — matmul, conv, transposed conv, attention 의 Q·K^T 까지 모두 처리. “implicit” 한 mapping 들 (im2col, im2win 등) 만 추가. 한 building block 이 N 개 op 의 backbone. composable 의 정확한 의미.
L0 · MFMAv_mfma_f32_32x32x8_f16 같은 matrix instruction한 wave 가 한 instruction 으로 32×32×8 곱셈PTX 대응 ISA
L1 · warp tile한 wave 가 들고 가는 sub-tileMFMA 를 여러 번 호출해 sub-tile 합치기register
L2 · block tile한 block (CTA) 의 tile여러 wave 가 협력. LDS staging.LDS · shared
L3 · grid전체 GEMM 의 gridblock 들이 output tile 을 분담global
L4 · devicehost-facing APIdevice_gemm_xdl_cshuffle 한 줄user code
strict 하게 — CUTLASS 의 layer 와 거의 1:1 대응. CUTLASS 의 “collective” 가 CK 의 “block tile” 에 대응, “kernel” 이 “grid” 에 대응. 사용자는 보통 L4 entry 만 쓰고, customization 이 필요할 때만 한 layer 씩 내려간다.
§ 07AI Template — graph 위의 layer· model → CK kernels
“PyTorch 모델 → AIT compile → CK kernels” 의 path
CK 가 라이브러리라면, AI Template (AIT) 는 그 위에 올라가는 model compiler. PyTorch 모델을 AIT 의 graph 표현으로 받아 — 각 op 를 CK 의 fused kernel 로 lowering, 한 .so 파일로 codegen. inference 시점에 PyTorch 가 그 .so 를 호출.
FIG · AIT 의 lowering pipelinegraph → CK
1
PyTorch model
forward 함수
2
AIT graph
op 단위 IR — fusion 후보 식별
3
fusion pass
conv+bias+relu 같은 패턴을 한 CK kernel 로
4
CK kernel select
tile shape, schedule autotune
5
codegen
C++ source → hipcc → .so
6
runtime
PyTorch 가 .so import 해서 호출
중요한 건 — 이 stack 이 AOT (ahead-of-time) compile. torch.compile 같은 JIT 가 아니다. 첫 build 가 길지만 (분 단위), 이후 inference 는 native binary 의 속도. inference serving 에 잘 맞음.
§ 08ROCm FlashAttention· CK 위의 첫 큰 사용 사례
같은 algorithm 이 AMD 위에 정착하는 길
FlashAttention (Tri Dao) 의 origin 은 CUTLASS 위. 같은 algorithm 을 AMD 에 옮기는 작업이 CK 의 큰 첫 stress test. 강의에서 Haocong 이 직접 설명한 자리.
online softmax + tiled attention — algorithm 자체는 hardware-agnostic. tile shape, MFMA instruction 만 hardware-specific.
CK 위에서 거의 그대로 mapping — block tile GEMM 두 번 (Q·K^T, attn·V) + epilogue 안에 online softmax. tile programming 의 sweep 이 softmax statistics 누적.
upstream 통합 — Tri Dao 의 official FlashAttention repo 에 ROCm path 가 합류. 같은 PyTorch SDPA API 가 H100/MI300 모두에서 동작.
production 의미
이게 동작하는 시점이 “AMD GPU 위에서 LLM 추론을 production 할 수 있는” 시점. PyTorch 의 SDPA 가 backend swap 만으로 H100 ↔ MI300 호환. cluster TCO 에 직접 영향.
“FlashAttention on ROCm — 같은 algorithm 의 같은 결과, 다른 hardware. CK 의 abstraction 이 이걸 가능하게 했다.”Haocong Wang (요약)
§ 09CUTLASS vs CK· 같은 idea, 다른 hardware
두 라이브러리의 1:1 대응 — 그리고 미묘한 차이
CUTLASS · CuTe
NVIDIA GPU only
mma.sync (Ampere) / wgmma (Hopper)
warp = 32 thread
CuTe 의 layout algebra (composition · tile · partition)
(1) tile-level building block 으로 lowering. (2) layout/swizzle 를 type 안에 박음. (3) epilogue fusion 을 수동 가능. (4) hardware 별 instruction (MMA / MFMA) 를 atom 으로 추상화. 두 라이브러리는 같은 mental model 을 공유 — 한 쪽을 배우면 다른 쪽이 빠르다.
portability 노트
“같은 코드가 두 hardware 에 동작” 은 아직 어렵다. 일부 PyTorch frontend 는 두 backend 를 swap 가능 (SDPA, scaled_mm 등). 그러나 custom kernel 작성자는 여전히 CUTLASS 와 CK 를 각각 짜야. 이 gap 을 채우려는 시도가 Triton, MLIR, IREE 같은 cross-vendor 프로젝트들.
§ 10기억할 메모와 코드 자료· key takeaways
CK = AMD 의 CUTLASS
같은 idea — tile-level template building block 의 layered abstraction. hardware-specific (MFMA · LDS · wave64) 만 다름.
tile programming
tile_window + tile_distribution + sweep + block_tile_gemm. 일반 HIP 보다 high-level 하면서도 hardware tier 성능 보장.
implicit GEMM
conv 를 GEMM 으로 — input reorder 없이 stride 로. 한 building block 이 matmul/conv/attention 의 backbone.
4 layer hierarchy
MFMA → warp tile → block tile → grid. CUTLASS 의 4 layer 와 1:1 대응.
AI Template (AIT)
PyTorch model → CK kernels AOT compile. Meta 와 협업으로 출발. 모델 작성자가 CK 의 entry point 로 쓴다.
ROCm FlashAttention
CK 위의 첫 큰 production 사례. PyTorch SDPA 의 ROCm backend. 같은 algorithm 다른 hardware.
wave64 의 함의
CUDA warp32 가정의 코드는 옮길 수 없다. shared memory access pattern, MFMA tile shape 모두 다시 디자인.
FP8 GEMM
MI300 부터 FP8 (E4M3/E5M2) 지원. CK 의 GEMM 이 FP8 도 cover. quantize 된 추론에 필요.