CUDA/COMPILER 18-VOL · CONTENT-FIRST · A4 LANDSCAPE · 18p

Triton 컴파일러 내부

Triton IR · TritonGPU IR · Layout Inference · Pipeline Pass — 사용법이 아닌 내부 구조
Volume V11 / 18
Tier T4 Compiler
선행 V01 · V03 · V04 · V06
용도 Triton 소스 읽기 · 내부 pass 지도

목차

1. Triton 이라는 추상의 정체p.2
2. Compilation pipeline 전체 흐름p.3
3. Frontend (Python → IR)p.4
4. Triton IR (상위)p.5
5. TritonGPU IR (하위)p.6
6. Layout 종류와 의미p.7
7. Layout Inference Passp.8
8. Coalescing Passp.9
9. Pipeline Pass (SW pipelining)p.10
10. Prefetch Passp.11
11. Dot 최적화 Passp.12
12. LLVM 변환 (NVPTX)p.13
13. Autotuner 메커니즘p.14
14. tl.constexpr · specializationp.15
15. 디버깅 / 소스 읽기 가이드p.16
16. Hopper backend 추가 사항p.17
17. Cheat Sheetp.18

범례

핵심 용어 (노란 배경)
표 헤더 / 매우 중요
정의·공식 박스
예시·워크드 박스
빨강주의·실수하기 쉬움
반드시 숙지
(!)니모닉 (권당 ≤5)
다른 권 cross-ref
인과·흐름·lowering
∵∴이유·결론
인쇄 설정 · A4 가로 / 여백 없음 / 배경 그래픽 포함 · Ctrl(⌘)+P
src: triton-lang/triton · MLIR dialect `triton`/`triton_gpu` · Tillet 2019·2023 · 18p

1 What Triton hides ★ warp-level 을 숨긴다 no thread, only tile

정의 Triton = block-level tile programming model on GPU. Thread/warp 개념을 user 에게서 감춘다. 연산 단위는 thread가 아니라 tile tensor.
  • CUDA 의 1급 시민 = threadIdx, warp, shared memory
  • Triton 의 1급 시민 = tl.arange, tl.load, tl.dot 위의 block tensor
  • Thread assignment · shmem stage · bank conflict 회피 = 컴파일러 책임

2 추상 레벨 스펙트럼

레벨사용자 단위할당 주체
CUDA C++threaduser
CUTLASS / CuTethread + tileuser (layout algebra)
Tritonblock tilecompiler
Linalg / XLA HLOtensor opcompiler stack

cf. CuTe 자세히 ↗ V06 §4

3 왜 이 추상인가 ★

productivity × peak ratio ≈ CUTLASS 수준 peak ratio: sustained / peak HW throughput, 대략 0.7~0.9
  1. tile-level 이면 coalescing·bank conflict·TC fragment 결정론적 lowering 가능
  2. SIMT divergence 가 노출되지 않음 → mask가 predicate
  3. Python AST 직접 변환 → debug loop 짧음
  4. DL 커널 ≈ block matmul + reduce → 이 영역에 최적

4 숨기는 것 vs 노출하는 것

hideexpose
thread indexprogram_id(axis)
warp / lanenum_warps hint
shmem 배치constexpr shape
bank conflict
sync barrier
pipeline stagenum_stages hint
TC fragmenttl.dot

5 Triton 이 잘 맞는 도메인

  • GEMM 변종: scaled mm, block-sparse, grouped
  • Fused attention: QK·softmax·AV 한 kernel
  • Elementwise + reduce: RMSNorm, LayerNorm
  • 그 외 elementwise fusion, rotary, moe routing

Triton 사용법·실제 matmul 코드 → ↗ cudalearning v2 p13·14

6 Triton 의 비-목표

  • general-purpose GPU compiler 아님 — DL tile workload 한정
  • dynamic shape 1급 지원 X autotune 재컴파일로 우회
  • host-side graph / scheduling 영역 아님 → ↗ V13 TorchInductor
  • HW 독립 IR 목표가 아니다 — NVIDIA·AMD backend 별도
오해: Triton은 "CUDA 위의 wrapper" 가 아니다. 독자적 MLIR dialect + pass pipeline 을 가진 컴파일러다. PTX 는 마지막 단계 산출물.

1 6 stage overview ★ Py→ttir→ttgir→llir→ptx→sass

Python source (@triton.jit fn)
      │  [1] AST walk, type infer
      ▼
Triton IR              (.ttir)   dialect: triton
      │  [2] convert-triton-to-tritongpu
      ▼
TritonGPU IR           (.ttgir)  dialect: triton_gpu  (+layout)
      │  [3] optimize-passes (coalesce/pipeline/...)
      ▼
TritonGPU IR optimized (.ttgir)
      │  [4] convert-tritongpu-to-llvm
      ▼
LLVM IR                (.llir)   target nvptx64
      │  [5] llc / ptxas frontend → NVPTX codegen
      ▼
PTX text               (.ptx)    sm_80 / sm_90
      │  [6] ptxas / driver JIT
      ▼
SASS binary                      sm-specific

2 각 단계 IR 예시 ★

# Python
@triton.jit
def add(X,Y,Z,N,BS:tl.constexpr):
  pid = tl.program_id(0)
  o = pid*BS + tl.arange(0,BS)
  m = o < N
  x = tl.load(X+o, m); y = tl.load(Y+o, m)
  tl.store(Z+o, x+y, m)
// Triton IR (ttir) — layout 無
tt.func @add(%X, %Y, %Z, %N) {
  %pid = tt.get_program_id x
  %r   = tt.make_range {start=0,end=BS}
  %o   = arith.addi (%pid*BS), %r
  %m   = arith.cmpi slt, %o, %N
  %x   = tt.load %X+%o, %m
  %y   = tt.load %Y+%o, %m
  %s   = arith.addf %x, %y
  tt.store %Z+%o, %s, %m
}
// TritonGPU IR (ttgir) — #blocked layout 주입
#bl = #triton_gpu.blocked<{
  sizePerThread=[4], threadsPerWarp=[32],
  warpsPerCTA=[4], order=[0]}>
%x : tensor<1024xf32, #bl>

3 pass 카테고리

stage주요 pass
ttirinliner, combine, canonicalize
ttir→ttgirconvert-triton-to-tritongpu
ttgircoalesce, layout-infer, remove-layout-conversions
ttgirpipeline, prefetch, accelerate-matmul
ttgir→llirconvert-tritongpu-to-llvm
llir→ptxNVPTX backend (LLVM)
ptx→sassptxas (closed source)

각 pass 세부 → 이후 §7~§11

4 입력·출력 구조

  • 입력: function ptr table, int/float constexpr, pointer dtype
  • 출력: PTX string + kernel metadata (spills, shmem bytes, num_warps)
  • 특정 sm_arch 에 바인딩 → arch 변경 시 재컴파일
주의: ttir 에는 layout 이 아직 없다. ttgir 진입 시 모든 tensor 가 layout encoding을 획득한다 (§5).

1 @triton.jit 의 역할

정의 @triton.jit 은 Python 함수를 AST 로 파싱한 뒤 각 tl.* 호출을 MLIR triton dialect op 으로 번역하는 custom compiler entry.
  • 실행시 Python interpreter 통해 호출되지 않는다 JIT cached
  • 함수 body 를 ast.parse 로 tree 화 → CodeGenerator visitor 가 IR builder 호출
  • caching key = (fn, signature, constexpr 값, sm_arch, num_warps, num_stages)

2 Python 지원 범위

구문지원
if / elsescf.if 로 lowering
for range(...)scf.for (static bound 권장)
whilescf.while (제한적)
assertdev assert
recursion불가
list / dict불가 (constexpr tuple 만)

3 AST 변환 예시

# Python
o = pid * BS + tl.arange(0, BS)
# AST nodes
BinOp(Add, BinOp(Mult, Name(pid), Name(BS)),
           Call(Attribute(tl, arange),
                [Const 0, Name BS]))
// Triton IR emission
%bs  = arith.constant BS : i32
%p0  = arith.muli %pid, %bs
%rng = tt.make_range {start=0, end=BS}
%pb  = tt.splat %p0 : tensor<BSxi32>
%o   = arith.addi %pb, %rng

4 tl.* → IR op 매핑 ★

tl.*MLIR op
tl.program_idtt.get_program_id
tl.arangett.make_range
tl.loadtt.load
tl.storett.store
tl.dottt.dot
tl.reduce / maxtt.reduce
tl.atomic_*tt.atomic_rmw / cas
broadcasttt.broadcast
reshapett.reshape

5 Type inference

  • 입력 argument: Python dtype → MLIR type
    • torch.float16f16
    • torch.float32f32
    • int64 ptr!tt.ptr<f16>
  • 중간 tensor: shape = (BM, BK) constexpr 전개
  • promotion rule: numpy 와 유사, tl.int32 + tl.float16 → f16

6 Specialization key

key = hash(fn_src, signature, constexpr_values, sm, num_warps, num_stages) constexpr 값이 바뀌면 재컴파일. 일반 int arg 는 signature 만 체크.

alignment hint (pow2) 도 specialization 영향 → §14

실수: tensor 크기를 Python int 로 넘기면 매 호출 재컴파일. 반드시 tl.constexpr 로 선언.

1 dialect triton

정의 triton dialect 는 MLIR upstream dialect 위에 추가된 target-agnostic tile-level dialect. Layout 정보가 아직 없다. Ops prefix = tt..
  • MLIR 자체 → ↗ V12 §2
  • 의존 dialect: arith, scf, cf, math
  • 정의 위치: include/triton/Dialect/Triton/IR/TritonOps.td

2 핵심 op 표 ★

op의미
tt.get_program_idgrid axis id
tt.get_num_programsgrid size
tt.make_range[start, end) i32 tensor
tt.splatscalar → tensor
tt.broadcastrank-match broadcast
tt.load / tt.storeptr tensor + mask
tt.dotC += A·B (tile)
tt.reduceaxis reduce + combiner region
tt.atomic_rmwatomic op
tt.func / tt.returnkernel entry

3 Block type system

T = tensor<shape × dtype>  |  !tt.ptr<dtype> shape 은 constexpr (BM, BN, BK...). dtype = f16/bf16/f32/f8/i32...
  • scalar: 그냥 f32 / i32
  • tile: tensor<128×64×f16>
  • ptr tile: tensor<128×64×!tt.ptr<f16>> gather/scatter 전용

4 tt.load / tt.store

%v = tt.load %ptrs, %mask, %other
       {cache = ca, evict = normal}
       : tensor<128x64xf16>
tt.store %p_out, %v, %mask
  • mask tensor 필요 시 OOB skip
  • other tensor = masked-out fill value
  • cache modifier = ca/cg/cs/wb → PTX .ca/.cg 매핑

5 tt.dot 의 의미

의미 tt.dot %a, %b, %c = C += A · B over (M, N, K) tile. Ttir 단계에서는 어떻게 실현될지 미정 (FMA vs TC, layout 미결정).

6 tt.reduce 구조

%r = "tt.reduce"(%x) ({
  ^bb(%a: f32, %b: f32):
    %s = arith.maxnumf %a, %b
    tt.reduce.return %s : f32
}) {axis = 1 : i32}
   : (tensor<128x64xf32>) -> tensor<128xf32>
  • axis 와 combiner region 으로 모든 reduce 일반화 (sum, max, argmax, ...)
  • ttgir 에서 warp shuffle / shmem 으로 lowering

7 Pointer arithmetic

ptr tensor + i32 tensor = ptr tensor
(element-wise, shape 일치)
BroadCast shape 맞추기:
  A_ptrs = A + m[:,None]*sAm + k[None,:]*sAk

실 example ↗ cudalearning v2 p13 §4

핵심: ttir 는 "what"만 표현. "how" (warp 분배, shmem 배치, pipeline) 는 ttgir pass 가 결정. 둘의 경계가 Triton 설계의 core.

1 dialect triton_gpu

정의 triton_gpu dialect = Triton IR 의 GPU-bound 버전. 모든 tensor 타입에 layout encoding 이 attribute 로 부착된다. 이 단계부터 "누가 어떤 원소를 들고 있나"가 확정.
  • prefix = #triton_gpu / op 는 대부분 tt.* 유지 + layout 부여
  • 핵심 op: triton_gpu.convert_layout, triton_gpu.alloc_tensor, triton_gpu.async_copy_global_to_local

2 tensor type w/ layout

tensor<128x64xf16, #blocked>
tensor<128x64xf16, #shared>
tensor<128x64xf16, #mma>
tensor<128x64xf16,
       #dot_operand<{opIdx=0, parent=#mma}>>

layout 은 같은 tensor 를 누가 어떻게 들고 있는지 를 기술한다.

3 convert_layout op ★ 같은 tensor, 다른 소유자

역할 triton_gpu.convert_layout %x : src → dst 은 tensor 의 값은 동일하되 layout 만 변경한다. 실제 lowering 시 shmem 경유한 재분배 로 구현됨.
blocked → shared   (store to shmem)
shared  → dot_op   (ldmatrix 발행)
mma     → blocked  (TC 결과 재배치)

insert / remove 는 Layout Inference pass 가 관리 (§7).

4 async copy op

// Ampere: cp.async
%tok = triton_gpu.async_copy_global_to_local
         %gptr, %smem_slice
         : tensor<…, #blocked> → memdesc<…, #shared>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 2 : i32}
  • Ampere → cp.async PTX
  • Hopper → TMA bulk tensor copy (§16)
  • commit/wait = stage 카운팅 (§9 Pipeline)

5 ttir → ttgir 전환 예

// ttir
%x : tensor<1024xf32>
%y : tensor<1024xf32>
%z = arith.addf %x, %y : tensor<1024xf32>
// ttgir (default blocked 주입)
#bl = #triton_gpu.blocked<{
  sizePerThread=[4], threadsPerWarp=[32],
  warpsPerCTA=[4], order=[0]}>
%x : tensor<1024xf32, #bl>
%y : tensor<1024xf32, #bl>
%z = arith.addf %x, %y : tensor<1024xf32, #bl>

6 num_warps 와의 관계

warpsPerCTA · 32 = num_warps · 32 = CTA thread count 사용자가 지정한 num_warps 가 곧 layout 의 warp 축 총곱
  • num_warps=4 · 1D → warpsPerCTA=[4]
  • num_warps=8 · 2D tile → warpsPerCTA=[4,2] 등 pass 결정
핵심: ttgir 는 "결정이 내려진 tile IR". ttir 와 달리 같은 알고리즘이라도 hardware 에 따라 전혀 다른 ttgir 이 나온다.

1 Layout 5종 한눈에 ★ Bl · Sh · Mma · Dot · Sl

layout용도생성자변환
Blockedregister tile (load/store/ewise)Coalescing pass→ Shared
Sharedshmem 배치 (swizzle)Pipeline/Prefetch↔ Blocked, → DotOp
MmaTC accumulator fragmentAccelerateMatmul→ Blocked (store)
DotOperandtt.dot A/B operand fragmentAccelerateMatmul← Shared (ldmatrix)
Slicereduce 결과 1DReduce lowering← parent (axis drop)

2 BlockedLayout 분해 ★

element = sizePerThread[d] · threadsPerWarp[d] · warpsPerCTA[d] · CTAsPerCGA[d] 각 축 d 를 4계층으로 분해. order = 내부 메모리 순서 (contig 우선).
tensor<128x64xf16, #blocked<{
  sizePerThread = [4, 4],
  threadsPerWarp = [8, 4],
  warpsPerCTA  = [4, 1],
  order = [1, 0]}>
→ thread 한 개가 (4,4) block 소유
→ warp 하나가 (32,16) tile 커버
→ 4 warps 면 (128, 16) — K축 4 iter

3 SharedLayout · swizzle

swizzle shmem bank conflict 회피용 주소 변환. vec · perPhase · maxPhase 세 param 으로 기술.
#shared = #triton_gpu.shared<{
  vec = 8, perPhase = 2, maxPhase = 4,
  order = [1,0]
}>
  • TC fragment load (ldmatrix) 요구 패턴과 일치 시켜야 conflict-free
  • PMPP shmem bank 기초 → ↗ V01 §8

4 MmaLayout · TC fragment

  • Ampere: #mma<{versionMajor=2, warpsPerCTA=[2,2], instrShape=[16,8]}>
  • Hopper: #mma<{versionMajor=3, warpsPerCTA=[4,1], instrShape=[64,N,16]}> (WGMMA)
  • 각 thread 가 TC 정의한 레지스터 slot 을 소유
  • PTX mma.sync fragment 와 1:1 대응

mma.sync / WGMMA 상세 → ↗ V03 §7 · V04 §7

5 DotOperandLayout

#dot_a = #triton_gpu.dot_op<{
  opIdx = 0, parent = #mma, kWidth = 8
}>
#dot_b = #triton_gpu.dot_op<{
  opIdx = 1, parent = #mma, kWidth = 8
}>
  • opIdx=0 = A operand, 1 = B operand
  • parent = 결과 MmaLayout
  • kWidth = inner dim 단위 (8·16)
  • shared → dot_op 변환 시 ldmatrix 생성

6 SliceLayout

역할 reduce 결과의 한 축이 제거된 1D tensor 를 표현. parent layout + dim 조합으로 "어느 축이 사라졌는지" 를 유지.

예: #slice<{dim=1, parent=#blocked}> → reduce(axis=1) 후.

1 목표

문제 ttir → ttgir 전환 시점에 모든 tensor 는 default blocked layout 으로 초기화된다. 실제로는 tt.dot 는 Mma 가 필요하고, 재사용 경계에는 Shared 가 필요. Inference pass 가 어디서 어떤 layout 을 쓸지를 결정한다.

2 알고리즘 개요 ★

  1. anchor op 에 강한 layout 고정 (tt.dot → Mma, tt.load → Blocked)
  2. bottom-up propagation: operand 가 다음 op 요구와 맞도록 layout 전파
  3. top-down propagation: 소비자 요구가 생산자로 역전파
  4. mismatch 지점에 convert_layout 삽입
  5. cost 기반 선택 (§4): conversion 비용 최소화

3 Anchor op 표

op선호 layout
tt.load (contig)Blocked (coalesced)
tt.store (contig)Blocked (coalesced)
tt.dot resultMma
tt.dot operandDotOperand (parent Mma)
tt.reduce 결과Slice (parent=입력)
async_copy → localShared

4 Cost model 대략

cost(convert src → dst) ≈ nElts · shmem_roundtrip_factor 같은 layout 이면 0, blocked↔blocked 는 작고, blocked→mma 는 shmem 경유로 큼
  • 동일 layout 사이 trivial → DCE 로 제거
  • Mma ↔ Mma (다른 warpsPerCTA) 는 shmem 필수
  • Blocked ↔ Blocked (order 다름) 는 transpose 비용

5 Remove-layout-conversions

after insertion: 많은 convert_layout
→ redundant 제거 pass:
  · convert(convert(x)) → convert(x)
  · convert(x) where src==dst → x
  · hoist convert out of loop (loop-invariant)
  · sink convert past elementwise

결과적으로 핵심 boundary 에만 convert 남음.

6 예시: matmul mainloop

// before
%a = tt.load … : #blocked
%b = tt.load … : #blocked
%c = tt.dot %a, %b : #mma

// after inference
%a_b = tt.load …             : #blocked
%a_s = convert_layout %a_b   : #shared
%a_d = convert_layout %a_s   : #dot_op<0,mma>
%b_b = tt.load …             : #blocked
%b_s = convert_layout %b_b   : #shared
%b_d = convert_layout %b_s   : #dot_op<1,mma>
%c   = tt.dot %a_d, %b_d, %c : #mma

이후 Pipeline pass 가 shared stage 를 loop 로 분할 → §9

1 목표 ★ BlockedLayout 조정

정의 tt.load / tt.store 의 pointer tensor 패턴을 분석해, 결과 BlockedLayout 이 128B-aligned coalesced access 를 만들도록 sizePerThread / order 를 수정한다.

coalescing 기초 규칙 → ↗ V01 §8

2 분석 순서

  1. pointer tensor 의 source 추적 (tt.make_range, tt.splat, arith.addi)
  2. 각 축 stride 상수/변수 분류
  3. contiguous 축 탐지 (stride == 1 인 축)
  4. 해당 축을 layout order[0] 로 설정
  5. vec width 결정 (dtype bits · alignment hint 기반)

3 vec width 결정 수식

vec = min( 128 / dtype_bits, align_hint, contig_len ) f16 → 최대 8, f32 → 최대 4. align_hint = pow2 alignment of ptr.
dtypemax vecPTX
f16 / bf168 elem (128b)ld.global.v4.b32
f324 elemld.global.v4.b32
f816 elemld.global.v4.b32
i32 / u324 elemld.global.v4.b32

4 2D tile 예시

// tile <128,64> f16, contig = axis 1
sizePerThread = [1, 8]   // vec=8 along axis 1
threadsPerWarp = [4, 8]
warpsPerCTA  = [4, 1]
order        = [1, 0]
→ thread 32개가 한 row 의 64 col 커버

5 비-coalesced case 판정

  • stride 가 대형 상수 → non-contig → vec = 1
  • stride 가 data-dependent (gather) → pass skip, scalar load
  • mask 가 stride-wise 패턴 깨면 vec 축소
주의: Coalescing pass 는 pointer 가 compile-time stride 로 표현될 때만 효과. Python 에서 affine 한 인덱싱 유지가 중요.

6 다른 layout 과의 상호작용

  • Coalesce 가 확정된 Blocked → 이후 Layout Inference 가 anchor 로 고정
  • store 단계 Mma → Blocked 변환 시 Coalesced Blocked 가 목표
  • scatter/gather path 는 pass 가 마지막 fallback 으로 vec=1 설정

7 pass 순서 위치

convert-triton-to-tritongpu
→ coalesce-pass   (여기)
→ accelerate-matmul
→ layout-propagation
→ remove-layout-conv
→ pipeline

1 목적 ★ load + compute 겹치기

아이디어 matmul K-loop 의 반복 i 에서 쓰일 데이터의 global→shmem 로딩을 반복 i-(s-1) 에서 미리 발행한다. num_stages = s 만큼의 shmem buffer 를 순환 사용.
  • 지연 숨김: HBM latency ~400 cycle vs compute cycle
  • Ampere: cp.async↗ V03 §6)
  • Hopper: TMA bulk + mbarrier (§16, ↗ V04 §4)

2 Before ★

// sync load + compute, no pipeline
scf.for %k = 0 to K step BK {
  %a_b = tt.load A[%k]   // blocks
  %b_b = tt.load B[%k]
  %a_s = conv %a_b : #shared
  %b_s = conv %b_b : #shared
  %a_d = conv %a_s : #dot_op
  %b_d = conv %b_s : #dot_op
  %c   = tt.dot %a_d, %b_d, %c
}
  • load 가 끝나야 dot 시작 → 완전 직렬
  • Tensor Core 대기 ↑↑

3 After (num_stages=3) ★

// prologue: stage 0,1 선행 발행
async_copy A[0] → smem[0]; async_commit
async_copy B[0] → smem[0]
async_copy A[1] → smem[1]; async_commit
async_copy B[1] → smem[1]

scf.for %k = 0 to K-2*BK step BK {
  // 곧 쓸 stage 는 wait
  async_wait {num = 2}
  // stage t : compute
  %a_d = conv smem[t]   // ldmatrix
  %b_d = conv smem[t]
  %c   = tt.dot %a_d, %b_d, %c

  // stage t+2 : 새 load 발행
  async_copy A[k+2*BK] → smem[(t+2)%3]
  async_copy B[k+2*BK] → smem[(t+2)%3]
  async_commit
}
// epilogue: 남은 stage 소진

4 stage vs resource 표

num_stagesshmem byteslatency 은닉
22·tile 크기약함
33·tile보통 (Ampere 기본)
44·tile강함 (Hopper TMA)
≥5shmem 초과 위험register spill 위험

shmem capacity: A100 192KB, H100 228KB (configurable) ↗ V02 §3

5 pass 내부 절차

  1. loop 내 async_copy/convert_layout(blocked→shared) 식별
  2. multi-buffer 할당: stage 수만큼 shmem tile 복제
  3. prologue · steady · epilogue 3부분 생성
  4. async_commit/async_wait 배치 (wait num = stages-1)
  5. shmem-offset rotation: (t + k) % stages
실수: num_stages 키우면 항상 빠르진 않다. shmem 초과 시 occupancy 1 block 으로 강등되며 오히려 느려진다.

1 Prefetch ≠ Pipeline

비교Pipeline PassPrefetch Pass
계층global → shmemshmem → register
도구cp.async / TMAldmatrix · LDS
단위stage (BK 전체)sub-tile (ldmatrix unit)
목적HBM latency 은닉shmem → reg latency 은닉

2 핵심 아이디어

아이디어 mainloop 각 iteration 에서 TC 에 쓸 다음 sub-k 의 shmem → register 로딩을 현재 dot 와 겹쳐 발행.

3 Before / After 스케치

before (k 단위만 pipeline):
  smem[t] ─ldmatrix→ A_frag
  smem[t] ─ldmatrix→ B_frag
             │
             ▼
           tt.dot

after (+ prefetch on k-sub):
  A_frag[0] already loaded (prologue)
  for kk in 0..BK/mma_k:
    A_frag[kk+1] = ldmatrix smem[t] (kk+1)
    B_frag[kk+1] = ldmatrix smem[t] (kk+1)
    tt.dot A_frag[kk], B_frag[kk], c

4 register 압박

reg usage ≈ 2 · fragment_size A_frag, B_frag 이중 보유. 16x16 f16 mma 기준 thread당 +16 reg.
  • spill 발생 시 local memory 로 내려감 → DRAM 왕복
  • 64K reg / SM (Ampere) · 256 threads CTA 기준 thread당 255 reg 상한

5 활성화 조건

  • tt.dot 가 K-loop 내부에 있어야 함
  • operand layout = DotOperand (parent Mma) 인 경우만
  • shmem tile 이 stage buffer 로 등록되어 있어야 함 (pipeline 이후)

6 trade-off 판단

상황prefetch 효과
BK 큼 (≥64)↑ dot-load 겹침 ↑
reg spill 발생↓ 오히려 악화
HD 큰 attention↑ main bottleneck 해소
elementwise fused미미
핵심: pipeline 이 "stage 간 겹침", prefetch 가 "stage 내부 sub-k 겹침". 둘 다 적용되어야 Tensor Core utilization 이 peak 에 근접.

1 목적

정의 tt.dotMmaLayout 기반으로 정확히 HW Tensor Core instruction 이 떨어지도록 재작성. operand layout 을 DotOperand 로, 결과 layout 을 Mma 로 확정.

2 HW version 선택

smMmaLayout verinstr
sm_70 (V100)v1mma.m8n8k4
sm_75 (Turing)v1mma.m16n8k8
sm_80/86 (Ampere)v2mma.sync.m16n8k16
sm_90 (Hopper)v3wgmma.m64nNk16

dtype별 shape 매트릭스 → ↗ V03 §7 · V04 §7

3 warpsPerCTA 결정

(M_warp, N_warp) : M_warp · N_warp = num_warps BM/BN 비율 · instrShape 반영. 일반 패턴: 비율 1:2 ~ 2:1.
  • BM=128, BN=128, num_warps=4 → (2,2)
  • BM=128, BN=256, num_warps=8 → (2,4)
  • BM=64, BN=256, num_warps=8 → (1,8) (Hopper WGMMA 선호)

4 operand 경로 ★

shmem (A tile)
  │  convert_layout shared → dot_op
  ▼           (lowering: ldmatrix)
A_dot_op  ─┐
           │ tt.dot
B_dot_op  ─┤        → C_mma
           │
shmem (B tile)
  │  convert_layout shared → dot_op
  ▼           (lowering: ldmatrix.trans 가능)
B_dot_op

5 ldmatrix 생성 규칙

  • source = #shared swizzle 과 정합
  • dest = DotOperand fragment (thread-register 매핑)
  • .x1/.x2/.x4 = 한 호출당 8x8 tile 수
  • .trans variant = B operand K-major 일 때

ldmatrix 개요 → ↗ V03 §8

6 FP8 / low-bit 경로

  • Hopper: wgmma FP8 e4m3/e5m2 지원
  • Blackwell: FP4/FP6 (2nd-gen TE) — backend 확장 영역
  • scaled dot: tt.dot_scaled 로 표현 (block-scaling)

low-bit 정밀도 → ↗ V09 · V10

7 pass 위치

coalesce
→ accelerate-matmul  (여기)
→ layout-propagation
→ pipeline
→ prefetch

1 역할

정의 ttgir 의 tensor-level op 를 per-thread LLVM IR + NVPTX intrinsic 로 lowering. 이 단계에서 tensor 는 thread-local register tuple 로 분해된다.

MLIR → LLVM 일반 이론 → ↗ V12 §10

2 Lowering 매핑 표 ★

ttgirLLVM / NVVM
tt.load (vec)ld.global.v4.b32 (inline asm)
tt.storest.global.v4.b32
tt.dot (mma)llvm.nvvm.mma.m16n8k16.*
tt.dot (wgmma)inline PTX wgmma.mma_async
convert shared→dotinline PTX ldmatrix
async_copyinline PTX cp.async.cg
async_waitcp.async.wait_group
tt.reducellvm.nvvm.shfl.sync.* + shmem
tt.atomic_rmwLLVM atomicrmw / inline PTX
program_idllvm.nvvm.read.ptx.sreg.ctaid.*

3 tensor → register tuple

tensor<128x64xf16, #blocked>
 layout: thread 당 (4,4) = 16 elements
 → 각 thread 가 16개 f16 reg 소유
 → LLVM: vector<16xf16> 혹은
          !llvm.struct<(f16, f16, ..., f16)>

op 는 thread 당 element 수 만큼 펼쳐진다 (unroll).

4 Address space 매핑

IRNVPTX AS
global ptraddrspace(1)
shared (shmem)addrspace(3)
constantaddrspace(4)
local (stack/spill)addrspace(5)
genericaddrspace(0)

NVPTX AS 상세 → ↗ V12 §10

5 Inline PTX 사용

// 실제 생성되는 LLVM IR snippet
call void asm sideeffect
  "cp.async.cg.shared.global [$0], [$1], 16;",
  "r,l"(i32 %smem_off, i8* %gptr)
  • LLVM intrinsic 이 없는 PTX 는 inline asm 으로 생성
  • register constraint: r(32b reg), l(64b reg), f(float reg)
  • clobber / sideeffect 지정 필수

6 이후 LLVM → PTX

LLVM IR (nvptx64 target)
 → LLVM opt: instcombine, mem2reg, ...
 → NVPTX backend codegen
 → PTX text file
 → ptxas (external) → SASS

SASS / ptxas 는 NVIDIA closed source → ↗ V04 §12

핵심: convert-tritongpu-to-llvm 이 Triton 고유 pass 중 가장 복잡. 각 layout 조합별 lowering template 이 존재.

1 @triton.autotune 구조

@triton.autotune(
  configs=[
    triton.Config({'BM':128,'BN':256,
                   'BK':64,'GM':8},
                   num_warps=8, num_stages=3),
    triton.Config({'BM':64, 'BN':64,
                   'BK':32,'GM':8},
                   num_warps=4, num_stages=2),
    …
  ],
  key=['M','N','K'],
  prune_configs_by={'perf_model': fn,
                    'early_config_prune': fn},
)
@triton.jit
def mm(…): …

2 Config space 차원

  • tile 크기: BM, BN, BK
  • swizzle: GROUP_M (§13 cudalearning v2)
  • parallel: num_warps, num_stages
  • dtype 별 지원 조합 제한

3 Caching key ★ k = (M,N,K,dtype,sm,arch)

cache_key = hash( key_arg_values, dtype_sig, sm_arch ) key 에 지정한 runtime arg 값 별로 독립 config 탐색.
  • 같은 (M,N,K) 호출 → 재탐색 없이 best config 사용
  • 새로운 값 조합 → 모든 config 벤치마크 후 저장
  • persistent cache: disk (~/.triton/cache)

4 Benchmark loop 내부

  1. pruner 가 config 후보 축소 (smem 초과 / reg 초과 제거)
  2. 각 남은 config 에 대해 compile → launch × N trial
  3. median elapsed time 측정 (cuda event)
  4. best config 선택 → cache
  5. 이후 호출은 fast path

5 Prune 방법 비교

방식의미
staticshmem / reg 추정 → 상한 초과 config 제거
perf_model사용자 제공 cost 함수 top-k
restore_valuein-place 연산 시 원복용
warmup/rep벤치마크 반복 수

6 Heuristic vs Autotune

  • @triton.heuristics = deterministic 규칙 (shape→config)
  • @triton.autotune = 실측 기반 탐색
  • 실무: heuristics 로 1차 후보 → autotune 최종
실수: key 에 포함되지 않은 shape 축에 대해서는 config 가 재사용된다. dynamic batch 일 때 key=['M','N','K'] 만 쓰면 잘못된 config 가 선택될 수 있음.

1 tl.constexpr 의 의미

정의 tl.constexpr 로 표시된 argument 는 kernel signature 의 type 일부 가 된다. 값이 바뀌면 다른 kernel 로 간주되어 재컴파일된다. C++ template parameter 와 유사.
  • tile 크기 (BM/BN/BK), grid 구조 (GROUP_M), dtype 선택, bool 플래그에 사용
  • 내부적으로 MLIR constant + type 로 embed

2 비교: 일반 int arg

종류재컴파일?특성
일반 intN (sig 동일 시)runtime value
tl.constexprY (값 달라지면)compile-time
pow2 hint부분alignment specialization

3 Specialization 의 힘 ★

  • 상수 전파: BM=128 이면 for loop bound 가 상수 → 완전 unroll
  • shape 확정: tensor type 완전 결정 → layout inference 최적
  • branch elision: if CAUSAL 이 사라지며 dead code 제거
  • vec/align: stride % 16 == 0 이면 vec=8 선택 가능

4 실제 예시

@triton.jit
def norm(X, Y, N,
         BS: tl.constexpr,
         USE_BF16: tl.constexpr):
  pid = tl.program_id(0)
  o = pid*BS + tl.arange(0, BS)
  x = tl.load(X+o, o < N)
  if USE_BF16:
    x = x.to(tl.bfloat16)
  ...

USE_BF16=True/False → 서로 다른 kernel 두 버전 컴파일.

5 Divisibility / alignment hint

divisibility(ptr) = 2^k   ⇒   vec width ≤ 2^k / dtype_bytes Triton 이 ptr alignment 를 추적해 load vec 를 자동 결정.
  • torch tensor → triton 자동으로 16B align 추정
  • annotate 로 강제 가능: tl.load(ptr, ..., eviction_policy=..., cache_modifier=...)

6 template-like 패턴

kernel(…, DTYPE: tl.constexpr)
  → 같은 kernel 소스가
    f16, bf16, f32 각각 독립 cache
  → 런타임 dispatch:
    if   dt=='f16':  kernel[grid](…, f16)
    elif dt=='bf16': kernel[grid](…, bf16)
실수: constexpr 값이 계속 바뀌면 cache thrashing. 제한된 집합 으로 제약해야 한다.

1 환경변수 표 ★

env효과
TRITON_CACHE_DIRJIT cache 위치 (default ~/.triton/cache)
MLIR_ENABLE_DUMP=1각 pass 전후 IR stdout dump
TRITON_ENABLE_LLIR_DUMP=1LLVM IR 덤프
TRITON_INTERPRET=1kernel 을 CPU 에서 Python 실행 (debug)
TRITON_DEBUG=1compile 단계 세부 로그
TRITON_PRINT_AUTOTUNING=1best config 출력

2 cache 파일 종류

~/.triton/cache/<hash>/
  ├── kernel.ttir     Triton IR
  ├── kernel.ttgir    TritonGPU IR
  ├── kernel.llir     LLVM IR
  ├── kernel.ptx      PTX text
  ├── kernel.cubin    SASS (arch-specific)
  └── kernel.json     metadata (smem, reg, shared, num_warps)

3 소스 트리 지도

경로내용
include/triton/Dialect/Triton/ttir dialect 정의 (ODS)
include/triton/Dialect/TritonGPU/ttgir dialect + layout attr
lib/Dialect/TritonGPU/Transforms/핵심 pass (Coalesce, Pipeline, …)
lib/Conversion/TritonToTritonGPU/ttir→ttgir
lib/Conversion/TritonGPUToLLVM/ttgir→llir
python/triton/compiler/frontend · cache · compile
python/triton/runtime/autotune · launch
python/triton/language/tl.* API

4 읽기 순서 추천

  1. python/triton/language/core.py — tl.* 정의
  2. compiler/code_generator.py — AST → IR
  3. ODS 파일 (TritonOps.td, TritonGPUOps.td)
  4. TritonGPU Transforms (Coalesce.cpp → Pipeline.cpp)
  5. TritonGPUToLLVM (LoadStoreOpToLLVM.cpp → DotOpToLLVM.cpp)

5 IR 덤프 읽기 팁

  • ttir 은 layout 無 → 알고리즘 구조 파악용
  • ttgir 은 layout 확정 → 실제 shmem/TC 경로 파악
  • // -----// IR Dump Before / After … 구분선 기준으로 pass 영향 비교
  • diff tool 로 before/after 차이만 보기: diff before.ir after.ir

6 디버그 모드

  • TRITON_INTERPRET=1 → Python 으로 tile 연산 에뮬
  • kernel 내부 tl.device_print 허용
  • CPU → GPU 이전 수식 검증용 (속도는 매우 느림)
  • pdb breakpoint 가능
주의: MLIR_ENABLE_DUMP=1 은 모든 pass 전후 IR 를 찍는다. kernel 하나에 수천줄 출력 가능 → 파일 redirect 필수.

1 sm_80 vs sm_90 요점 ★

요소sm_80sm_90
global→shmemcp.asyncTMA (bulk tensor)
syncbar.syncmbarrier (async)
TC instrmma.sync m16n8k16wgmma.mma_async m64nNk16
TC 단위warp (32)warpgroup (128)
cluster없음thread block cluster
shmem 접근localdistributed shared

HW 상세 → ↗ V02 §9 · V04 §4·§7

2 TMA 사용 조건

활성화 조건
  • sm_90 target
  • source tensor 가 contiguous 2D/3D tile
  • tile 크기가 TMA box dim 제약 충족 (16B 배수)
  • pointer align 이 최소 16B
  • 조건 안 맞으면 cp.async fallback
  • TMA descriptor = kernel launch 시 host 에서 구성

3 WGMMA 사용 조건

  • num_warps 가 4 의 배수 (warpgroup 단위)
  • A operand 가 shmem 또는 register
  • B operand 가 shmem 필수
  • accumulator 는 register (f32)
  • shmem swizzle 이 WGMMA 모드와 일치 (128B/64B)

WGMMA shape 매트릭스 → ↗ V04 §8

4 num_warps · num_stages 의미 변화

paramsm_80sm_90
num_warpsCTA thread 수 / 32warpgroup 수 × 4
num_stages 2prologue 얇음의미 약함 (TMA 깊은 파이프)
num_stages 4~6shmem 초과 위험TMA 권장 구간

5 Warp specialization

  • Hopper 에서 Triton 은 producer/consumer 분리 pattern 코드 생성
  • producer warpgroup: TMA issue + mbarrier.arrive
  • consumer warpgroup: mbarrier.wait + WGMMA
  • setmaxnreg 로 register 재분배 (inc/dec)

패턴 상세 → ↗ V04 §9

실수: sm_80 용으로 튜닝한 config 를 sm_90 에 그대로 쓰면 TMA/WGMMA 가 비활성화될 수 있음. autotune key 에 sm 반영 필수.

1 Stage × Pass 매트릭스 ★

stagepassuser 영향
frontendAST → ttirconstexpr 설계
ttir→ttgirconvert-to-tritongpunum_warps 선택
ttgircoalesceptr affine 패턴 유지
ttgiraccelerate-matmultl.dot 명시
ttgirlayout-propagateconvert 수 ↓
ttgirpipelinenum_stages 선택
ttgirprefetchreg 압박 주의
ttgir→llirconvert-to-llvm
llir→ptxNVPTX backendsm_arch 지정

2 Layout 선택 기준

상황layout
global load/storeBlocked (order=contig)
shmem stage bufferShared (swizzle)
tt.dot operandDotOperand(parent=Mma)
tt.dot resultMma
reduce 결과Slice(parent)

3 Debug flag 7종

  • MLIR_ENABLE_DUMP=1
  • TRITON_ENABLE_LLIR_DUMP=1
  • TRITON_INTERPRET=1
  • TRITON_PRINT_AUTOTUNING=1
  • TRITON_DEBUG=1
  • TRITON_CACHE_DIR=./tcache
  • TRITON_ALWAYS_COMPILE=1

4 xref 맵

  • MLIR / dialect 이론 → ↗ V12
  • TorchInductor → Triton → ↗ V13
  • XLA / TVM 비교 → ↗ V14
  • Triton 실전 코드 → ↗ cudalearning v2 p13·14

5 단계별 파일 확장자

ext단계
.ttirTriton IR (dialect triton)
.ttgirTritonGPU IR (+layout)
.llirLLVM IR (nvptx64)
.ptxPTX assembly
.cubinSASS binary
.jsonmetadata

6 튜닝 knob 우선순위

  1. tile (BM, BN, BK) — 가장 큰 영향
  2. num_warps — warp layout 결정
  3. num_stages — pipeline 깊이
  4. GROUP_M (swizzle) — L2 hit
  5. dtype 선택 (f16 / bf16 / f8) — 별개 구조
6 stages: Py·ttir·ttgir·llir·ptx·sass (AST → Triton → GPU+layout → LLVM → PTX → SASS)
최종 원칙: Triton 튜닝 = pass 동작을 조건에 맞춰 주는 것. knob 은 pass 의 입력 hint.