threadIdx, warp, shared memorytl.arange, tl.load, tl.dot 위의 block tensor| 레벨 | 사용자 단위 | 할당 주체 |
|---|---|---|
| CUDA C++ | thread | user |
| CUTLASS / CuTe | thread + tile | user (layout algebra) |
| Triton | block tile | compiler |
| Linalg / XLA HLO | tensor op | compiler stack |
cf. CuTe 자세히 ↗ V06 §4
| hide | expose |
|---|---|
| thread index | program_id(axis) |
| warp / lane | num_warps hint |
| shmem 배치 | constexpr shape |
| bank conflict | — |
| sync barrier | — |
| pipeline stage | num_stages hint |
| TC fragment | tl.dot |
Triton 사용법·실제 matmul 코드 → ↗ cudalearning v2 p13·14
Python source (@triton.jit fn)
│ [1] AST walk, type infer
▼
Triton IR (.ttir) dialect: triton
│ [2] convert-triton-to-tritongpu
▼
TritonGPU IR (.ttgir) dialect: triton_gpu (+layout)
│ [3] optimize-passes (coalesce/pipeline/...)
▼
TritonGPU IR optimized (.ttgir)
│ [4] convert-tritongpu-to-llvm
▼
LLVM IR (.llir) target nvptx64
│ [5] llc / ptxas frontend → NVPTX codegen
▼
PTX text (.ptx) sm_80 / sm_90
│ [6] ptxas / driver JIT
▼
SASS binary sm-specific
# Python @triton.jit def add(X,Y,Z,N,BS:tl.constexpr): pid = tl.program_id(0) o = pid*BS + tl.arange(0,BS) m = o < N x = tl.load(X+o, m); y = tl.load(Y+o, m) tl.store(Z+o, x+y, m)
// Triton IR (ttir) — layout 無
tt.func @add(%X, %Y, %Z, %N) {
%pid = tt.get_program_id x
%r = tt.make_range {start=0,end=BS}
%o = arith.addi (%pid*BS), %r
%m = arith.cmpi slt, %o, %N
%x = tt.load %X+%o, %m
%y = tt.load %Y+%o, %m
%s = arith.addf %x, %y
tt.store %Z+%o, %s, %m
}
// TritonGPU IR (ttgir) — #blocked layout 주입
#bl = #triton_gpu.blocked<{
sizePerThread=[4], threadsPerWarp=[32],
warpsPerCTA=[4], order=[0]}>
%x : tensor<1024xf32, #bl>
| stage | 주요 pass |
|---|---|
| ttir | inliner, combine, canonicalize |
| ttir→ttgir | convert-triton-to-tritongpu |
| ttgir | coalesce, layout-infer, remove-layout-conversions |
| ttgir | pipeline, prefetch, accelerate-matmul |
| ttgir→llir | convert-tritongpu-to-llvm |
| llir→ptx | NVPTX backend (LLVM) |
| ptx→sass | ptxas (closed source) |
각 pass 세부 → 이후 §7~§11
sm_arch 에 바인딩 → arch 변경 시 재컴파일@triton.jit 의 역할tl.* 호출을 MLIR triton dialect op 으로 번역하는
custom compiler entry.
ast.parse 로 tree 화 → CodeGenerator visitor 가 IR builder 호출| 구문 | 지원 |
|---|---|
| if / else | scf.if 로 lowering |
| for range(...) | scf.for (static bound 권장) |
| while | scf.while (제한적) |
| assert | dev assert |
| recursion | 불가 |
| list / dict | 불가 (constexpr tuple 만) |
# Python o = pid * BS + tl.arange(0, BS)
# AST nodes
BinOp(Add, BinOp(Mult, Name(pid), Name(BS)),
Call(Attribute(tl, arange),
[Const 0, Name BS]))
// Triton IR emission
%bs = arith.constant BS : i32
%p0 = arith.muli %pid, %bs
%rng = tt.make_range {start=0, end=BS}
%pb = tt.splat %p0 : tensor<BSxi32>
%o = arith.addi %pb, %rng
tl.* | MLIR op |
|---|---|
tl.program_id | tt.get_program_id |
tl.arange | tt.make_range |
tl.load | tt.load |
tl.store | tt.store |
tl.dot | tt.dot |
tl.reduce / max | tt.reduce |
tl.atomic_* | tt.atomic_rmw / cas |
| broadcast | tt.broadcast |
| reshape | tt.reshape |
torch.float16 → f16torch.float32 → f32int64 ptr → !tt.ptr<f16>tl.int32 + tl.float16 → f16alignment hint (pow2) 도 specialization 영향 → §14
tl.constexpr 로 선언.
tritontt..
arith, scf, cf, mathinclude/triton/Dialect/Triton/IR/TritonOps.td| op | 의미 |
|---|---|
tt.get_program_id | grid axis id |
tt.get_num_programs | grid size |
tt.make_range | [start, end) i32 tensor |
tt.splat | scalar → tensor |
tt.broadcast | rank-match broadcast |
tt.load / tt.store | ptr tensor + mask |
tt.dot | C += A·B (tile) |
tt.reduce | axis reduce + combiner region |
tt.atomic_rmw | atomic op |
tt.func / tt.return | kernel entry |
f32 / i32tensor<128×64×f16>tensor<128×64×!tt.ptr<f16>> gather/scatter 전용%v = tt.load %ptrs, %mask, %other
{cache = ca, evict = normal}
: tensor<128x64xf16>
tt.store %p_out, %v, %mask
ca/cg/cs/wb → PTX .ca/.cg 매핑tt.dot %a, %b, %c = C += A · B
over (M, N, K) tile. Ttir 단계에서는 어떻게 실현될지 미정
(FMA vs TC, layout 미결정).
%r = "tt.reduce"(%x) ({
^bb(%a: f32, %b: f32):
%s = arith.maxnumf %a, %b
tt.reduce.return %s : f32
}) {axis = 1 : i32}
: (tensor<128x64xf32>) -> tensor<128xf32>
ptr tensor + i32 tensor = ptr tensor (element-wise, shape 일치) BroadCast shape 맞추기: A_ptrs = A + m[:,None]*sAm + k[None,:]*sAk
실 example ↗ cudalearning v2 p13 §4
triton_gpu ★#triton_gpu / op 는 대부분 tt.* 유지 + layout 부여triton_gpu.convert_layout, triton_gpu.alloc_tensor, triton_gpu.async_copy_global_to_localtensor<128x64xf16, #blocked> tensor<128x64xf16, #shared> tensor<128x64xf16, #mma> tensor<128x64xf16, #dot_operand<{opIdx=0, parent=#mma}>>
layout 은 같은 tensor 를 누가 어떻게 들고 있는지 를 기술한다.
triton_gpu.convert_layout %x : src → dst
은 tensor 의 값은 동일하되 layout 만 변경한다.
실제 lowering 시 shmem 경유한 재분배 로 구현됨.
blocked → shared (store to shmem) shared → dot_op (ldmatrix 발행) mma → blocked (TC 결과 재배치)
insert / remove 는 Layout Inference pass 가 관리 (§7).
// Ampere: cp.async
%tok = triton_gpu.async_copy_global_to_local
%gptr, %smem_slice
: tensor<…, #blocked> → memdesc<…, #shared>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 2 : i32}
cp.async PTX// ttir
%x : tensor<1024xf32>
%y : tensor<1024xf32>
%z = arith.addf %x, %y : tensor<1024xf32>
// ttgir (default blocked 주입)
#bl = #triton_gpu.blocked<{
sizePerThread=[4], threadsPerWarp=[32],
warpsPerCTA=[4], order=[0]}>
%x : tensor<1024xf32, #bl>
%y : tensor<1024xf32, #bl>
%z = arith.addf %x, %y : tensor<1024xf32, #bl>
| layout | 용도 | 생성자 | 변환 |
|---|---|---|---|
| Blocked | register tile (load/store/ewise) | Coalescing pass | → Shared |
| Shared | shmem 배치 (swizzle) | Pipeline/Prefetch | ↔ Blocked, → DotOp |
| Mma | TC accumulator fragment | AccelerateMatmul | → Blocked (store) |
| DotOperand | tt.dot A/B operand fragment | AccelerateMatmul | ← Shared (ldmatrix) |
| Slice | reduce 결과 1D | Reduce lowering | ← parent (axis drop) |
tensor<128x64xf16, #blocked<{
sizePerThread = [4, 4],
threadsPerWarp = [8, 4],
warpsPerCTA = [4, 1],
order = [1, 0]}>
→ thread 한 개가 (4,4) block 소유
→ warp 하나가 (32,16) tile 커버
→ 4 warps 면 (128, 16) — K축 4 iter
#shared = #triton_gpu.shared<{
vec = 8, perPhase = 2, maxPhase = 4,
order = [1,0]
}>
#mma<{versionMajor=2, warpsPerCTA=[2,2], instrShape=[16,8]}>#mma<{versionMajor=3, warpsPerCTA=[4,1], instrShape=[64,N,16]}> (WGMMA)mma.sync fragment 와 1:1 대응mma.sync / WGMMA 상세 → ↗ V03 §7 · V04 §7
#dot_a = #triton_gpu.dot_op<{
opIdx = 0, parent = #mma, kWidth = 8
}>
#dot_b = #triton_gpu.dot_op<{
opIdx = 1, parent = #mma, kWidth = 8
}>
ldmatrix 생성예: #slice<{dim=1, parent=#blocked}> → reduce(axis=1) 후.
tt.dot 는 Mma 가 필요하고, 재사용 경계에는 Shared 가 필요.
Inference pass 가 어디서 어떤 layout 을 쓸지를 결정한다.
convert_layout 삽입| op | 선호 layout |
|---|---|
tt.load (contig) | Blocked (coalesced) |
tt.store (contig) | Blocked (coalesced) |
tt.dot result | Mma |
tt.dot operand | DotOperand (parent Mma) |
tt.reduce 결과 | Slice (parent=입력) |
async_copy → local | Shared |
after insertion: 많은 convert_layout → redundant 제거 pass: · convert(convert(x)) → convert(x) · convert(x) where src==dst → x · hoist convert out of loop (loop-invariant) · sink convert past elementwise
결과적으로 핵심 boundary 에만 convert 남음.
// before %a = tt.load … : #blocked %b = tt.load … : #blocked %c = tt.dot %a, %b : #mma // after inference %a_b = tt.load … : #blocked %a_s = convert_layout %a_b : #shared %a_d = convert_layout %a_s : #dot_op<0,mma> %b_b = tt.load … : #blocked %b_s = convert_layout %b_b : #shared %b_d = convert_layout %b_s : #dot_op<1,mma> %c = tt.dot %a_d, %b_d, %c : #mma
이후 Pipeline pass 가 shared stage 를 loop 로 분할 → §9
tt.load / tt.store 의 pointer tensor 패턴을 분석해,
결과 BlockedLayout 이 128B-aligned coalesced access 를
만들도록 sizePerThread / order 를 수정한다.
coalescing 기초 규칙 → ↗ V01 §8
tt.make_range, tt.splat, arith.addi)order[0] 로 설정| dtype | max vec | PTX |
|---|---|---|
| f16 / bf16 | 8 elem (128b) | ld.global.v4.b32 |
| f32 | 4 elem | ld.global.v4.b32 |
| f8 | 16 elem | ld.global.v4.b32 |
| i32 / u32 | 4 elem | ld.global.v4.b32 |
// tile <128,64> f16, contig = axis 1 sizePerThread = [1, 8] // vec=8 along axis 1 threadsPerWarp = [4, 8] warpsPerCTA = [4, 1] order = [1, 0] → thread 32개가 한 row 의 64 col 커버
convert-triton-to-tritongpu → coalesce-pass (여기) → accelerate-matmul → layout-propagation → remove-layout-conv → pipeline
cp.async (§↗ V03 §6)// sync load + compute, no pipeline scf.for %k = 0 to K step BK { %a_b = tt.load A[%k] // blocks %b_b = tt.load B[%k] %a_s = conv %a_b : #shared %b_s = conv %b_b : #shared %a_d = conv %a_s : #dot_op %b_d = conv %b_s : #dot_op %c = tt.dot %a_d, %b_d, %c }
// prologue: stage 0,1 선행 발행 async_copy A[0] → smem[0]; async_commit async_copy B[0] → smem[0] async_copy A[1] → smem[1]; async_commit async_copy B[1] → smem[1] scf.for %k = 0 to K-2*BK step BK { // 곧 쓸 stage 는 wait async_wait {num = 2} // stage t : compute %a_d = conv smem[t] // ldmatrix %b_d = conv smem[t] %c = tt.dot %a_d, %b_d, %c // stage t+2 : 새 load 발행 async_copy A[k+2*BK] → smem[(t+2)%3] async_copy B[k+2*BK] → smem[(t+2)%3] async_commit } // epilogue: 남은 stage 소진
| num_stages | shmem bytes | latency 은닉 |
|---|---|---|
| 2 | 2·tile 크기 | 약함 |
| 3 | 3·tile | 보통 (Ampere 기본) |
| 4 | 4·tile | 강함 (Hopper TMA) |
| ≥5 | shmem 초과 위험 | register spill 위험 |
shmem capacity: A100 192KB, H100 228KB (configurable) ↗ V02 §3
async_copy/convert_layout(blocked→shared) 식별(t + k) % stagesnum_stages 키우면 항상 빠르진 않다.
shmem 초과 시 occupancy 1 block 으로 강등되며 오히려 느려진다.
| 비교 | Pipeline Pass | Prefetch Pass |
|---|---|---|
| 계층 | global → shmem | shmem → register |
| 도구 | cp.async / TMA | ldmatrix · LDS |
| 단위 | stage (BK 전체) | sub-tile (ldmatrix unit) |
| 목적 | HBM latency 은닉 | shmem → reg latency 은닉 |
before (k 단위만 pipeline):
smem[t] ─ldmatrix→ A_frag
smem[t] ─ldmatrix→ B_frag
│
▼
tt.dot
after (+ prefetch on k-sub):
A_frag[0] already loaded (prologue)
for kk in 0..BK/mma_k:
A_frag[kk+1] = ldmatrix smem[t] (kk+1)
B_frag[kk+1] = ldmatrix smem[t] (kk+1)
tt.dot A_frag[kk], B_frag[kk], c
tt.dot 가 K-loop 내부에 있어야 함| 상황 | prefetch 효과 |
|---|---|
| BK 큼 (≥64) | ↑ dot-load 겹침 ↑ |
| reg spill 발생 | ↓ 오히려 악화 |
| HD 큰 attention | ↑ main bottleneck 해소 |
| elementwise fused | 미미 |
tt.dot 를 MmaLayout 기반으로 정확히
HW Tensor Core instruction 이 떨어지도록 재작성.
operand layout 을 DotOperand 로, 결과 layout 을 Mma 로 확정.
| sm | MmaLayout ver | instr |
|---|---|---|
| sm_70 (V100) | v1 | mma.m8n8k4 |
| sm_75 (Turing) | v1 | mma.m16n8k8 |
| sm_80/86 (Ampere) | v2 | mma.sync.m16n8k16 |
| sm_90 (Hopper) | v3 | wgmma.m64nNk16 |
dtype별 shape 매트릭스 → ↗ V03 §7 · V04 §7
shmem (A tile)
│ convert_layout shared → dot_op
▼ (lowering: ldmatrix)
A_dot_op ─┐
│ tt.dot
B_dot_op ─┤ → C_mma
│
shmem (B tile)
│ convert_layout shared → dot_op
▼ (lowering: ldmatrix.trans 가능)
B_dot_op
#shared swizzle 과 정합.x1/.x2/.x4 = 한 호출당 8x8 tile 수.trans variant = B operand K-major 일 때ldmatrix 개요 → ↗ V03 §8
wgmma FP8 e4m3/e5m2 지원tt.dot_scaled 로 표현 (block-scaling)low-bit 정밀도 → ↗ V09 · V10
coalesce → accelerate-matmul (여기) → layout-propagation → pipeline → prefetch
MLIR → LLVM 일반 이론 → ↗ V12 §10
| ttgir | LLVM / NVVM |
|---|---|
tt.load (vec) | ld.global.v4.b32 (inline asm) |
tt.store | st.global.v4.b32 |
tt.dot (mma) | llvm.nvvm.mma.m16n8k16.* |
tt.dot (wgmma) | inline PTX wgmma.mma_async |
| convert shared→dot | inline PTX ldmatrix |
| async_copy | inline PTX cp.async.cg |
| async_wait | cp.async.wait_group |
tt.reduce | llvm.nvvm.shfl.sync.* + shmem |
tt.atomic_rmw | LLVM atomicrmw / inline PTX |
| program_id | llvm.nvvm.read.ptx.sreg.ctaid.* |
tensor<128x64xf16, #blocked>
layout: thread 당 (4,4) = 16 elements
→ 각 thread 가 16개 f16 reg 소유
→ LLVM: vector<16xf16> 혹은
!llvm.struct<(f16, f16, ..., f16)>
op 는 thread 당 element 수 만큼 펼쳐진다 (unroll).
| IR | NVPTX AS |
|---|---|
| global ptr | addrspace(1) |
| shared (shmem) | addrspace(3) |
| constant | addrspace(4) |
| local (stack/spill) | addrspace(5) |
| generic | addrspace(0) |
NVPTX AS 상세 → ↗ V12 §10
// 실제 생성되는 LLVM IR snippet
call void asm sideeffect
"cp.async.cg.shared.global [$0], [$1], 16;",
"r,l"(i32 %smem_off, i8* %gptr)
r(32b reg), l(64b reg), f(float reg)LLVM IR (nvptx64 target) → LLVM opt: instcombine, mem2reg, ... → NVPTX backend codegen → PTX text file → ptxas (external) → SASS
SASS / ptxas 는 NVIDIA closed source → ↗ V04 §12
@triton.autotune 구조@triton.autotune( configs=[ triton.Config({'BM':128,'BN':256, 'BK':64,'GM':8}, num_warps=8, num_stages=3), triton.Config({'BM':64, 'BN':64, 'BK':32,'GM':8}, num_warps=4, num_stages=2), … ], key=['M','N','K'], prune_configs_by={'perf_model': fn, 'early_config_prune': fn}, ) @triton.jit def mm(…): …
(M,N,K) 호출 → 재탐색 없이 best config 사용~/.triton/cache)| 방식 | 의미 |
|---|---|
| static | shmem / reg 추정 → 상한 초과 config 제거 |
| perf_model | 사용자 제공 cost 함수 top-k |
| restore_value | in-place 연산 시 원복용 |
| warmup/rep | 벤치마크 반복 수 |
@triton.heuristics = deterministic 규칙 (shape→config)@triton.autotune = 실측 기반 탐색key 에 포함되지 않은 shape 축에 대해서는
config 가 재사용된다. dynamic batch 일 때
key=['M','N','K'] 만 쓰면 잘못된 config 가 선택될 수 있음.
| 종류 | 재컴파일? | 특성 |
|---|---|---|
| 일반 int | N (sig 동일 시) | runtime value |
| tl.constexpr | Y (값 달라지면) | compile-time |
| pow2 hint | 부분 | alignment specialization |
if CAUSAL 이 사라지며 dead code 제거@triton.jit def norm(X, Y, N, BS: tl.constexpr, USE_BF16: tl.constexpr): pid = tl.program_id(0) o = pid*BS + tl.arange(0, BS) x = tl.load(X+o, o < N) if USE_BF16: x = x.to(tl.bfloat16) ...
USE_BF16=True/False → 서로 다른 kernel 두 버전 컴파일.
tl.load(ptr, ..., eviction_policy=..., cache_modifier=...)kernel(…, DTYPE: tl.constexpr)
→ 같은 kernel 소스가
f16, bf16, f32 각각 독립 cache
→ 런타임 dispatch:
if dt=='f16': kernel[grid](…, f16)
elif dt=='bf16': kernel[grid](…, bf16)
| env | 효과 |
|---|---|
TRITON_CACHE_DIR | JIT cache 위치 (default ~/.triton/cache) |
MLIR_ENABLE_DUMP=1 | 각 pass 전후 IR stdout dump |
TRITON_ENABLE_LLIR_DUMP=1 | LLVM IR 덤프 |
TRITON_INTERPRET=1 | kernel 을 CPU 에서 Python 실행 (debug) |
TRITON_DEBUG=1 | compile 단계 세부 로그 |
TRITON_PRINT_AUTOTUNING=1 | best config 출력 |
~/.triton/cache/<hash>/ ├── kernel.ttir Triton IR ├── kernel.ttgir TritonGPU IR ├── kernel.llir LLVM IR ├── kernel.ptx PTX text ├── kernel.cubin SASS (arch-specific) └── kernel.json metadata (smem, reg, shared, num_warps)
| 경로 | 내용 |
|---|---|
include/triton/Dialect/Triton/ | ttir dialect 정의 (ODS) |
include/triton/Dialect/TritonGPU/ | ttgir dialect + layout attr |
lib/Dialect/TritonGPU/Transforms/ | 핵심 pass (Coalesce, Pipeline, …) |
lib/Conversion/TritonToTritonGPU/ | ttir→ttgir |
lib/Conversion/TritonGPUToLLVM/ | ttgir→llir |
python/triton/compiler/ | frontend · cache · compile |
python/triton/runtime/ | autotune · launch |
python/triton/language/ | tl.* API |
python/triton/language/core.py — tl.* 정의compiler/code_generator.py — AST → IR// -----// IR Dump Before / After … 구분선 기준으로 pass 영향 비교diff before.ir after.irTRITON_INTERPRET=1 → Python 으로 tile 연산 에뮬tl.device_print 허용MLIR_ENABLE_DUMP=1 은 모든 pass 전후 IR 를 찍는다.
kernel 하나에 수천줄 출력 가능 → 파일 redirect 필수.
| 요소 | sm_80 | sm_90 |
|---|---|---|
| global→shmem | cp.async | TMA (bulk tensor) |
| sync | bar.sync | mbarrier (async) |
| TC instr | mma.sync m16n8k16 | wgmma.mma_async m64nNk16 |
| TC 단위 | warp (32) | warpgroup (128) |
| cluster | 없음 | thread block cluster |
| shmem 접근 | local | distributed shared |
HW 상세 → ↗ V02 §9 · V04 §4·§7
WGMMA shape 매트릭스 → ↗ V04 §8
| param | sm_80 | sm_90 |
|---|---|---|
| num_warps | CTA thread 수 / 32 | warpgroup 수 × 4 |
| num_stages 2 | prologue 얇음 | 의미 약함 (TMA 깊은 파이프) |
| num_stages 4~6 | shmem 초과 위험 | TMA 권장 구간 |
setmaxnreg 로 register 재분배 (inc/dec)패턴 상세 → ↗ V04 §9
sm 반영 필수.
| stage | pass | user 영향 |
|---|---|---|
| frontend | AST → ttir | constexpr 설계 |
| ttir→ttgir | convert-to-tritongpu | num_warps 선택 |
| ttgir | coalesce | ptr affine 패턴 유지 |
| ttgir | accelerate-matmul | tl.dot 명시 |
| ttgir | layout-propagate | convert 수 ↓ |
| ttgir | pipeline | num_stages 선택 |
| ttgir | prefetch | reg 압박 주의 |
| ttgir→llir | convert-to-llvm | — |
| llir→ptx | NVPTX backend | sm_arch 지정 |
| 상황 | layout |
|---|---|
| global load/store | Blocked (order=contig) |
| shmem stage buffer | Shared (swizzle) |
| tt.dot operand | DotOperand(parent=Mma) |
| tt.dot result | Mma |
| reduce 결과 | Slice(parent) |
MLIR_ENABLE_DUMP=1TRITON_ENABLE_LLIR_DUMP=1TRITON_INTERPRET=1TRITON_PRINT_AUTOTUNING=1TRITON_DEBUG=1TRITON_CACHE_DIR=./tcacheTRITON_ALWAYS_COMPILE=1| ext | 단계 |
|---|---|
| .ttir | Triton IR (dialect triton) |
| .ttgir | TritonGPU IR (+layout) |
| .llir | LLVM IR (nvptx64) |
| .ptx | PTX assembly |
| .cubin | SASS binary |
| .json | metadata |