DLRM 추천 모델과 LoRA-on-MLP 두 케이스를 통해 “커널을 합친다” 가 정확히 무엇을 의미하는지 — HBM 왕복 회계, torch.compile 이 만들어내는 fused Triton 커널 읽기, 그리고 거기서 한 단계 더 내려가 직접 CUDA 로 fused 하는 워크플로. Kapil Sharma 가 GPU Mode 18강에서 보여준 compile → 읽기 → 다시 짜기 사이클을 정리한 학습 노트.
Kapil 의 강의는 “fusion 이 왜 빠르게 만드는가” 의 한 줄 답이 아니라, 실제 모델 두 개(DLRM 추천 모델과 LoRA 가 붙은 MLP)를 들고 와서 — 어떻게 fused 커널을 만들어내고, 어떻게 검증하고, 어떻게 더 짜낼지의 워크플로를 깐다.
강의가 깐 큰 질문은 두 개다.
torch.compile 이 자동으로 만들어주는 Triton 커널을 먼저 읽고, 그게 어디서 부족한지 NCU 로 확인한 다음에야 직접 CUDA 로 내려간다.“fused 됐는지” 의 검증은 항상 코드로 한다 — wall-clock 으로가 아니라. TORCH_LOGS=output_code 가 dump 하는 Triton 코드 안에 tmp0 = a + b, tmp1 = tmp0 * c, tmp2 = relu(tmp1) 가 한 커널 안에 들어 있는지를 본다. 들어 있으면 fused, 안 들어 있으면 fusion 이 깨진 것이다.
강의 끝에서 손에 쥐어야 하는 건 3개의 도구와 1개 회계다 — torch.compile + TORCH_LOGS=output_code, Triton fused kernel template, 그리고 load_inline 으로 직접 CUDA. 회계는 “HBM 왕복 N 번 → fused 후 1 번” 의 단순한 산수.
강의 첫 figure 의 메시지는 단순하다. 산술 연산의 양은 fused 와 unfused 가 정확히 같다 — 그런데 메모리 트래픽이 다르다. pointwise 커널은 거의 항상 memory-bound 이므로, 트래픽이 줄면 시간이 줄어든다.
Kapil 이 강의에서 “fusion 의 이득을 가장 직관적으로 표현하는 방법” 이라고 강조한 것이 이 회계다. 하나 더 — 회계에는 launch overhead 가 들어 있지 않다. 작은 텐서에서는 launch 가 dominant 한 케이스가 많다(§ 08 의 CUDA Graphs 가 그 답).
두 조건 중 하나가 충족돼야 한다 — (1) 커널들이 모두 memory-bound 이고 producer 의 출력이 consumer 의 입력으로 곧장 흘러간다. (2) launch overhead 가 dominant 한 작은 텐서에서 — fusion 이 launch 횟수를 줄여준다. 두 조건 다 안 맞으면 fusion 으로 거의 안 빨라진다.
강의는 fusion 을 두 카테고리로 나눠 본다. 의미가 다르고, torch.compile 이 잡는 정도도 다르다.
한 커널의 출력이 다음 커널의 입력으로 곧장 들어간다. add → mul → relu 같은 elementwise 체인이 전형. 중간값이 HBM 을 안 다녀온다.
torch.compile 이 자동으로 잡는 1순위.같은 입력을 여러 다른 op 가 동시에 소비한다. attention QKV 의 세 projection 처럼 — 같은 input X 를 W_Q, W_K, W_V 가 각각 곱한다. 입력 read 한 번을 공유한다.
강의의 첫 실험. Meta 의 DLRM (Deep Learning Recommendation Model) 을 PyTorch 로 돌리고 PyTorch profiler 로 trace 를 찍는다. 그리고 “커널이 너무 많이 호출되는 자리” 를 fusion 표적으로 골라낸다.
(1) embedding lookup — sparse categorical feature 를 dense vector 로. (2) bottom MLP — dense feature 를 같은 차원으로. (3) interaction — embedding 들과 dense vector 를 pairwise dot product. (4) top MLP — interaction 결과 + dense 를 concat 해서 최종 score. CTR (click-through rate) 예측의 표준 구조.
Profiler 에서 잡힌 패턴 — 가장 자주 호출되는 게 vectorized_elementwise_kernel 과 functor_kernel. interaction 단계와 MLP 의 activation 단계가 fusion 표적으로 떠오른다. “span 1” 영역에서 elementwise 가 가장 자주 호출되고, “span 2” 에 들어가면 striking kernel(GEMM 등) 이 보이기 시작한다 — 강의 영상의 39:00 부근 화면.
Kapil 이 보여준 또 한 가지 — DLRM 은 작은 batch 와 작은 hidden dim 으로 도는 경우가 많아서 GEMM 자체도 launch overhead 가 dominant 한 케이스에 빠진다. 큰 모델의 attention 과는 다른 영역의 문제.
L001 에서 깔린 도구가 여기서 본격적으로 쓰인다. torch.compile 한 모델을 돌리면서 TORCH_LOGS=output_code 를 켜면, Inductor 가 만들어낸 Triton 커널이 /tmp 안에 dump 된다. 거기서 한 커널 안에 몇 개 op 가 들어가 있는지 를 직접 읽는다.
# 강의에서 Kapil 이 보여준 minimal 케이스
import torch
def add_mul_relu(a, b, c):
return torch.relu((a + b) * c)
opt = torch.compile(add_mul_relu)
a, b, c = [torch.randn(10000, 10000,
device='cuda') for _ in range(3)]
opt(a, b, c)
# 실행
TORCH_LOGS=output_code python add_mul_relu.py
# /tmp/torchinductor_*/triton_poi_fused_*.py 에 dump
# Inductor 가 만들어낸 Triton 커널 (약식)
@triton.jit
def triton_poi_fused_add_mul_relu_0(
in_ptr0, in_ptr1, in_ptr2, out_ptr0,
xnumel, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.load(in_ptr0 + xindex, xmask)
tmp1 = tl.load(in_ptr1 + xindex, xmask)
tmp2 = tl.load(in_ptr2 + xindex, xmask)
tmp3 = tmp0 + tmp1 # add
tmp4 = tmp3 * tmp2 # mul
tmp5 = triton_helpers.maximum(0, tmp4) # relu
tl.store(out_ptr0 + xindex, tmp5, xmask)
커널 이름이 이미 답한다 — triton_poi_fused_add_mul_relu_0. poi = pointwise, fused = 합쳐짐, 그 뒤에 op 이름이 붙는다. body 안에 tmp3 / tmp4 / tmp5 가 한 커널 안에 흐른다 — 중간값이 register 에만 산다는 뜻. wall-clock 측정 없이도 fusion 이 일어났다는 게 확인된다.
Kapil 이 강의에서 강조한 패턴 — “기존 모델에 torch.compile 을 돌려보고 어떤 커널들이 만들어졌는지 보면, 어디를 수동으로 fuse 할지 힌트가 나온다.” Inductor 가 잡는 fusion 이 “이상적” 인 건 아니다. 가끔 fusion 이 깨지는데(§ 09), 그 경계가 곧 직접 짜야 할 자리.
repo 안 kernels/triton_fused_add_mul_activation.py 와 kernels/triton_fused_add_mul_relu.py 가 이 섹션의 코드. torch.compile 출력에서 변수명만 바꾸고, activation 을 tl.constexpr 인자로 받게 일반화한 형태.
# triton_fused_add_mul_activation.py — 약식
@triton.jit
def fused_add_mul_act_kernel(
a_ptr, b_ptr, c_ptr, out_ptr, n,
BLOCK: tl.constexpr,
ACT: tl.constexpr, # 'relu' / 'gelu' / 'none'
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
a = tl.load(a_ptr + offs, mask)
b = tl.load(b_ptr + offs, mask)
c = tl.load(c_ptr + offs, mask)
y = (a + b) * c
if ACT == 'relu':
y = tl.maximum(y, 0.0)
elif ACT == 'gelu':
y = 0.5 * y * (1.0 + tl.erf(y * 0.7071))
tl.store(out_ptr + offs, y, mask)
강의의 핵심은 “activation 을 인자로 받게 한다” 다 — 같은 fused 커널이 ReLU, GeLU, SiLU 등 어떤 activation 도 처리할 수 있게. 한 커널 = 여러 모델.
tl.constexpr 로 받으면 컴파일 시점에 분기가 사라진다 — 런타임 비용 0.autotune 을 그냥 받을 수 있다 — BLOCK, num_warps, num_stages 를 sweep 하는 코드를 따로 짤 필요 없음. @triton.autotune 데코레이터 한 줄. 직접 CUDA 로 같은 sweep 을 짜려면 한참 더 든다.
그 다음 단계 — kernels/src/pointwise_add_relu_fused.cu 가 같은 일을 직접 CUDA 로 한 버전. load_inline 으로 PyTorch 에 끼워넣었다. 이 시점부터는 “Triton 으로는 잡히지 않는 vectorized load 나 cooperative load” 같은 게 표적이 된다.
// pointwise_add_relu_fused.cu — 약식
__global__ void fused_add_relu_kernel(
const float4* a, const float4* b,
float4* out, int n4) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n4) {
float4 va = a[i];
float4 vb = b[i];
float4 vo;
vo.x = fmaxf(0.0f, va.x + vb.x);
vo.y = fmaxf(0.0f, va.y + vb.y);
vo.z = fmaxf(0.0f, va.z + vb.z);
vo.w = fmaxf(0.0f, va.w + vb.w);
out[i] = vo;
}
}
float4 로 한 thread 가 4 element 를 한 번에 load — vectorized load. Triton 의 default 가 못 잡는 자리 중 하나.
두 번째 케이스는 LoRA. 큰 weight W (frozen) 와 작은 두 행렬 A, B (trainable). forward 는 y = W·x + B·(A·x). 보통 PyTorch 에서는 세 번의 matmul + add 가 따로 launch 된다 — 강의는 이걸 한 fused 커널로 만든다.
repo 안 kernels/src/fused_kernels_lora_on_mlp.cu 가 그 코드. cuBLAS 의 cublasGemmEx 결과 위에 LoRA term 을 직접 붙여 쓰는 방식 — strict 한 epilogue fusion 은 아니지만 — 핵심 idea 는 같다. 중간 결과를 HBM 에 안 쓰고 가까운 메모리에서 끝낸다.
strict epilogue fusion 은 CUTLASS 의 cutlass::epilogue visitor pattern 으로 표현된다 — main loop accumulator 가 끝난 직후 같은 register 위에서 LoRA term 을 더한다. 더 깊은 통합. L023 (Tensor Cores) 와 L036 의 표적.
fusion 이 메모리 트래픽을 줄여도 — launch 자체에 ~5–10μs 정도가 든다. 작은 batch 에서 수백 개의 kernel 을 launch 하는 LLM decoding 에서는 이게 dominant 한다. CUDA Graphs 가 그 자리를 채운다.
torch.cuda.graph(g) context 안에서 model forward 한 번 돌리면, 그 안의 모든 kernel launch 가 graph 로 캡처된다. 이후에는 g.replay() 한 번에 그 graph 전체가 한 번의 launch 비용으로 돈다.
torch.compile(mode="reduce-overhead") 가 이걸 자동으로 켠다.
torch.compile 의 fusion 이 항상 일어나는 건 아니다. output_code 를 봤을 때 fused 커널 안에 op 가 1개 만 있으면 — 거기서 fusion 이 깨진 거다. 강의에서 명시적으로 다룬 자리는 일부지만, 실전에서 자주 만나는 패턴들을 함께 정리한다.
print, .cpu(), .item(), custom non-traceable op 등. graph 가 둘로 쪼개지면서 fusion 도 쪼개진다.torch.compile(dynamic=True) 명시 필요.x.relu_() 다음에 x.view(...) 같은 패턴이 alias 분석을 어렵게 만들어 fusion 안전성 보장 안 됨.(1) TORCH_LOGS=output_code 로 dump → fused 커널 이름 확인. (2) TORCH_LOGS=graph_breaks 로 graph break 위치 확인. (3) torch._dynamo.config.suppress_errors=False 로 “왜 끊겼는지” 메시지 받기. (4) 그래도 안 합쳐지면 — 직접 Triton/CUDA 로 짠다.
강의에서 다시 돌아왔을 때 가장 빨리 복원해야 하는 사실들과 — 직접 손에 박아야 하는 코드 자료들.
TORCH_LOGS=output_code 가 dump 한 Triton 커널 이름 — triton_poi_fused_X_Y_Z_0 의 X/Y/Z 가 합쳐진 op 들.relu/gelu/silu 모두 처리 — Triton 에서는 tl.constexpr ACT 로 분기, CUDA 에서는 template 으로.+B·A·x 도 이 패턴.torch.compile(mode="reduce-overhead") 가 자동으로 켠다.output_code 와 graph_breaks 로 진단.x = (a+b)*c; y = relu(x) 를 unfused 와 fused (torch.compile) 두 버전으로 측정. nvidia-smi dmon 또는 NCU 의 DRAM throughput 으로 트래픽이 약 2× 차이가 나는지 확인.TORCH_LOGS=output_code python script.py. 만들어진 Triton 커널 이름 안에 fused_X_Y_Z 의 X/Y/Z 가 무엇인지 확인. 같은 코드를 torch.compile 없이 돌렸을 때 어떻게 다른지 비교.triton_fused_add_mul_activation.py 의 패턴을 SiLU 까지 확장. @triton.autotune 로 BLOCK · num_warps sweep.float4 로 직접 CUDA 짜고 load_inline 으로 PyTorch 에 끼운다. Triton 버전과 비교 — 차이가 얼마나 나는지.fused_kernels_lora_on_mlp.cu 를 읽고, A·x 의 결과가 register 에만 사는 구조를 도식으로 그려본다. CUTLASS 의 epilogue 와 비교.print() 가 들어간 모델에 torch.compile 을 붙이고 TORCH_LOGS=graph_breaks 로 어디서 끊기는지 확인.torch.compile(mode="reduce-overhead") 와 default 모드의 trace 를 비교. cudaGraphLaunch 가 trace 에 보이는지.L018 의 도구들이 시리즈 안에서 어떻게 다시 호출되는지를 묶어둔다.
TORCH_LOGS=output_code 가 거기서 깔린다. 그 도구를 이 강의가 본격 활용학습 노트로 정리하면서 의도적으로 비워둔 자리 — 강의에서 부분적으로만 다뤄졌거나, 후속 강의에서 본격적으로 등장하는 주제.
fused_kernels_lora_on_mlp.cu 가 cuBLAS 위에 LoRA term 을 직접 더하는지, 아니면 CUTLASS 스타일 epilogue 인지 코드 직접 확인 필요.triton_fused_add_mul_activation.py 의 default 가 무엇인지 확인.이 노트의 % 분포 그래프와 cycle 추정치는 모두 강의 화면을 재구성한 예시 값이다. 자기 GPU 에서 직접 PyTorch profiler 와 NCU 를 돌려 검증해야 깊이 들어간다.