| layout | A stride | B stride | 해석 |
|---|---|---|---|
| NN | col-major | col-major | A(M×K) · B(K×N) |
| TN | row-major A | col-major B | Aᵀ 입력 |
| NT | col-major A | row-major B | Bᵀ 입력 |
| TT | row-major | row-major | 둘 다 전치 |
CUTLASS는 LayoutA/LayoutB/LayoutC로 명시. cf. cuBLAS convention.
| A | B | Acc | D | 용례 |
|---|---|---|---|---|
| FP16 | FP16 | FP32 | FP16/FP32 | 학습·추론 |
| BF16 | BF16 | FP32 | BF16 | 학습 표준 |
| TF32 | TF32 | FP32 | FP32 | A100 FP32 가속 |
| FP8 E4M3 | FP8 E4M3 | FP32 | BF16/FP8 | H100 추론 |
| FP8 E5M2 | FP8 E5M2 | FP32 | FP32 | H100 gradient |
| INT8 | INT8 | INT32 | INT32/FP16 | 양자화 ↗ V10 §4 |
| FP4 | FP4 | FP32 | BF16 | Blackwell 추론 |
dtype 조합·bit layout 근거는 ↗ V09 §3·§4.
Roofline 계산 원리는 ↗ V18 §3.
Default(둘 다), NoBetaScaling(β=1), OnlyAlphaScaling(β=0), Nothing(α=1,β=0).
epilogue scale은 출력 직전 FP32 acc에서 수행 (precision 보존).
// row-major, no tiling for (m = 0; m < M; ++m) for (n = 0; n < N; ++n) { acc = 0; for (k = 0; k < K; ++k) acc += A[m,k] * B[k,n]; D[m,n] = alpha*acc + beta*C[m,n]; }
HBM ──► L2 ──► SMEM ──► Register ↑ ↑ ↑ ↑ M,N,K CTA warp thread 전체 tile tile fragment
// CTA(M_C, N_C), stages=S for (k0 = 0; k0 < K; k0 += K_C) { // load to smem (cp.async / TMA) load_tile(sA[k0%S], gA[m_blk, k0 : k0+K_C]); load_tile(sB[k0%S], gB[k0 : k0+K_C, n_blk]); __syncthreads(); for (kw : warp_K_iter) { rA = ldmatrix(sA, kw); rB = ldmatrix(sB, kw); mma(acc, rA, rB); // tensor core } } epilogue(gD, acc, alpha, beta, gC);
tile이 클수록 HBM traffic ↓ 그러나 SMEM · reg pressure ↑.
| stage | smem | 특성 |
|---|---|---|
| 1 | 1 tile | load⇄compute 직렬 |
| 2 | 2 tiles | double buffer · 기본 |
| 3~7 | S tiles | multistage · deep prefetch |
| 계층 | A100 BW | 배율 |
|---|---|---|
| HBM2e | ~1.5 TB/s | 1× |
| L2 | ~4.7 TB/s | ~3× |
| SMEM / L1 | ~19 TB/s aggregate | ~13× |
| Register | ~500 TB/s | ~300× |
A100 whitepaper · SMEM은 SM 당 ~140 GB/s × 108 SM. 계층 상세 ↗ V02 §3·§5.
┌──────────────────────────────────────┐ │ Device M×N×K (전체 문제) │ in : A(M,K), B(K,N), C(M,N) │ │ out: D(M,N) │ ▼ gemm_universal launches grid │ │ │ │ Kernel grid + TileScheduler │ in : problem_shape, tile_shape │ │ out: per-CTA (m_blk, n_blk, k_iter) │ ▼ CTA = tile 1개 담당 │ │ │ │ CTA M_C×N_C × (k0..K) stages │ in : gA tile, gB tile │ │ out: CTA acc fragment (smem/reg) │ ▼ 여러 warp(group)가 분담 │ │ │ │ Warp M_W×N_W × K_W │ in : sA slice, sB slice │ │ out: warp acc (register) │ ▼ warp = mma instruction 묶음 │ │ │ │ Thread M_T×N_T (reg fragment) │ in : rA, rB fragment │ │ out: rC fragment └──────────────────────────────────────┘
| Level | 관심사 | 파일 예 |
|---|---|---|
| Device | grid launch, args 포장 | device/gemm_universal.h |
| Kernel | tile scheduler, 경계 | kernel/gemm_universal.h |
| CTA | mainloop · stage · smem | threadblock/mma_*.h |
| Warp | mma inst · fragment | warp/mma_tensor_op.h |
| Thread | iterator · register tile | thread/mma.h |
| Level | 입력 storage | 출력 storage |
|---|---|---|
| Device | HBM ptr | HBM ptr |
| Kernel | HBM ptr + tile id | HBM 부분 |
| CTA | HBM → SMEM (async) | SMEM → reg (epi) |
| Warp | SMEM (via ldmatrix) | Register |
| Thread | Register | Register |
// Device cutlass::gemm::device::GemmUniversal< Kernel_>; // Kernel cutlass::gemm::kernel::GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_>; // Threadblock (2.x) cutlass::gemm::threadblock::Mma< Shape_, Warp_, Policy_, Stages_>;
| 항목 | 2.x (Ampere) | 3.x (Hopper+) |
|---|---|---|
| 추상 | IteratorA/B + Mma | Collective + CuTe |
| tile 표현 | Shape / Policy | CuTe Layout |
| async | cp.async | TMA · WGMMA |
| warp spec | 없음 | Producer/Consumer |
Shape<M_C, N_C, K_C>. CTA 하나가 처리하는 출력 tile + K iteration chunk.
| 세대 | 전형 tile (FP16) |
|---|---|
| A100 sm_80 | 128×128×32 / 256×128×32 |
| H100 sm_90a | 128×256×64 / 64×256×64 |
| B200 sm_100 | 128×256×64 (FP4 더 큼) |
// CTA scope acc = zero(); prologue_prefetch(S); // 첫 S stage for (k_iter : K/K_C - S) { issue_next_stage(); // async load wait_oldest_stage(); warp_mma(acc, sA[oldest], sB[oldest]); advance_stage(); } drain_remaining(S); // 마지막 tiles epilogue(acc);
PredicatedTileIterator · 3.x: TMA descriptor가 처리| S ↑ | 효과 |
|---|---|
| latency hide | 개선 (load·compute 겹침) |
| smem | S배 증가 → occupancy ↓ 가능 |
| mbarrier | S개 필요 (Hopper) |
| register | addr tracking 증가 |
smem: ┌─ sA[0] ──┐ ┌─ sA[1] ──┐ ... ┌─ sA[S-1] ──┐ │ M_C×K_C │ │ M_C×K_C │ │ M_C×K_C │ ├─ sB[0] ──┤ ├─ sB[1] ──┤ ... ├─ sB[S-1] ──┤ │ K_C×N_C │ │ K_C×N_C │ │ K_C×N_C │ └──────────┘ └──────────┘ └────────────┘ ┌─ acc buffer (for epilogue split) ──┐ └────────────────────────────────────┘
Shape<M_W, N_W, K_W>. 1개 warp (또는 warp-group)이 한 번에 다룰 출력 영역. CTA tile이 warp 수만큼 나뉘어 할당.
| inst | m·n·k | 세대 |
|---|---|---|
| mma.sync | 16·8·16 (FP16) | Ampere |
| mma.sync | 16·8·8 (TF32) | Ampere |
| mma.sync | 16·8·32 (INT8) | Ampere |
| wgmma.mma_async | 64·N·16 (FP16) | Hopper |
| wgmma.mma_async | 64·N·32 (FP8) | Hopper |
mma shape 전체 목록은 ↗ V03 §7 (Ampere) · ↗ V04 §8 (Hopper).
// Ampere warp-level multi-stage class MmaMultistage { IteratorA iter_a; // gmem→smem IteratorB iter_b; Policy::Warp warp_mma; int smem_stage = 0; void operator()(acc, k_iter) { prologue(); for (k : k_iter) { ldmatrix(rA, sA[stage]); ldmatrix(rB, sB[stage]); warp_mma(acc, rA, rB, acc); advance(); } } };
| operand | dtype | elems/thread |
|---|---|---|
| A (16·8·16) | FP16 | 8 |
| B (16·8·16) | FP16 | 4 |
| C/D (16·8·16) | FP32 | 4 |
| A (WGMMA 64·N·16) | FP16 | SMEM 직접 |
WGMMA는 A를 register 대신 SMEM으로 받을 수 있음 (A-descriptor).
smem tile (16×16 FP16)
┌──────────────────────┐
│ warp 32 thread 협력 │ ──► ldmatrix.x4
│ 16B × 32 = 512B │
└──────────────────────┘
│
▼
register fragment (A shape)
ldmatrix.trans variant = 전치 로드. mma 입력 shape에 맞춰 thread-mapping 자동.
PredicatedTileIterator : 경계 mask 포함RegularTileIterator : 경계 없음 (aligned)MaskedTileIterator : 외부 mask 받음using IteratorA = PredicatedTileIterator< MatrixShape<M_C, K_C>, ElementA, LayoutA, AdvanceRank, // K 축 전진 ThreadMap, // thread → elem 맵 AccessSize>; // vector width
| level | 저장 | 수명 |
|---|---|---|
| gmem frag | HBM | 전역 |
| smem frag | SMEM stage | CTA mainloop |
| reg frag A/B | Register | mma 1회 |
| reg frag C | Register | K 전체 (acc) |
| AccessSize | 효과 |
|---|---|
| 32 bit | coalesce ↓ · 유연 |
| 64 bit | 중간 |
| 128 bit | 최대 BW · alignment 제약 |
다음 페이지 §7부터 CuTe 본론.
| 작업 | 2.x | CuTe |
|---|---|---|
| tile 분할 | Iterator 템플릿 | logical_divide |
| thread 할당 | ThreadMap | local_partition |
| swizzle | SmemLayout 고정 | Swizzle composition |
| TMA desc | 별도 작성 | Layout 직변환 |
| mma shape 매핑 | 수동 fragment | TiledMMA partition |
((M_1, M_2), (N_1, N_2)) 처럼 shape가 tuple의 tuple. 외부에서는 평면처럼, 내부에서는 sub-structure로 취급.
shape = (4, (2, 3)) → 4×(2·3) = 24 elem
stride = (6, (3, 1)) → col-major outer, row-major inner
4 그룹 × (2×3 tile)
(M, K) 평면, ((M/bM, bM), (K/bK, bK)) block view,((thr_M, bM/thr_M), K) thread view| 종류 | 표기 |
|---|---|
| static (cxx) | _128{}, Int<64> |
| dynamic | int M |
| mixed | make_shape(M, _64{}) |
static 부분은 template 메타로 fold 되어 런타임 cost 0.
L(c) = Σ cᵢ · strideᵢ. 좌표 tuple → 정수 offset. Shape와 Stride는 동일 tree 구조.
Layout<(M,N), (1,M)> → col-major M×N Layout<(M,N), (N,1)> → row-major Layout<(M,N), (1,0)> → N축 broadcast Layout<(M,N), (0,1)> → M축 broadcast
auto s = make_shape(_8{}, _16{}); auto d = make_stride(_16{}, _1{}); auto L = make_layout(s, d); // row-major 8×16 // 간이 ctor auto L2 = make_layout( make_shape(_8{}, _16{}), LayoutRight{}); // row-major auto L3 = make_layout(s, LayoutLeft{}); // col-major
| 함수 | 의미 |
|---|---|
size(L) | Π shape (총 elem 수) |
cosize(L) | max(L(c)) + 1 (필요 memory) |
rank(L) | shape tuple 길이 |
depth(L) | tuple nesting 깊이 |
Layout<(4,3),(1,4)>: size=12, cosize=12, rank=2, depth=1
shape=((4,2),(8,4)) stride=((1,16),(4,32)) 좌표 ((i,j),(k,l)): L = i·1 + j·16 + k·4 + l·32 → col-major 8×8 tile in row-major 4×8 super-grid
| 레이아웃 | shape (M,N) | stride |
|---|---|---|
| col-major | (M,N) | (1, M) |
| row-major | (M,N) | (N, 1) |
| 3D col-major | (M,N,K) | (1, M, M·N) |
| 3D row-major | (M,N,K) | (N·K, K, 1) |
| batched (MNK, B) | ((M,N,K), B) | ((...),M·N·K) |
// Layout<(4,3),(1,4)> col-major 4×3 0 1 2 0 0 4 8 1 1 5 9 2 2 6 10 3 3 7 11 // cell 값 = L(i,j) offset
print_layout(L)·print_tensor(t) 디버그 필수.
| op | signature | 의미 |
|---|---|---|
composition(A,B) | L∘L | A(B(c)) · 좌표 변환 |
complement(L,N) | L→L' | 직교 basis로 size=N 채움 |
coalesce(L) | L→L' | 정합 가능한 축 합병 |
logical_divide(L,T) | →(T,Rest) | Tile과 잔여 분해 |
logical_product(L,T) | →replicate | L을 T번 타일링 |
zipped_divide | →divide + zip | logical_divide + 축 재배열 |
tiled_divide | →flat(divide) | divide 후 flat |
flat_divide | →(T,R) flat | shallow 결과 |
inverse(L) | L-1 | bijective 역 |
right_inverse(L) | 부분역 | L(Lr-1)=id |
(8):(1), B = (4):(2)(4):(2) (0,2,4,6)
(24):(1), T = (4):(1)((4),(6)) : ((1),(4)) : tile-내·tile-외 분리
= replicate L tile. GEMM의 CTA → 전체 grid 복원에 사용.
coalesce( ((2,4),(1,2)) ) = (8):(1)complement( (4):(4), 16 ) = (4):(1)
// CTA tile → thread partition auto cta_tile = Layout<Shape<_128,_64>, Stride<_64,_1>>{}; auto thr_tile = Layout<Shape<_32,_8>, Stride<_8,_1>>{}; auto ztile = zipped_divide( cta_tile, thr_tile); // ((thr_M, thr_N), (rest_M, rest_N)) // 각 thread가 자신의 rest tile 받음
| 법칙 | 식 |
|---|---|
| associativity | (A∘B)∘C = A∘(B∘C) |
| identity | id∘L = L∘id = L |
| divide · product | divide ∘ product = id (조건부) |
| non-commut | A∘B ≠ B∘A (일반) |
print_layout으로 op 전·후를 수식 대신 그림으로 확인. CuTe tutorial의 첫 번째 조언.
T(c)는 engine[layout(c)].
| Engine | 저장 |
|---|---|
gmem_ptr<T> | HBM |
smem_ptr<T> | SMEM |
rmem_ptr<T> | Register |
ArrayEngine | C++ array (stack) |
auto gA = make_tensor( make_gmem_ptr(A_ptr), make_shape(M, K), make_stride(K, _1{})); auto sA = make_tensor( make_smem_ptr(smemA), Layout<Shape<_128,_64>, Stride<_64,_1>>{}); auto rA = make_tensor<half>( Layout<Shape<_8>>{}); // register
// 128×64 CTA tile 한 개 auto tA = local_tile(gA, make_shape(_128{}, _64{}), make_coord(m_blk, k_blk));
// 32×8 thread layout (256 thr/CTA) auto thr_layout = Layout< Shape<_32,_8>, Stride<_8,_1>>{}; auto tAgA = local_partition( tA, thr_layout, threadIdx.x);
// Copy atom으로 gmem → smem copy(copy_atom, thr_copy.partition_S(gA_tile), thr_copy.partition_D(sA)); cp_async_fence(); cp_async_wait<0>(); // TiledMMA로 연산 auto tCrA = thr_mma.partition_A(sA); auto tCrB = thr_mma.partition_B(sB); auto tCrC = thr_mma.partition_C(rC); gemm(tiled_mma, tCrA, tCrB, tCrC);
| 기능 | 기존 2.x | CuTe |
|---|---|---|
| tile 전진 | iter++ | coord 수동 |
| 경계 check | Predicate | TMA / crd2idx |
| thread 분배 | ThreadMap | local_partition |
| shape 변환 | Transform | composition |
make_tensor / make_layoutpartition_A/B/C (TiledMMA slice)partition_S/D (TiledCopy)recast<T> (dtype 재해석, e.g. uint32 → 2×half)flatten, group_modesprint_tensor, print_layout[M+S+B−1 : M+S]를 [M+B−1 : M]과 XOR.
| param | 역할 | 전형 |
|---|---|---|
| B | XOR 영역 bit 수 (= swizzle 분산 배수 log₂) | 3 |
| M | 보존 lower bit (= atomic unit log₂) | 3 (8 elem 묶음) |
| S | XOR 상위 bit 시작 위치 (= row stride log₂) | 3 |
Swizzle<3,3,3> = 128 B line 내 8×8 block 재배치. WGMMA A/B 기본.
원본 offset(6-bit): bit: 5 4 3 2 1 0 ───────────────── S=2, B=2, M=0 → XOR region [3:2] with [1:0] offset bin swizzled 0 000000 000000 (row 0, col 0) 1 000001 000001 (row 0, col 1) 4 000100 000100 (row 1, col 0) 5 000101 000101 (row 1, col 1) 8 001000 001010 (row 2, col 0 → 2) 9 001001 001011 (row 2, col 1 → 3) 12 001100 001110 16 010000 010100 20 010100 010000 (swapped!)
| preset | <B,M,S> | 용도 |
|---|---|---|
| NoSwizzle | <0,0,0> | identity |
| B32_M3_S3 | <3,3,3> | 128B swizzle |
| B64_M4_S3 | <3,4,3> | 128B wide |
| B128_M5_S3 | <3,5,3> | Hopper WGMMA |
preset 이름의 숫자는 보통 swizzle chunk 크기 바이트.
// 기본 SMEM layout 위에 swizzle 덮어쓰기 using SmemLayoutA = decltype( composition( Swizzle<3, 3, 3>{}, Layout<Shape<_64,_64>, Stride<_64,_1>>{})); // 좌표 (i,j) → swizzled offset offset = SmemLayoutA(i, j); // ldmatrix / wgmma descriptor에도 swizzle 인코딩
cp.async.bulk.tensor.{1,2,3,4,5}dtile, im2colTMA PTX 전체는 ↗ V04 §5·§6.
| field | 의미 |
|---|---|
| base_ptr | global addr |
| global_shape | 전체 M, N, ... |
| global_stride | byte stride |
| box_shape | tile 크기 |
| element_stride | 통상 1 |
| swizzle | SMEM swizzle mode |
| dtype | FP16/BF16/FP8/... |
// CuTe Layout이 TMA descriptor로 직변환 auto tma_a = make_tma_copy( SM90_TMA_LOAD{}, gA_tensor, // global Tensor SmemLayoutA{}, // swizzled smem make_shape(_128{}, _64{})); // box // kernel 내에서 partition auto thr_tma = tma_a.get_slice(block_id); auto tAgA = thr_tma.partition_S(gA); auto tAsA = thr_tma.partition_D(sA);
Host :
─ make_tma_copy ──► TMA desc struct
(global tensor map)
─ kernel<<<...>>>(desc, ...)
Device :
thread 0 만 발행:
─ cp.async.bulk.tensor (desc, coord, smem, bar)
─ HW가 shape/stride/swizzle 해석
─ 완료시 bar.arrive(expected_bytes)
모든 thread:
─ bar.wait(parity)
여러 TMA가 같은 bar 사용시 exp_bytes는 합산.
auto tma_a_mc = make_tma_copy(
SM90_TMA_LOAD_MULTICAST{},
gA, SmemLayoutA{},
CtaShape{},
cluster_shape);
multicast는 cluster ≥ 2 + same A tile CTAs.
cuTensorMapEncodeTiled)Kernel<...> ├ CollectiveMainloop (K iter · stages · wgmma) ├ CollectiveEpilogue (acc → epilogue → gD) └ TileScheduler (persistent · streamk · default)
| Policy | 의미 |
|---|---|
KernelTmaWarpSpecialized | 1P + 1C WG |
..Cooperative | 1P + 2C (동일 tile 분할) |
..Pingpong | 1P + 2C (교대 tile) |
KernelMultistage | Ampere cp.async |
sm_90a에서만 TmaWarpSpecialized 선택 가능.
// High-level auto-select using Mainloop = cutlass::gemm::collective::CollectiveBuilder< ArchTag, // SM90 OperatorClass, // TensorOp ElementA, LayoutA, AlignA, ElementB, LayoutB, AlignB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountAuto, KernelScheduleAuto >::CollectiveOp;
using Epilogue = cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignC, ElementD, LayoutD, AlignD, EpilogueScheduleAuto, FusionCallbacks // EVT 연결 >::CollectiveOp;
| Scheduler | 동작 |
|---|---|
| Default | row-major tile iter |
| Persistent | launch = SM 수, work-stealing |
| StreamK | K split, 잔여 fixup kernel ↗ V05 §8 |
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< ProblemShape, Mainloop, Epilogue, TileScheduler >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter< GemmKernel>;
Auto tagstage: 0 1 2 ... S-1 (ring) full : F E E E (producer arrive) empty: E F F F (consumer arrive) producer: wait(empty) → TMA → arrive(full) consumer: wait(full) → wgmma → arrive(empty)
| slot | full bar | empty bar |
|---|---|---|
| 0 | P→C signal | C→P signal |
| 1 | P→C | C→P |
| ... | S개 | S개 |
mbarrier 상세는 ↗ V04 §7.
| WG | 역할 | reg budget |
|---|---|---|
| producer (1 WG) | TMA issue · signal | setmaxnreg.dec(40) |
| consumer (1~2 WG) | WGMMA · epilogue | setmaxnreg.inc(232) |
setmaxnreg PTX · producer 레지스터 해방 → consumer 더 받음.
fence.proxy.async.shared::cta 필요 케이스 주의| type | 용도 |
|---|---|
| PipelineTmaAsync | TMA + WGMMA (Hopper) |
| PipelineAsync | cp.async (Ampere) |
| PipelineTransaction | expected_bytes 기반 |
| OrderedSequenceBarrier | Pingpong WG 순서 |
Time →
WG1: [load+mma Mi ][epilogue ]
WG2: [load+mma Mi+1][epilogue]
WG1: [load+mma Mi+2][epi]
WG2: [mma Mi+3][epi]
↑ 교대: 서로 다른 M tile
epilogue와 mainloop가 다른 WG에서 겹침. 큰 epilogue에 유리.
Time →
WG1: [mma Mi, N0..N/2-1 ][epi L]
WG2: [mma Mi, N/2..N-1 ][epi R]
↑ 같은 M tile을 N 축으로 분할
acc 합쳐서 epilogue 동시
같은 M tile 처리. 큰 N에 유리 (N/2씩 담당).
| 항목 | Pingpong | Cooperative |
|---|---|---|
| WG 역할 | 다른 M tile 교대 | 같은 tile 분할 |
| N 분할 | 전체 | N/2 |
| Epilogue | WG 각자 | 같이 + sync |
| SMEM | epilogue buf 2× | 공유 |
| 큰 epilogue | 유리 | overhead |
| 큰 N | 보통 | 유리 |
| 작은 GEMM 수 | 보통 | 유리 |
// WG1 먼저 epi, WG2 나중
osb.wait(tok_epi, wg_id);
epilogue(acc);
osb.arrive(tok_epi);
시작
│
├─ N < 128 ? ──Y──► Pingpong
│ (큰 epi가 relative 비용 ↑)
│
├─ epilogue gemma fuse 복잡?
│ ──Y──► Pingpong (WG별 독립)
│
├─ N ≥ 256 & 작은 batch?
│ ──Y──► Cooperative
│
└─ default ──► Cooperative
(CUTLASS 권장)
Auto가 안전 조합 반환. 수동 지정 시 compile-time assert 확인.
acc (FP32) │ · α ▼ + β · C (load gmem) │ ▼ + bias (load broadcast) │ ▼ activation (ReLU/GELU/Sigmoid) │ ▼ quantize/cast → dtype_D │ ▼ store D (gmem)
| op | 입력 | 비고 |
|---|---|---|
| Linear Combination | α, β | 기본 |
| Bias Add | bias[N] | broadcast |
| ReLU / GELU / Silu | acc | tanh approx 택 |
| Scale → FP8 | scale | per-tensor amax |
| Dequant (INT→FP) | scale | per-channel |
| Residual Add | gmem 입력 | skip-conn |
| Clamp | min/max | bf16 overflow 방지 |
Fusion 일반론은 ↗ V13 §8. 여기선 CUTLASS epilogue 한정.
smem-staged epilogue가 output coalescing과 layout 재배열에 유리.
| 방식 | 장점 | 단점 |
|---|---|---|
| Register only | latency ↓ | layout 유연성 ↓ |
| SMEM staged | coalesce 용이 | SMEM 사용 |
| TMA store | descriptor 재활용 | alignment 제약 |
acc(FP32) ├ · alpha (FP32) ├ + bias (FP32) ├ activation ├ · output_scale (FP32) ← per-tensor └ cast → E4M3 (clamp ±448)
amax history update도 epilogue 한 노드.
| M×N×K | tile M·N·K | stage |
|---|---|---|
| 4096²×4096 FP16 | 256·128·32 | 3 |
| 2048×4096×1024 | 128·128·32 | 4 |
| 512×4096×1024 | 64·128·64 | 5 |
| M≤256 skinny | 64·128·32 + SplitK | 3 |
| K ≤ 64 | 128·128·K | 2 |
| M×N×K | tile | cluster | sched |
|---|---|---|---|
| 8192² FP16 | 128·256·64 | (2,1,1) | Coop |
| 8192² FP8 | 128·256·128 | (2,1,1) | Coop |
| 4096×8192 FP16 | 128·128·64 | (2,2,1) | Coop |
| 512×8192 FP16 | 64·256·64 | (1,1,1) | Pingpong |
| 큰 epilogue | 128·128·64 | (1,1,1) | Pingpong |
| skinny M | 64·128·64 + StreamK | (1,1,1) | default |
| 형상 | 전략 |
|---|---|
| M=N, 둘 다 큼 | 큰 tile + Cooperative |
| M≫N (tall) | bigger M, SplitK 가능 |
| M≪N (wide) | bigger N, Cooperative |
| M 작음, N 큼 | Pingpong + 작은 N tile |
| M, N 둘 다 작음 | StreamK + persistent |
| K 매우 큼 | deep stage (7+) |
| K 작음 | stage ↓ (2~3) |
같은 B가 여러 CTA 재사용 되면 L2 hit 기대. tile 순서 (swizzle scheduler)가 이 hit에 영향.
1. M, N, K 구하기 2. tile 후보 2~3개 고르기 · H100: (128,256,64), (128,128,64), (64,256,64) 3. cluster (1,1) or (2,1) 4. SMEM budget 확인 → stage 조정 5. Coop vs Pingpong 6. autotune 시 이 4~6 조합 시도
iter b = 0..B-1 : D[b] = α·A[b]·B[b] + β·C[b]
MoE 맥락 상세는 ↗ V08 §3·§4.
| 방식 | A 포맷 | launch |
|---|---|---|
| Batched stride | (B,M,K) contiguous | 1 |
| Batched ptr-array | ptr[B] | 1 |
| Grouped | per-group ptr + shape | 1 |
| loop cublas | 각기 | G |
// host problem_sizes = [(M0,N0,K0), ...]; ptr_A = [A_0, A_1, ...]; ptr_B, ptr_C, ptr_D = ...; ld_A, ld_B, ld_C, ld_D = ...; // kernel 내 int g = tile_scheduler.group_id(); auto [M,N,K] = problem_sizes[g]; run_gemm(ptr_A[g], ptr_B[g], ..., M,N,K);
| Batched | Grouped | |
|---|---|---|
| shape | 동일 | 각기 |
| load balance | 자동 | shape 분산에 민감 |
| launch | 1 | 1 |
| 코드 | 단순 | 복잡 |
| MoE 적합 | x | o |
양자화 알고리즘 상세는 ↗ V10 §5·§6·§7.
| granularity | scale 형상 | 예 |
|---|---|---|
| per-tensor | scalar | FP8 (delayed scale) |
| per-channel | (N,) or (M,) | W8A8 INT8 |
| per-token | (M,) runtime | smooth-quant act |
| per-group (G=128) | (K/G, N) | AWQ W4 |
| block-scale | (K/B,) E8M0 | MXFP8 |
SMEM (quant) │ │ ldmatrix → reg (quant) ▼ dequant(scale, zp) ★ ← 여기 (inline, before mma) │ ▼ reg (bf16 / fp16 / fp8) │ ▼ mma(acc, Â, B)
scale fetch는 별도 TMA/cp.async로 prefetch.
recast<uint8_t> 후 bit-unpack kernel 삽입.
packed (uint8): 0xBA
│└── low nibble: A (int4)
└── high nibble: B (int4)
| 항목 | E4M3 | E5M2 |
|---|---|---|
| range | ±448 | ±57344 |
| precision | 높음 | 낮음 |
| 용도 | forward/output | gradient |
| scale 관리 | per-tensor amax | per-tensor amax |
FP8 bit 상세 ↗ V09 §4.
| 노드 | 역할 |
|---|---|
| Load | aux input 로드 (bias/C/scale) |
| Compute | element-wise op (add/mul/act) |
| Store | gmem/aux out 저장 |
Store(D)
│
Compute(cast→D)
│
Compute(scale_D ·)
│
Compute(GELU)
│
Compute(+)
┌─┴─┐
Compute Load(bias)
(·α)
│
acc
// pseudo-syntax using EVT = Sm90EVT< Sm90Compute<cast, D_t, F32>, Sm90EVT< Sm90Compute<mul, F32, F32>, Sm90EVT< Sm90Compute<gelu, F32, F32>, Sm90EVT< Sm90Compute<add, F32, F32>, Sm90ScalarBroadcast<alpha>, Sm90RowBroadcast<bias> > >, Sm90ScalarBroadcast<scale_D> > >;
| 이름 | 의미 |
|---|---|
| LinearCombination | α·AB + β·C |
| LinCombBias | + bias |
| LinCombBiasRelu | + bias → ReLU |
| LinCombGelu | GELU(α·AB+β·C) |
| LinCombBiasEltwise | + bias → act → residual |
| PerRowLinComb | row 별 α·β |
// 입력: FP8 acc, 출력: FP8 + amax_next using EVT_FP8 = Sm90EVT< Sm90Compute<cast, E4M3, F32>, Sm90EVT< Sm90Compute<mul, F32, F32>, // · scale_D Sm90Compute<amax_update, // side-effect F32, F32> >>;
| 용도 | 호출 |
|---|---|
| col-major | make_layout(s, LayoutLeft{}) |
| row-major | make_layout(s, LayoutRight{}) |
| tile 꺼냄 | local_tile(T, tile, coord) |
| thread 분배 | local_partition(T, thr, tid) |
| MMA slice | thr_mma.partition_{A,B,C}(T) |
| Copy slice | thr_copy.partition_{S,D}(T) |
| dtype 재해석 | recast<U>(T) |
| 축 평탄화 | flatten(L) |
| 축 묶기 | group_modes<lo,hi>(L) |
| swizzle 적용 | composition(Swizzle{}, L) |
using Mainloop = CollectiveBuilder< SM90, TensorOp, half, RowMajor, 8, half, ColMajor, 8, float, Shape<_128,_256,_64>, Shape<_2,_1,_1>, StageCountAuto, KernelTmaWarpSpecializedCooperative >::CollectiveOp; using Epilogue = EpilogueBuilder< SM90, TensorOp, Shape<_128,_256,_64>, Shape<_2,_1,_1>, EpilogueTileAuto, float, float, half, RowMajor, 8, half, RowMajor, 8, EpilogueScheduleAuto, LinearCombination >::CollectiveOp; using Kernel = GemmUniversal< Shape<int,int,int>, Mainloop, Epilogue, PersistentScheduler>;
Device M·N·K gemm_universal
Kernel grid TileScheduler
CTA bM·bN·bK mainloop · stage
Warp wM·wN·wK (wg)mma iter
Thread fragA/B/C reg
print_layout / print_tensor