CUDA · KERNEL · COMPILER 18권 · T2 KERNEL · A4 LANDSCAPE · 22p

V06 — GEMM 완전판 단권화

CUTLASS 5-Level Hierarchy · CuTe Layout Algebra · Swizzle · TMA Pipeline · Collective · EVT
Volume 06/18
Tier T2 Kernel 패턴
선행 V01 · V02 · V03 · V04 · V05
용도 CUTLASS/CuTe 코드를 펼쳐 읽을 때의 지도

목차

1. GEMM의 대수 C=αAB+βCp.2
2. Naive → Tiled GEMMp.3
3. CUTLASS 5-Level Hierarchyp.4
4. Threadblock (CTA tile · mainloop)p.5
5. Warp tile · MmaMultistage · Fragmentp.6
6. Thread Iterator · Register Fragmentp.7
7. CuTe의 동기: Shape + Stride 대수p.8
8. CuTe Layout 기본p.9
9. CuTe Layout 연산 10종p.10
10. CuTe Tensor · local_tile · partitionp.11
11. Swizzle<B,M,S>p.12
12. TMA + CuTe (Hopper)p.13
13. CUTLASS 3.x Collectivep.14
14. Pipeline · Stage 설계p.15
15. Cooperative vs Pingpongp.16
16. Epilogue fused op chainp.17
17. Tile 선택 heuristicp.18
18. Grouped / Batched GEMMp.19
19. Mixed-dtype GEMM (FP8×BF16)p.20
20. Epilogue Visitor Tree (EVT)p.21
21. Cheat Sheet + Layout 공식p.22

범례

핵심 용어 (노랑)
매우 중요 / 표 헤더
정의·공식 박스
예시 박스
red주의·함정
핵심 (페이지당 ≤3)
(!)니모닉 (권당 ≤5)
다른 권 참조 (xref)
⊕⊗∘Layout 대수 연산
∵∴이유·결론
인쇄 A4 가로 / 여백 없음 / 배경 그래픽 포함
CUTLASS 3.x · CuTe · Hopper sm_90a · 22 pages

1 GEMM 표준형 General Matrix-Matrix Multiply α·A·B+β·C

정의 D = α · op(A) · op(B) + β · C. BLAS-3 level 연산. op(·) ∈ {N (no-trans), T (transpose), C (conjugate)}. 일반적으로 D는 C를 in-place 갱신.
A : M×K, B : K×N, C, D : M×N
Dm,n = α · Σk=0K−1 Am,k · Bk,n + β · Cm,n M, N, K : 행렬 차원 · α, β : scalar · k : reduction axis

2 전치 조합 표기 Column-major BLAS 관습

layoutA strideB stride해석
NNcol-majorcol-majorA(M×K) · B(K×N)
TNrow-major Acol-major BAᵀ 입력
NTcol-major Arow-major BBᵀ 입력
TTrow-majorrow-major둘 다 전치

CUTLASS는 LayoutA/LayoutB/LayoutC로 명시. cf. cuBLAS convention.

3 dtype 조합 A · B · accumulator · output

ABAccD용례
FP16FP16FP32FP16/FP32학습·추론
BF16BF16FP32BF16학습 표준
TF32TF32FP32FP32A100 FP32 가속
FP8 E4M3FP8 E4M3FP32BF16/FP8H100 추론
FP8 E5M2FP8 E5M2FP32FP32H100 gradient
INT8INT8INT32INT32/FP16양자화 ↗ V10 §4
FP4FP4FP32BF16Blackwell 추론

dtype 조합·bit layout 근거는 ↗ V09 §3·§4.

4 FLOP · byte 회계

FLOPs = 2 · M · N · K mul + add
bytesin = (M·K + K·N) · sizeof(dtypeA,B)
bytesout = M·N · sizeof(dtypeD)
AI = FLOPs / (bytesin + bytesout)

Roofline 계산 원리는 ↗ V18 §3.

5 Compute / Memory bound 경계 M·N 作, K 消

AIgemm ≈ (2·M·N·K) / ((M·K + K·N + M·N) · sz)
≈ 2·K · 1/(1/N + 1/M + K/(M·N) ) · 1/sz
  • M, N 크고 K 크면 compute-bound
  • M 또는 N 작으면 memory-bound → Split-K/Stream-K 필요 ↗ V05 §7·§8
  • K=1 수준 (GEMV) → 완전 memory-bound

6 ScaleType · accumulator 의미

CUTLASS ScaleType α·A·B + β·C 의 처리 모드. Default(둘 다), NoBetaScaling(β=1), OnlyAlphaScaling(β=0), Nothing(α=1,β=0).

epilogue scale은 출력 직전 FP32 acc에서 수행 (precision 보존).

7 본 권의 좌표 제외 항목

  • Attention dispatch → ↗ V07 §4·§10
  • Fusion 일반론 → ↗ V13 §8
  • MoE grouped GEMM 맥락 → ↗ V08 §3
  • PTX mma·ldmatrix → ↗ V03 §7·§8
  • WGMMA·TMA PTX → ↗ V04 §5·§7

1 Naive 3-nested loop

// row-major, no tiling
for (m = 0; m < M; ++m)
  for (n = 0; n < N; ++n) {
    acc = 0;
    for (k = 0; k < K; ++k)
      acc += A[m,k] * B[k,n];
    D[m,n] = alpha*acc + beta*C[m,n];
  }
∴ A의 재사용 N회, B의 재사용 M회가 전부 HBM 왕복. AI ≈ 1 수준 → memory bound.

2 3-level 메모리 blocking ★

HBM  ──► L2  ──► SMEM ──► Register
 ↑       ↑       ↑         ↑
 M,N,K   CTA     warp      thread
 전체    tile    tile      fragment
  • CTA tile: MC × NC × KC
  • Warp tile: MW × NW × KW
  • Thread tile: MT × NT (reg)

3 Tiled GEMM 의사코드 ★

// CTA(M_C, N_C), stages=S
for (k0 = 0; k0 < K; k0 += K_C) {
  // load to smem (cp.async / TMA)
  load_tile(sA[k0%S], gA[m_blk, k0 : k0+K_C]);
  load_tile(sB[k0%S], gB[k0 : k0+K_C, n_blk]);
  __syncthreads();
  for (kw : warp_K_iter) {
    rA = ldmatrix(sA, kw);
    rB = ldmatrix(sB, kw);
    mma(acc, rA, rB);  // tensor core
  }
}
epilogue(gD, acc, alpha, beta, gC);

4 재사용률 공식

reuseA = NC (CTA당 A tile이 N_C번 쓰임)
reuseB = MC
HBM traffic ≈ M·N·K·sz · (1/NC + 1/MC)

tile이 클수록 HBM traffic ↓ 그러나 SMEM · reg pressure ↑.

5 Double / Multi-stage buffering

stagesmem특성
11 tileload⇄compute 직렬
22 tilesdouble buffer · 기본
3~7S tilesmultistage · deep prefetch
Titer = max(Tcompute, Tload)
S ≥ ⌈ Tload / Tcompute

6 계층별 BW 비율

계층A100 BW배율
HBM2e~1.5 TB/s
L2~4.7 TB/s~3×
SMEM / L1~19 TB/s aggregate~13×
Register~500 TB/s~300×

A100 whitepaper · SMEM은 SM 당 ~140 GB/s × 108 SM. 계층 상세 ↗ V02 §3·§5.

7 실수 포인트

  • K 분할 안 하면 K 전체를 SMEM에 올리려 함 → overflow
  • __syncthreads 누락 → stage 간 race
  • stride 잘못 → uncoalesced HBM load

1 5-Level 전체 도식 ★ D-K-C-W-T

┌──────────────────────────────────────┐
│ Device   M×N×K (전체 문제)            │  in : A(M,K), B(K,N), C(M,N)
│                                      │  out: D(M,N)
│   ▼ gemm_universal launches grid     │
│                                      │
│ Kernel   grid + TileScheduler        │  in : problem_shape, tile_shape
│                                      │  out: per-CTA (m_blk, n_blk, k_iter)
│   ▼ CTA = tile 1개 담당              │
│                                      │
│ CTA     M_C×N_C × (k0..K) stages      │  in : gA tile, gB tile
│                                      │  out: CTA acc fragment (smem/reg)
│   ▼ 여러 warp(group)가 분담          │
│                                      │
│ Warp    M_W×N_W × K_W                 │  in : sA slice, sB slice
│                                      │  out: warp acc (register)
│   ▼ warp = mma instruction 묶음      │
│                                      │
│ Thread  M_T×N_T (reg fragment)        │  in : rA, rB fragment
│                                      │  out: rC fragment
└──────────────────────────────────────┘

2 각 레벨의 책임

Level관심사파일 예
Devicegrid launch, args 포장device/gemm_universal.h
Kerneltile scheduler, 경계kernel/gemm_universal.h
CTAmainloop · stage · smemthreadblock/mma_*.h
Warpmma inst · fragmentwarp/mma_tensor_op.h
Threaditerator · register tilethread/mma.h

3 입출력 타입 명시

Level입력 storage출력 storage
DeviceHBM ptrHBM ptr
KernelHBM ptr + tile idHBM 부분
CTAHBM → SMEM (async)SMEM → reg (epi)
WarpSMEM (via ldmatrix)Register
ThreadRegisterRegister

4 Template 계층 이름 규칙

// Device
cutlass::gemm::device::GemmUniversal<
  Kernel_>;
// Kernel
cutlass::gemm::kernel::GemmUniversal<
  CollectiveMainloop_, CollectiveEpilogue_,
  TileScheduler_>;
// Threadblock (2.x)
cutlass::gemm::threadblock::Mma<
  Shape_, Warp_, Policy_, Stages_>;

5 2.x vs 3.x 차이 ★

항목2.x (Ampere)3.x (Hopper+)
추상IteratorA/B + MmaCollective + CuTe
tile 표현Shape / PolicyCuTe Layout
asynccp.asyncTMA · WGMMA
warp spec없음Producer/Consumer

6 암기 키

D-K-C-W-T: 위에서 아래로 tile이 쪼개지고, 아래에서 위로 reduction이 모인다.

1 CTA tile 구성

ThreadblockShape Shape<M_C, N_C, K_C>. CTA 하나가 처리하는 출력 tile + K iteration chunk.
세대전형 tile (FP16)
A100 sm_80128×128×32 / 256×128×32
H100 sm_90a128×256×64 / 64×256×64
B200 sm_100128×256×64 (FP4 더 큼)

2 Mainloop 골격 ★

// CTA scope
acc = zero();
prologue_prefetch(S);  // 첫 S stage
for (k_iter : K/K_C - S) {
  issue_next_stage();   // async load
  wait_oldest_stage();
  warp_mma(acc, sA[oldest], sB[oldest]);
  advance_stage();
}
drain_remaining(S);     // 마지막 tiles
epilogue(acc);

3 SMEM 사용량 ★

smem = S · (MC·KC + KC·NC) · sz
   + epilogue buffer + mbarrier slots S : stages · sz : sizeof(dtypeA/B)
128×64 + 64×256 = 16384 elem · 7 stage · 2B(FP16) = 224 KB
→ H100 SMEM cap 228 KB 안쪽으로 fit.
SMEM 초과 → stage ↓ 또는 tile ↓ 강제. heuristic: ↗ §17 p18.

4 Predicate & 경계

  • M, N, K가 tile 배수 아님 → predicate mask
  • out-of-bound load는 0 또는 identity
  • 2.x: PredicatedTileIterator · 3.x: TMA descriptor가 처리
  • 경계 CTA는 K iteration도 다름 (split-K)

5 Stage count trade-off

S ↑효과
latency hide개선 (load·compute 겹침)
smemS배 증가 → occupancy ↓ 가능
mbarrierS개 필요 (Hopper)
registeraddr tracking 증가

6 SMEM 레이아웃 배치

smem:
 ┌─ sA[0] ──┐ ┌─ sA[1] ──┐ ... ┌─ sA[S-1] ──┐
 │ M_C×K_C  │ │ M_C×K_C  │     │ M_C×K_C    │
 ├─ sB[0] ──┤ ├─ sB[1] ──┤ ... ├─ sB[S-1] ──┤
 │ K_C×N_C  │ │ K_C×N_C  │     │ K_C×N_C    │
 └──────────┘ └──────────┘     └────────────┘
 ┌─ acc buffer (for epilogue split) ──┐
 └────────────────────────────────────┘

7 Warp-group 분할 (Hopper) 3.x

  • 1 producer WG + 1~2 consumer WG
  • producer: TMA load → mbarrier arrive
  • consumer: mbarrier wait → wgmma → epilogue
  • 상세 스케줄은 ↗ §15 p16

1 Warp tile 구성

WarpShape Shape<M_W, N_W, K_W>. 1개 warp (또는 warp-group)이 한 번에 다룰 출력 영역. CTA tile이 warp 수만큼 나뉘어 할당.
num_warpsM = MC / MW, num_warpsN = NC / NW
warps/CTA = numM · numN

2 MMA shape · warp tile 관계 ★

instm·n·k세대
mma.sync16·8·16 (FP16)Ampere
mma.sync16·8·8 (TF32)Ampere
mma.sync16·8·32 (INT8)Ampere
wgmma.mma_async64·N·16 (FP16)Hopper
wgmma.mma_async64·N·32 (FP8)Hopper

mma shape 전체 목록은 ↗ V03 §7 (Ampere) · ↗ V04 §8 (Hopper).

3 Warp tile → MMA 반복

itersm = MW / minst
itersn = NW / ninst
itersk = KW / kinst
MW=64, NW=128, KW=16, inst=16·8·16 (FP16)
→ iters = 4·16·1 = 64 mma per warp per K chunk.

4 MmaMultistage 구조 (2.x)

// Ampere warp-level multi-stage
class MmaMultistage {
  IteratorA iter_a;   // gmem→smem
  IteratorB iter_b;
  Policy::Warp warp_mma;
  int smem_stage = 0;

  void operator()(acc, k_iter) {
    prologue();
    for (k : k_iter) {
      ldmatrix(rA, sA[stage]);
      ldmatrix(rB, sB[stage]);
      warp_mma(acc, rA, rB, acc);
      advance();
    }
  }
};

5 Tensor Core fragment layout ★

Fragment 한 warp 32 thread에 분산 저장된 register tile. dtype + element 개수 + mma-shape별 고정 매핑.
operanddtypeelems/thread
A (16·8·16)FP168
B (16·8·16)FP164
C/D (16·8·16)FP324
A (WGMMA 64·N·16)FP16SMEM 직접

WGMMA는 A를 register 대신 SMEM으로 받을 수 있음 (A-descriptor).

6 ldmatrix ↔ fragment 정합

smem tile (16×16 FP16)
 ┌──────────────────────┐
 │ warp 32 thread 협력   │ ──► ldmatrix.x4
 │ 16B × 32 = 512B       │
 └──────────────────────┘
        │
        ▼
 register fragment (A shape)

ldmatrix.trans variant = 전치 로드. mma 입력 shape에 맞춰 thread-mapping 자동.

7 Accumulator 수명

  • K iteration 전체를 register에 누적
  • epilogue 직전에만 smem/HBM으로 flush
  • register pressure = MW·NW / 32 elem/thread · 4B
  • e.g. 64×128 FP32 acc → 64·128/32·4 = 1024 B/thread = 256 reg

1 Iterator 개념 2.x 추상

TileIterator gmem/smem 상의 tile을 순회하는 stride-aware 반복자. thread-map과 layout을 조합해 thread별 access pattern을 결정.
  • PredicatedTileIterator : 경계 mask 포함
  • RegularTileIterator : 경계 없음 (aligned)
  • MaskedTileIterator : 외부 mask 받음

2 IteratorA / IteratorB 템플릿

using IteratorA = PredicatedTileIterator<
  MatrixShape<M_C, K_C>,
  ElementA,
  LayoutA,
  AdvanceRank,   // K 축 전진
  ThreadMap,     // thread → elem 맵
  AccessSize>;   // vector width

3 ThreadMap 역할

thread_idx (0..NT-1) → (row, col) in tile
contiguous: vector width 단위로 row 먼저
strided: 32-thread 단위로 row 묶음

4 Thread-level 출력 fragment ★

level저장수명
gmem fragHBM전역
smem fragSMEM stageCTA mainloop
reg frag A/BRegistermma 1회
reg frag CRegisterK 전체 (acc)

5 Fragment 크기 계산

frag_elems = (tile_elems) / threads_per_scope
e.g. 16·8 (C frag) / 32 thread = 4 elem/thread
MMA 16·8·16 FP16→FP32:
A frag = 16·16/32 = 8 · FP16 = 16 B
B frag = 8·16/32 = 4 · FP16 = 8 B
C frag = 16·8/32 = 4 · FP32 = 16 B

6 Vectorized access

  • 128-bit (16 B) per thread가 이상적 (coalesced)
  • FP16 → 8 elem/thread vector load
  • FP32 → 4 elem/thread
  • access size < 128 bit이면 transaction 낭비

7 AccessSize trade-off

AccessSize효과
32 bitcoalesce ↓ · 유연
64 bit중간
128 bit최대 BW · alignment 제약

8 Register 압박 진단

reg/thread ≈ acc_frag + A_frag·iters + B_frag·iters + addr_tracking
제약 = 255 reg/thread (Ampere)
초과 → spill to local mem → 지연 폭증

9 2.x → 3.x 전환

CuTe 대체 3.x에서는 Iterator/ThreadMap이 CuTe Layout + local_partition으로 일반화. 수동 thread-map 작성 대신 layout 대수로 유도.

다음 페이지 §7부터 CuTe 본론.

1 문제의식 ★

CUTLASS 2.x의 한계 Iterator / ThreadMap / Policy가 각 kernel마다 수동으로 template 조합. tile 방식이 바뀌면 전부 재작성. 조합 폭발.
  • IteratorA, IteratorB, SmemIterator, AccumulatorIterator ...
  • each × (layout × dtype × shape × stage) 조합
  • indexing 수식을 template으로 인코딩 → 난독

2 CuTe의 한 줄 요약 shape+stride=Layout

CuTe 모든 tensor indexing을 Layout = (Shape, Stride) 하나로 통일. compose/divide/product 대수로 모든 partition을 표현.
Layout L : ℕrank → ℕ
L(c0, c1, ...) = Σ ci · stridei

3 Layout의 위력

작업2.xCuTe
tile 분할Iterator 템플릿logical_divide
thread 할당ThreadMaplocal_partition
swizzleSmemLayout 고정Swizzle composition
TMA desc별도 작성Layout 직변환
mma shape 매핑수동 fragmentTiledMMA partition

4 Hierarchical shape ★

Nested tuple shape ((M_1, M_2), (N_1, N_2)) 처럼 shape가 tuple의 tuple. 외부에서는 평면처럼, 내부에서는 sub-structure로 취급.
shape = (4, (2, 3))    → 4×(2·3) = 24 elem
stride = (6, (3, 1))   → col-major outer, row-major inner
                         4 그룹 × (2×3 tile)

5 같은 데이터, 여러 View

M·K 행렬을
(M, K) 평면,
((M/bM, bM), (K/bK, bK)) block view,
((thr_M, bM/thr_M), K) thread view
—stride만 바꾸면 메모리 원본 동일.

6 Compile-time vs runtime

종류표기
static (cxx)_128{}, Int<64>
dynamicint M
mixedmake_shape(M, _64{})

static 부분은 template 메타로 fold 되어 런타임 cost 0.

7 요구 지식

  • C++17 template + CTAD
  • integer_sequence / tuple 메타
  • GCD · 나눗셈 · coprime 개념
  • modular arithmetic (swizzle)

1 Layout 정의 ★

Layout<Shape, Stride> L(c) = Σ cᵢ · strideᵢ. 좌표 tuple → 정수 offset. Shape와 Stride는 동일 tree 구조.
Layout<(M,N), (1,M)>     → col-major  M×N
Layout<(M,N), (N,1)>     → row-major
Layout<(M,N), (1,0)>     → N축 broadcast
Layout<(M,N), (0,1)>     → M축 broadcast

2 make_* helper

auto s = make_shape(_8{}, _16{});
auto d = make_stride(_16{}, _1{});
auto L = make_layout(s, d);   // row-major 8×16

// 간이 ctor
auto L2 = make_layout(
   make_shape(_8{}, _16{}),
   LayoutRight{});              // row-major
auto L3 = make_layout(s, LayoutLeft{});  // col-major

3 size · cosize · rank

함수의미
size(L)Π shape (총 elem 수)
cosize(L)max(L(c)) + 1 (필요 memory)
rank(L)shape tuple 길이
depth(L)tuple nesting 깊이
Layout<(4,3),(1,4)>: size=12, cosize=12, rank=2, depth=1

4 Injective / Surjective

  • size == cosize → bijective (1:1)
  • size < cosize → "sparse" (내부 갭)
  • size > cosize → broadcast (동일 elem 공유, stride=0)

5 Hierarchical shape 예 ★

shape=((4,2),(8,4))
stride=((1,16),(4,32))

좌표 ((i,j),(k,l)):
L = i·1 + j·16 + k·4 + l·32
  → col-major 8×8 tile in row-major 4×8 super-grid

6 Column/Row major 레시피

레이아웃shape (M,N)stride
col-major(M,N)(1, M)
row-major(M,N)(N, 1)
3D col-major(M,N,K)(1, M, M·N)
3D row-major(M,N,K)(N·K, K, 1)
batched (MNK, B)((M,N,K), B)((...),M·N·K)

7 print_layout 출력 이해

// Layout<(4,3),(1,4)> col-major 4×3
     0   1   2
0    0   4   8
1    1   5   9
2    2   6  10
3    3   7  11
// cell 값 = L(i,j) offset

print_layout(L)·print_tensor(t) 디버그 필수.

1 연산 전체 표 ★★

opsignature의미
composition(A,B)L∘LA(B(c)) · 좌표 변환
complement(L,N)L→L'직교 basis로 size=N 채움
coalesce(L)L→L'정합 가능한 축 합병
logical_divide(L,T)→(T,Rest)Tile과 잔여 분해
logical_product(L,T)→replicateL을 T번 타일링
zipped_divide→divide + ziplogical_divide + 축 재배열
tiled_divide→flat(divide)divide 후 flat
flat_divide→(T,R) flatshallow 결과
inverse(L)L-1bijective 역
right_inverse(L)부분역L(Lr-1)=id

2 composition ∘ 함수 합성

(A ∘ B)(c) = A(B(c)) B의 출력이 A의 좌표가 돼야 함 (cosize(B) ≤ size(A))
A = (8):(1), B = (4):(2)
→ A∘B = (4):(2) (0,2,4,6)

3 logical_divide ★

divide(L, T) = (T, R) 쌍으로 재배열
L의 shape가 T·R 로 factorize 가능해야 함
L = (24):(1), T = (4):(1)
((4),(6)) : ((1),(4)) : tile-내·tile-외 분리

4 logical_product

product(L, T) : L을 "기본 block"으로 두고 T번 stride

= replicate L tile. GEMM의 CTA → 전체 grid 복원에 사용.

5 coalesce · complement 보조

coalesce( ((2,4),(1,2)) ) = (8):(1)
complement( (4):(4), 16 ) = (4):(1)
  • coalesce: stride가 정확히 연속되는 축 병합 → 연산 단순화
  • complement: bijective 유지하며 빈 stride 채움

6 zipped_divide 사용 예 ★

// CTA tile → thread partition
auto cta_tile = Layout<Shape<_128,_64>,
                         Stride<_64,_1>>{};
auto thr_tile = Layout<Shape<_32,_8>,
                         Stride<_8,_1>>{};
auto ztile = zipped_divide(
      cta_tile, thr_tile);
// ((thr_M, thr_N), (rest_M, rest_N))
// 각 thread가 자신의 rest tile 받음

7 조합 법칙 요약

법칙
associativity(A∘B)∘C = A∘(B∘C)
identityid∘L = L∘id = L
divide · productdivide ∘ product = id (조건부)
non-commutA∘B ≠ B∘A (일반)

8 디버깅 습관

print_layout으로 op 전·후를 수식 대신 그림으로 확인. CuTe tutorial의 첫 번째 조언.

1 Tensor = Engine + Layout

Tensor<Engine, Layout> Engine은 ptr 래퍼 (gmem/smem/rmem). Layout은 indexing. T(c)engine[layout(c)].
Engine저장
gmem_ptr<T>HBM
smem_ptr<T>SMEM
rmem_ptr<T>Register
ArrayEngineC++ array (stack)

2 make_tensor 사용

auto gA = make_tensor(
   make_gmem_ptr(A_ptr),
   make_shape(M, K),
   make_stride(K, _1{}));

auto sA = make_tensor(
   make_smem_ptr(smemA),
   Layout<Shape<_128,_64>,
          Stride<_64,_1>>{});

auto rA = make_tensor<half>(
   Layout<Shape<_8>>{});   // register

3 local_tile ★

local_tile(T, tile_shape, coord) 전역 tensor T에서 tile 크기의 sub-tensor를 coord로 인덱싱. 반환은 view (메모리 복사 없음).
// 128×64 CTA tile 한 개
auto tA = local_tile(gA,
   make_shape(_128{}, _64{}),
   make_coord(m_blk, k_blk));

4 local_partition ★

local_partition(T, thr_layout, thr_id) tile을 thread-layout으로 분할, thr_id의 몫을 반환. thread별 access pattern 자동 유도.
// 32×8 thread layout (256 thr/CTA)
auto thr_layout = Layout<
   Shape<_32,_8>,
   Stride<_8,_1>>{};
auto tAgA = local_partition(
   tA, thr_layout, threadIdx.x);

5 copy / gemm 호출

// Copy atom으로 gmem → smem
copy(copy_atom,
     thr_copy.partition_S(gA_tile),
     thr_copy.partition_D(sA));
cp_async_fence();
cp_async_wait<0>();

// TiledMMA로 연산
auto tCrA = thr_mma.partition_A(sA);
auto tCrB = thr_mma.partition_B(sB);
auto tCrC = thr_mma.partition_C(rC);
gemm(tiled_mma, tCrA, tCrB, tCrC);

6 Iterator 패턴

기능기존 2.xCuTe
tile 전진iter++coord 수동
경계 checkPredicateTMA / crd2idx
thread 분배ThreadMaplocal_partition
shape 변환Transformcomposition

7 Helper 라이브러리

  • make_tensor / make_layout
  • partition_A/B/C (TiledMMA slice)
  • partition_S/D (TiledCopy)
  • recast<T> (dtype 재해석, e.g. uint32 → 2×half)
  • flatten, group_modes
  • print_tensor, print_layout

1 왜 필요한가 bank conflict 회피

SMEM bank 32 bank × 4 B word. 같은 bank에 동시 access 하면 n-way conflict → n cycle serialized.
  • row-major FP16 row = 128 B = 32 word → 한 row가 32 bank 전부 hit
  • 같은 col 접근은 stride 32 word → 모두 bank 0에 몰림
  • ldmatrix / wgmma는 bank conflict 없는 layout 요구

2 Swizzle 함수 정의 ★

Swizzle<B, M, S>(offset) offset의 비트 [M+S+B−1 : M+S][M+B−1 : M]XOR.
offset' = offset XOR (
  ((offset >> (M+S)) & ((1<<B)−1)) << M ) B : XOR 비트 수 · M : 보존할 하위 bit · S : shift (XOR 원천 bit 위치)

3 파라미터 의미 표 ★

param역할전형
BXOR 영역 bit 수 (= swizzle 분산 배수 log₂)3
M보존 lower bit (= atomic unit log₂)3 (8 elem 묶음)
SXOR 상위 bit 시작 위치 (= row stride log₂)3

Swizzle<3,3,3> = 128 B line 내 8×8 block 재배치. WGMMA A/B 기본.

4 시각적 예시 ★ Swizzle<2,0,2>

원본 offset(6-bit):
 bit:  5 4 3 2 1 0
 ─────────────────
 S=2, B=2, M=0
 → XOR region [3:2] with [1:0]

offset  bin     swizzled
 0      000000  000000  (row 0, col 0)
 1      000001  000001  (row 0, col 1)
 4      000100  000100  (row 1, col 0)
 5      000101  000101  (row 1, col 1)
 8      001000  001010  (row 2, col 0 → 2)
 9      001001  001011  (row 2, col 1 → 3)
12      001100  001110
16      010000  010100
20      010100  010000  (swapped!)

5 표준 preset (CUTLASS)

preset<B,M,S>용도
NoSwizzle<0,0,0>identity
B32_M3_S3<3,3,3>128B swizzle
B64_M4_S3<3,4,3>128B wide
B128_M5_S3<3,5,3>Hopper WGMMA

preset 이름의 숫자는 보통 swizzle chunk 크기 바이트.

6 Layout composition

// 기본 SMEM layout 위에 swizzle 덮어쓰기
using SmemLayoutA = decltype(
  composition(
    Swizzle<3, 3, 3>{},
    Layout<Shape<_64,_64>,
           Stride<_64,_1>>{}));

// 좌표 (i,j) → swizzled offset
offset = SmemLayoutA(i, j);
// ldmatrix / wgmma descriptor에도 swizzle 인코딩

7 WGMMA descriptor 제약

WGMMA는 <B,M,S>의 일부 조합만 인식. 허용되지 않는 swizzle이면 silent wrong result. CUTLASS builder가 안전 조합 선택.

1 TMA 요약 Tensor Memory Accelerator

TMA Hopper SM90에서 추가된 HW unit. 1 thread가 다차원 tile 로드 발행, descriptor에 담긴 shape·stride·swizzle을 HW가 해석.
  • PTX: cp.async.bulk.tensor.{1,2,3,4,5}d
  • mode: tile, im2col
  • 완료: mbarrier arrive (async proxy)

TMA PTX 전체는 ↗ V04 §5·§6.

2 Tensor descriptor 구조

field의미
base_ptrglobal addr
global_shape전체 M, N, ...
global_stridebyte stride
box_shapetile 크기
element_stride통상 1
swizzleSMEM swizzle mode
dtypeFP16/BF16/FP8/...

3 CuTe → TMA 생성 ★

// CuTe Layout이 TMA descriptor로 직변환
auto tma_a = make_tma_copy(
   SM90_TMA_LOAD{},
   gA_tensor,                    // global Tensor
   SmemLayoutA{},                // swizzled smem
   make_shape(_128{}, _64{}));   // box

// kernel 내에서 partition
auto thr_tma = tma_a.get_slice(block_id);
auto tAgA = thr_tma.partition_S(gA);
auto tAsA = thr_tma.partition_D(sA);

4 Launch 흐름

Host :
 ─ make_tma_copy  ──► TMA desc struct
                      (global tensor map)
 ─ kernel<<<...>>>(desc, ...)

Device :
 thread 0 만 발행:
 ─ cp.async.bulk.tensor (desc, coord, smem, bar)
 ─ HW가 shape/stride/swizzle 해석
 ─ 완료시 bar.arrive(expected_bytes)
 모든 thread:
 ─ bar.wait(parity)

5 expected_bytes 수식

exp_bytes = box_M · box_N · sz(dtype) producer가 arrive 시 전달 · mbarrier가 완료 판정

여러 TMA가 같은 bar 사용시 exp_bytes는 합산.

6 multicast TMA ★

TMA multicast (cluster) cluster 내 여러 CTA의 SMEM에 동시 로드. A tile을 2 CTA가 공유하면 HBM traffic 절반.
auto tma_a_mc = make_tma_copy(
   SM90_TMA_LOAD_MULTICAST{},
   gA, SmemLayoutA{},
   CtaShape{},
   cluster_shape);

multicast는 cluster ≥ 2 + same A tile CTAs.

7 제약 · 함정

  • box_shape가 descriptor의 global_shape 안에 fit 해야
  • SMEM alignment: 16 B (WGMMA) · 128 B (swizzle chunk)
  • per-tensor descriptor 수 제한 (RT API cuTensorMapEncodeTiled)
  • cluster launch 실패시 swizzle 조합 재검토

1 3.x의 재구성

Collective Threadblock / Warp / Thread 계층을 "공동 작업 단위"로 합친 추상. CollectiveMainloop + CollectiveEpilogue + TileScheduler.
Kernel<...>
 ├ CollectiveMainloop  (K iter · stages · wgmma)
 ├ CollectiveEpilogue  (acc → epilogue → gD)
 └ TileScheduler       (persistent · streamk · default)

2 Dispatch Policy ★

Policy의미
KernelTmaWarpSpecialized1P + 1C WG
..Cooperative1P + 2C (동일 tile 분할)
..Pingpong1P + 2C (교대 tile)
KernelMultistageAmpere cp.async

sm_90a에서만 TmaWarpSpecialized 선택 가능.

3 Builder 패턴 ★

// High-level auto-select
using Mainloop =
  cutlass::gemm::collective::CollectiveBuilder<
    ArchTag,             // SM90
    OperatorClass,       // TensorOp
    ElementA, LayoutA, AlignA,
    ElementB, LayoutB, AlignB,
    ElementAccumulator,
    TileShape_MNK,
    ClusterShape_MNK,
    StageCountAuto,
    KernelScheduleAuto
  >::CollectiveOp;

4 대칭 Epilogue

using Epilogue =
  cutlass::epilogue::collective::CollectiveBuilder<
    ArchTag, OperatorClass,
    TileShape_MNK, ClusterShape_MNK,
    EpilogueTileType,
    ElementAccumulator, ElementCompute,
    ElementC, LayoutC, AlignC,
    ElementD, LayoutD, AlignD,
    EpilogueScheduleAuto,
    FusionCallbacks        // EVT 연결
  >::CollectiveOp;

5 TileScheduler 선택

Scheduler동작
Defaultrow-major tile iter
Persistentlaunch = SM 수, work-stealing
StreamKK split, 잔여 fixup kernel ↗ V05 §8

6 Kernel 조립

using GemmKernel =
  cutlass::gemm::kernel::GemmUniversal<
    ProblemShape,
    Mainloop,
    Epilogue,
    TileScheduler
  >;

using Gemm =
  cutlass::gemm::device::GemmUniversalAdapter<
    GemmKernel>;

7 장점 요약

  • 2.x 템플릿 대비 1/3 수준 boilerplate
  • CuTe layout이 warp·thread까지 관통
  • 새 기능(WGMMA·TMA·cluster) 1곳에 캡슐화
  • autoselection: Auto tag

1 PipelineTmaAsync 개념 ★

PipelineTmaAsync<Stages> TMA 비동기 load와 consumer WGMMA를 ring buffer로 잇는 primitive. mbarrier slot = Stages개.
stage:  0   1   2   ...  S-1   (ring)
 full :  F   E   E         E   (producer arrive)
 empty:  E   F   F         F   (consumer arrive)

 producer: wait(empty) → TMA → arrive(full)
 consumer: wait(full)  → wgmma → arrive(empty)

2 mbarrier 짝 full/empty

slotfull barempty bar
0P→C signalC→P signal
1P→CC→P
...S개S개

mbarrier 상세는 ↗ V04 §7.

3 Stage count 선택

Tload(tile) : HBM BW 제한 시간
Tmma(tile) : WGMMA 시간
Smin = ⌈ Tload / Tmma ⌉ + 1
tile 128×64×64 FP16, H100
Tload ≈ 0.6 µs, Tmma ≈ 0.15 µs
→ S ≥ 5 추천 (실전 6~7)

4 SMEM 한계와 균형

smem_total = S · (bM·bK + bK·bN) · sz + C
H100 per-SM limit = 228 KB C : epilogue 버퍼 + mbarrier + reserved
  • S ↑ → latency hide ↑ · smem ↑
  • S ↑ → occupancy ↓ 가능 (CTA/SM ↓)
  • heuristic: ↗ §17 p18

5 Producer / Consumer 분리

WG역할reg budget
producer (1 WG)TMA issue · signalsetmaxnreg.dec(40)
consumer (1~2 WG)WGMMA · epiloguesetmaxnreg.inc(232)

setmaxnreg PTX · producer 레지스터 해방 → consumer 더 받음.

6 Async proxy fence

  • WGMMA는 async proxy: 일반 SMEM write와 visibility 별개
  • producer의 TMA 완료 → mbarrier arrive fence 역할
  • consumer는 wait 후 wgmma 가능
  • fence.proxy.async.shared::cta 필요 케이스 주의

7 공용 Pipeline 변형

type용도
PipelineTmaAsyncTMA + WGMMA (Hopper)
PipelineAsynccp.async (Ampere)
PipelineTransactionexpected_bytes 기반
OrderedSequenceBarrierPingpong WG 순서

1 왜 2 consumer WG인가

  • WGMMA는 async → 발행 후 return 즉시
  • 한 WG만 있으면 accumulator 비어있는 중 epilogue 대기
  • 2 WG이 교대 또는 분할하면 compute hole 제거

2 Pingpong 구조 ★

Time →
WG1: [load+mma Mi   ][epilogue ]
WG2:           [load+mma Mi+1][epilogue]
WG1:                      [load+mma Mi+2][epi]
WG2:                                 [mma Mi+3][epi]
         ↑ 교대: 서로 다른 M tile

epilogue와 mainloop가 다른 WG에서 겹침. 큰 epilogue에 유리.

3 Cooperative 구조 ★

Time →
WG1: [mma Mi, N0..N/2-1    ][epi L]
WG2: [mma Mi, N/2..N-1      ][epi R]
     ↑ 같은 M tile을 N 축으로 분할
     acc 합쳐서 epilogue 동시

같은 M tile 처리. 큰 N에 유리 (N/2씩 담당).

4 비교 표 ★

항목PingpongCooperative
WG 역할다른 M tile 교대같은 tile 분할
N 분할전체N/2
EpilogueWG 각자같이 + sync
SMEMepilogue buf 2×공유
큰 epilogue유리overhead
큰 N보통유리
작은 GEMM 수보통유리

5 OrderedSequenceBarrier

OSB Pingpong 스케줄에서 두 WG의 epilogue 순서를 token passing으로 강제. epilogue가 같은 C를 쓰는 경우 안전.
// WG1 먼저 epi, WG2 나중
osb.wait(tok_epi, wg_id);
epilogue(acc);
osb.arrive(tok_epi);

6 선택 결정 트리 ★

시작
 │
 ├─ N < 128 ?  ──Y──► Pingpong
 │     (큰 epi가 relative 비용 ↑)
 │
 ├─ epilogue gemma fuse 복잡?
 │     ──Y──► Pingpong (WG별 독립)
 │
 ├─ N ≥ 256 & 작은 batch?
 │     ──Y──► Cooperative
 │
 └─ default ──► Cooperative
               (CUTLASS 권장)

7 주의 ★

stage · tile · cluster 값이 스케줄마다 지원 범위 다름. CollectiveBuilder의 Auto가 안전 조합 반환. 수동 지정 시 compile-time assert 확인.

1 Epilogue의 자리 α·AB+β·C→fuse

Epilogue K iteration 종료 후 register-resident accumulator에 적용되는 element-wise 변환 + C/D write-back. D = g(α·AB, β·C, bias, ...).
  • 입력: FP32 acc fragment
  • 출력: dtypeD (BF16/FP8/...)
  • fuse가 별도 kernel launch 제거

2 표준 chain

acc (FP32)
  │ · α
  ▼
  + β · C   (load gmem)
  │
  ▼
  + bias    (load broadcast)
  │
  ▼
  activation (ReLU/GELU/Sigmoid)
  │
  ▼
  quantize/cast → dtype_D
  │
  ▼
  store D (gmem)

3 일반적인 fused ops 표 ★

op입력비고
Linear Combinationα, β기본
Bias Addbias[N]broadcast
ReLU / GELU / Siluacctanh approx 택
Scale → FP8scaleper-tensor amax
Dequant (INT→FP)scaleper-channel
Residual Addgmem 입력skip-conn
Clampmin/maxbf16 overflow 방지

Fusion 일반론은 ↗ V13 §8. 여기선 CUTLASS epilogue 한정.

4 CTA 내 flush 단계

  1. acc fragment → smem 전송 (stride 재정렬)
  2. smem → reg 재로드 (output layout)
  3. op chain 적용
  4. gmem store (TMA_STORE or vector st)

smem-staged epilogue가 output coalescing과 layout 재배열에 유리.

5 Register vs SMEM epilogue

방식장점단점
Register onlylatency ↓layout 유연성 ↓
SMEM stagedcoalesce 용이SMEM 사용
TMA storedescriptor 재활용alignment 제약

6 순서 규칙

bias/activation은 FP32 domain에서. cast/quant는 마지막. 순서 섞이면 precision loss.
  • residual add도 FP32 또는 BF16 (dtype 맞출 것)
  • quantize scale은 ReLU/GELU

7 FP8 GEMM의 epilogue

acc(FP32)
  ├ · alpha (FP32)
  ├ + bias (FP32)
  ├ activation
  ├ · output_scale (FP32)  ← per-tensor
  └ cast → E4M3 (clamp ±448)

amax history update도 epilogue 한 노드.

1 원칙 ★

  • 큰 tile = HBM 재사용 ↑, SMEM 압박 ↑, 작은 GEMM 맞추기 어려움
  • 작은 tile = 작은 GEMM 유연, 재사용 ↓
  • "M·N가 tile의 배수가 되는 최소 tile" 선호
  • stage ≥ 3 기본 (H100 5~7)

2 A100 (sm_80) 권장 표 ★

M×N×Ktile M·N·Kstage
4096²×4096 FP16256·128·323
2048×4096×1024128·128·324
512×4096×102464·128·645
M≤256 skinny64·128·32 + SplitK3
K ≤ 64128·128·K2

3 H100 (sm_90a) 권장 표 ★

M×N×Ktileclustersched
8192² FP16128·256·64(2,1,1)Coop
8192² FP8128·256·128(2,1,1)Coop
4096×8192 FP16128·128·64(2,2,1)Coop
512×8192 FP1664·256·64(1,1,1)Pingpong
큰 epilogue128·128·64(1,1,1)Pingpong
skinny M64·128·64 + StreamK(1,1,1)default

4 Cluster shape 규칙

cluster_M · cluster_N ≤ 8 (H100 limit)
grid_M mod cluster_M = 0
grid_N mod cluster_N = 0
cluster ≥ 2 이면 M 또는 N이 cluster shape의 배수여야 launch 통과.

5 M, N 크기별 휴리스틱 ★

형상전략
M=N, 둘 다 큼큰 tile + Cooperative
M≫N (tall)bigger M, SplitK 가능
M≪N (wide)bigger N, Cooperative
M 작음, N 큼Pingpong + 작은 N tile
M, N 둘 다 작음StreamK + persistent
K 매우 큼deep stage (7+)
K 작음stage ↓ (2~3)

6 L2 capacity 고려

L2 resident ≈ bM·K + K·bN + bM·bN · sz
H100 L2 = 50 MB · 여러 CTA 공유 대상

같은 B가 여러 CTA 재사용 되면 L2 hit 기대. tile 순서 (swizzle scheduler)가 이 hit에 영향.

7 실전 결정 흐름

1. M, N, K 구하기
2. tile 후보 2~3개 고르기
   · H100: (128,256,64), (128,128,64), (64,256,64)
3. cluster (1,1) or (2,1)
4. SMEM budget 확인 → stage 조정
5. Coop vs Pingpong
6. autotune 시 이 4~6 조합 시도

1 Batched GEMM 정의

Batched 동일 M, N, K GEMM이 B개. A, B, C 각각 rank-3 tensor (B, M, K), (B, K, N), (B, M, N).
iter b = 0..B-1 :
   D[b] = α·A[b]·B[b] + β·C[b]

2 Grouped GEMM 정의 ★

Grouped 서로 다른 (M_g, N_g, K_g)의 GEMM G개를 한 번의 launch로. ragged shape. MoE expert별 FFN이 대표.

MoE 맥락 상세는 ↗ V08 §3·§4.

3 자료구조 비교 표 ★

방식A 포맷launch
Batched stride(B,M,K) contiguous1
Batched ptr-arrayptr[B]1
Groupedper-group ptr + shape1
loop cublas각기G

4 GroupedGemm kernel 구조

// host
problem_sizes = [(M0,N0,K0), ...];
ptr_A = [A_0, A_1, ...];
ptr_B, ptr_C, ptr_D = ...;
ld_A, ld_B, ld_C, ld_D = ...;

// kernel 내
int g = tile_scheduler.group_id();
auto [M,N,K] = problem_sizes[g];
run_gemm(ptr_A[g], ptr_B[g], ..., M,N,K);

5 TileScheduler 분배

  • 각 group의 tile 수 = ⌈M/bM⌉·⌈N/bN⌉
  • 전체 tile을 sequential ID로 매핑
  • persistent kernel이 tile → group 역매핑
  • 작은 group이 많으면 tail loss ↓

6 이점 · 한계

BatchedGrouped
shape동일각기
load balance자동shape 분산에 민감
launch11
코드단순복잡
MoE 적합xo

7 권당 한계

MoE dispatch/gate/permute는 본 권 범위 밖. GEMM 호출부만 다룸. 상세 ↗ V08 §2·§3.

1 Mixed-dtype의 의미

Mixed-dtype dtype(A) ≠ dtype(B). 입력 중 하나가 quantized (INT4/INT8/FP4/FP6)이고, 다른 하나는 FP8/BF16. accumulator는 FP32.
  • FP8(E4M3) × BF16 — decode path
  • INT4 × FP16 — W4A16 (AWQ/GPTQ)
  • FP4 × BF16 — Blackwell 추론

양자화 알고리즘 상세는 ↗ V10 §5·§6·§7.

2 수식 흐름

Âij = sA · (Aijquant − zA)
acc += Âik · Bkj
D = epilogue(acc, α, β, C, ...) s : scale, z : zero-point. per-tensor/channel/group 중 하나

3 양자화 granularity 표 ★

granularityscale 형상
per-tensorscalarFP8 (delayed scale)
per-channel(N,) or (M,)W8A8 INT8
per-token(M,) runtimesmooth-quant act
per-group (G=128)(K/G, N)AWQ W4
block-scale(K/B,) E8M0MXFP8

4 Dequant 삽입 위치

SMEM (quant)
  │
  │ ldmatrix → reg (quant)
  ▼
  dequant(scale, zp) ★  ← 여기 (inline, before mma)
  │
  ▼
 reg (bf16 / fp16 / fp8)
  │
  ▼
 mma(acc, Â, B)

scale fetch는 별도 TMA/cp.async로 prefetch.

5 Scale layout 이슈

  • per-group scale: (K/G, N) tensor · N축 broadcast
  • SMEM에서 K 단위로 묶어 저장 → iter마다 1회 load
  • swizzle 불필요 (작은 tile)
  • scale dtype: E8M0 (MXFP8 exp) / FP8 / FP16 다양

6 INT4 packing

INT4 pack 2 element / byte. CuTe에서는 recast<uint8_t> 후 bit-unpack kernel 삽입.
packed (uint8):  0xBA
                   │└── low nibble: A (int4)
                   └── high nibble: B (int4)

7 FP8 GEMM 특이점 ★

항목E4M3E5M2
range±448±57344
precision높음낮음
용도forward/outputgradient
scale 관리per-tensor amaxper-tensor amax

FP8 bit 상세 ↗ V09 §4.

1 EVT의 동기

EVT epilogue op chain을 tree-structured DSL로 기술. 각 노드는 Load/Compute/Store. 여러 연산을 단일 kernel에 컴파일타임 합성.
  • LinearCombination 만으로 안 되는 복합 fusion
  • multi-output (D, activation stats 등)
  • auxiliary load (bias, residual, scale)

2 3 종 노드

노드역할
Loadaux input 로드 (bias/C/scale)
Computeelement-wise op (add/mul/act)
Storegmem/aux out 저장

3 Tree 구조 예 ★

          Store(D)
            │
         Compute(cast→D)
            │
         Compute(scale_D ·)
            │
         Compute(GELU)
            │
         Compute(+)
          ┌─┴─┐
       Compute    Load(bias)
        (·α)
          │
        acc         

4 DSL 표기

// pseudo-syntax
using EVT = Sm90EVT<
  Sm90Compute<cast, D_t, F32>,
  Sm90EVT<
    Sm90Compute<mul, F32, F32>,
    Sm90EVT<
      Sm90Compute<gelu, F32, F32>,
      Sm90EVT<
        Sm90Compute<add, F32, F32>,
        Sm90ScalarBroadcast<alpha>,
        Sm90RowBroadcast<bias>
      >
    >,
    Sm90ScalarBroadcast<scale_D>
  >
>;

5 자주 쓰는 사전 조합

이름의미
LinearCombinationα·AB + β·C
LinCombBias+ bias
LinCombBiasRelu+ bias → ReLU
LinCombGeluGELU(α·AB+β·C)
LinCombBiasEltwise+ bias → act → residual
PerRowLinCombrow 별 α·β

6 FP8 fused scale 예

// 입력: FP8 acc, 출력: FP8 + amax_next
using EVT_FP8 = Sm90EVT<
  Sm90Compute<cast, E4M3, F32>,
  Sm90EVT<
    Sm90Compute<mul, F32, F32>,  // · scale_D
    Sm90Compute<amax_update,     // side-effect
                F32, F32>
  >>;

7 제약

  • tree는 compile-time 고정
  • 동적 op 선택 시 다중 kernel 생성
  • aux load는 SMEM 증가 → stage 재계산
  • 노드 간 dtype 일관성은 수동 cast 명시

1 Layout 조작 10선 ★

용도호출
col-majormake_layout(s, LayoutLeft{})
row-majormake_layout(s, LayoutRight{})
tile 꺼냄local_tile(T, tile, coord)
thread 분배local_partition(T, thr, tid)
MMA slicethr_mma.partition_{A,B,C}(T)
Copy slicethr_copy.partition_{S,D}(T)
dtype 재해석recast<U>(T)
축 평탄화flatten(L)
축 묶기group_modes<lo,hi>(L)
swizzle 적용composition(Swizzle{}, L)

2 수식 상자

L(c) = Σ cᵢ · sᵢ
(A∘B)(c) = A(B(c))
size(L) = Π shape
cosize(L) = max L + 1
smem = S·(bM·bK + bK·bN)·sz

3 CUTLASS 3.x 빌더 template ★

using Mainloop = CollectiveBuilder<
  SM90, TensorOp,
  half, RowMajor, 8,
  half, ColMajor, 8,
  float,
  Shape<_128,_256,_64>,
  Shape<_2,_1,_1>,
  StageCountAuto,
  KernelTmaWarpSpecializedCooperative
>::CollectiveOp;

using Epilogue = EpilogueBuilder<
  SM90, TensorOp,
  Shape<_128,_256,_64>,
  Shape<_2,_1,_1>,
  EpilogueTileAuto,
  float, float,
  half, RowMajor, 8,
  half, RowMajor, 8,
  EpilogueScheduleAuto,
  LinearCombination
>::CollectiveOp;

using Kernel = GemmUniversal<
  Shape<int,int,int>,
  Mainloop, Epilogue,
  PersistentScheduler>;

4 5-Level D·K·C·W·T Device·Kernel·CTA·Warp·Thread

Device  M·N·K      gemm_universal
 Kernel  grid      TileScheduler
  CTA    bM·bN·bK  mainloop · stage
   Warp  wM·wN·wK  (wg)mma iter
    Thread fragA/B/C reg

5 스케줄 결정 1-liner

  • 큰 N·작은 GEMM수 → Cooperative
  • 큰 epilogue / 작은 N → Pingpong
  • skinny M·N → StreamK
  • default → Persistent + Coop

6 디버그 경로

  1. print_layout / print_tensor
  2. compile-time assert: shape / alignment / dtype
  3. silent wrong result → swizzle 조합 확인
  4. hang → mbarrier expected_bytes / arrive count

7 함정 체크

sm_90a (a 포함) · cluster 배수 · align 16B · WGMMA swizzle 허용 조합 · setmaxnreg 쌍. 이 다섯이 대부분의 launch 실패 원인.