JAX 안에서 Hopper / Blackwell 의 모든 트릭 — TMA, WGMMA, async copy, warp specialization, block cluster — 을 직접 만질 수 있는 Python DSL. Adam Paszke 가 깐 Pallas + Mosaic GPU 의 설계 철학과 FlashAttention 3 를 150줄로 짠 사례를 두 축으로 정리한 학습 노트.
강의의 출발점은 역할 변화다. PyTorch 가 처음 태어났을 때 ML 라이브러리의 목표는 generality + acceptable performance 였다 — 30% 손실은 받아들일 만 했다. 지금은 거꾸로다. transformer 가 거의 모든 모델이 됐고, scaling 이 비용의 대부분을 결정하면서 peak performance 가 먼저, 사용성은 그 다음이 됐다. Mosaic GPU 는 그 반전된 역할 위에서 다시 그린 DSL.
Adam 이 강의 도입에서 명시한 디자인 좌표는 세 개로 압축된다.
Mosaic GPU 의 모든 디자인 결정은 한 문장으로 회수된다 — “boilerplate 는 자동화하되, performance 에 영향이 큰 manual knob 은 노출한다”. async copy 의 동기화는 자동화. block cluster / multicast 는 노출. WGMMA accumulator 는 reference 로 노출하지만, async copy 의 barrier wait/arrive 짝짓기는 helper 로 숨긴다.
강의의 실무적 산출은 두 개다 — Pallas 라는 frontend 와 Mosaic GPU 라는 backend, 그 둘의 결합으로 짠 FlashAttention 3 of 150 lines (Hopper 위에서 tensor core 의 70% 이상 활용). 이 결과 한 줄이 강의 중간에 박혀 있고, 나머지는 그 결과를 어떻게 끌어냈는지의 디자인 reasoning 이다.
Adam 이 도입에서 깐 사실 — Triton 의 큰 성공 요인이 integration triviality 였다는 것. PyTorch 에서 그냥 파이썬 함수처럼 import 하고 부른다. Mosaic GPU 도 같은 원칙: JAX 에서, 그리고 PyTorch 에서도 copy 없이 부를 수 있어야 한다.
강의가 깔린 시점(2025년 1월)에 시장에 있던 “Python 에서 GPU 커널을 짠다” 의 후보를 정리하면.
Mosaic GPU 의 답: JAX tracing 위에 얹는다. jax.experimental.pallas 를 frontend 로, Mosaic GPU 를 backend 로. 사용자는 JAX numpy 와 비슷한 코드를 쓰고, 컴파일러가 Hopper 자원을 하나씩 매핑한다.
강의에서 Adam 이 한 번 더 분명히 했다 — "We just want to have a DSL for fast Hopper, and at this point also Blackwell, and we also want to program them from Python." 즉 GPU 의 새 세대를 따라잡을 수 있는 빠르게 진화 가능한 backend 가 그 빈자리.
Mark 가 강의 중 끼어들어 한 질문 — "CUTLASS 는 C++ template 으로, 다른 프로젝트는 string 으로, Triton 과 Mosaic 는 MLIR 로 간다. 차이가 뭔가?" Adam 의 답이 강의의 첫 깊은 자료. 4가지 metaprogramming 패턴을 비교한다.
if → branch specialization, for → unrolling 자동jax.lax.cond)Adam 의 결론 — GPU 커널은 metaprogramming 이 사실상 모든 곳에 등장한다. tile shape 계산, register tile 분할, warp 별 작업 분배 모두 compile-time 결정이다. Python control flow 가 자연스럽게 unroll/specialize 되는 tracing 모델이 가장 적합하다는 입장.
JAX 의 traced loop 는 guaranteed unrolled. Triton 의 tl.range 도 비슷하지만 const-expr 표시가 필요하다. C++ 의 constexpr for 는 fold expression 으로만 가능하고 일반 unroll 과 다르다. 같은 의도를 표현하기 위한 manual 코드량이 언어마다 크게 다르다.
강의에서 가장 많이 헷갈렸던 자리. Pallas 는 JAX 의 “커널 짜는 frontend” 다. 그리고 그 frontend 는 세 backend 로 lowering 된다.
Adam 이 두 번 강조한 작은 사실 — Pallas 의 “thread” 는 CUDA 의 thread 가 아니다. Pallas 의 thread 한 개는, Mosaic GPU backend 에서는 warp group 한 개에 매핑된다. 이게 Mosaic GPU 의 추상화 레벨 결정의 핵심이다.
왜 warp group 인가 — Adam 의 reasoning.
그래서 Mosaic GPU 의 단일 “Pallas thread” 는 4개의 CUDA warp = 128개의 CUDA thread. WGMMA 의 자연스러운 host.
# Pallas 의 단일 thread = Mosaic GPU 의 warp group
@pl.kernel
def add_kernel(x_ref, y_ref, z_ref):
# x_ref, y_ref 는 reference (mutable, shaped)
# [...] 는 dereference
z_ref[...] = x_ref[...] + y_ref[...]
jax_add = pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(num_blocks,), # 1D 그리드
)
# 결과: 일반 JAX 함수. jit 가능, vmap 가능, autodiff 가능.
Pallas references 는 shaped pointer 다. Python slice 가 “이 자리에서 읽어라/써라” 의 의미. x_ref[...] 가 dereference (C 의 *x 와 비슷), z_ref[...] = ... 가 assign-through-pointer.
Mosaic GPU 의 디자인 원칙은 추상이 아니다. Hopper 의 4가지 핵심 feature 각각에서 무엇을 자동화하고 무엇을 노출했는지를 보면 분명해진다.
collective_axes=("x",) 한 줄. manual: 어느 차원으로 묶을지.이 패턴이 Mosaic GPU 코드를 짧게 만든다. 강의에서 Adam 이 비교한 — “같은 일을 CUTLASS 로 하면 metaprogramming + template 으로 백 줄 단위가 되는데, Pallas 위에서 한 줄 추가로 끝난다” 는 자리들이 cluster + multicast 같은 부분.
Adam 의 입장은 "compiler 가 우리한테 모든 결정을 잘 해주지 않는다". cluster 를 쓸지 말지는 occupancy / GPU 활용률에 큰 영향을 준다 — 자동화하면 잘못된 결정을 내릴 가능성. 반면 cluster 를 어떻게 구현하는지 (어느 block 이 어디로 broadcast 하는지) 는 거의 항상 같다 — 자동화 가능. “decision 은 manual, mechanism 은 automatic” 의 룰이 디자인 전체를 끌고 간다.
강의 후반부에 깐 첫 큰 예제. matmul 을 Pallas Mosaic GPU 로 짜는 시퀀스를 코드 한 덩이씩 보여준다 — block_spec 으로 tile 을 정의하고, plgpu.emit_pipeline 으로 producer-consumer 자동 생성, plgpu.wgmma 로 tensor core 부른다.
# Mosaic GPU 의 matmul kernel — 강의 시퀀스 정리
def matmul_kernel(a_ref, b_ref, o_ref):
# a_ref, b_ref 는 SMEM tile (per pipeline stage)
def compute(_, a_smem, b_smem):
acc = plgpu.wgmma(zeros, a_smem, b_smem) # tensor core MMA
return acc
acc = plgpu.emit_pipeline(
compute,
grid=(K // BLOCK_K,), # K-차원 reduction
in_specs=[
plgpu.GPUBlockSpec((BLOCK_M, BLOCK_K),
lambda k: (0, k),
transforms=plgpu.TilingTransform((64, 32))),
plgpu.GPUBlockSpec((BLOCK_K, BLOCK_N),
lambda k: (k, 0),
transforms=plgpu.TilingTransform((32, 64))),
],
)
plgpu.copy_smem_to_gmem(acc, o_ref) # 결과 store
이 코드의 의미 단위 4개.
emit_pipeline — TMA load → barrier wait → WGMMA → barrier arrive 의 producer-consumer 시퀀스를 자동 생성. K-차원으로 num_stages 만큼 prefetch.GPUBlockSpec — tile 의 shape, 그리고 몇번째 grid step 에서 어느 위치를 가져오는지의 mapping (lambda).TilingTransform — TMA descriptor 가 GMEM → SMEM 으로 가져오는 동안 data 자체를 reshuffle. tensor core 가 좋아하는 swizzled layout 으로 free 로 변환.plgpu.wgmma — Hopper warp-group MMA. accumulator 는 Pallas reference 로 모델링되어 += 와 비슷한 의미.TMA descriptor 는 단순히 “G→S 복사” 만 하지 않는다. stride trick 으로 row-major 데이터를 tile-major 형태로 재배열하면서 복사할 수 있다. tensor core 는 tile-major 를 좋아한다. 이 reshuffle 이 free — 별도 kernel 없이, descriptor 만 바꾸면 된다. Adam 이 강의에서 "another cool technique" 이라며 강조한 부분.
같은 코드에 cluster 차원을 추가하고 collective_axes=("x",) 한 줄. 같은 cluster 안의 block 들이 GMEM 한 번 읽고 SMEM 으로 multicast — bandwidth 사용량이 절반으로. "한 줄 변경. 단, 쓸지 말지 결정은 occupancy 와 trade-off".
Adam 이 강조한 메타 메시지 — “정말 새로운 개념은 거의 없다. block_spec 은 Triton 의 block pointer 와 의미적으로 같고, transforms 는 SMEM swizzle 의 직접 노출이고, emit_pipeline 은 다단계 prefetching 의 helper 화다.” 친숙한 추상의 재배열이지 새로운 추상의 발명이 아니라는 입장.
강의의 클라이맥스. FlashAttention 3 의 핵심 — 두 consumer warp group이 WGMMA 와 softmax 를 critical section 으로 분리해서, Hopper 의 tensor core 회로와 ALU 회로가 동시에 돌게 만든다 — 가 Mosaic GPU 위에서 자연스럽게 표현된다.
왜 이 trick 이 효과적인가 — Adam 의 설명을 풀면.
이 동시성을 표현하는 게 plgpu.emit_pipeline_warp_specialized helper. 같은 producer-consumer 시퀀스를 warp specialized 형태로 lowering.
"You use this special barrier to make it so that exactly one of those two warp groups is computing softmax while the other one is computing matmul. ... softmax uses the ALU which is one set of hardware circuits, and the matmuls use the tensor core which is another."
같은 trick 을 직접 PTX 로 구현하면 수백 줄. Mosaic GPU 위에서 helper 로 풀면 — 강의 시점 기준으로 FlashAttention 3 가 150줄, helper 가 새로 들어간 후엔 100줄 이하. "information density per line of code" 가 디자인 목표.
강의 끝에서 Adam 이 강조 — “우리가 제공하는 건 attention-specific 한 helper 가 아니다. emit_pipeline_warp_specialized 는 일반 pipeline 의 helper다. 다른 패턴 (decode, prefill, splash attention) 모두 같은 helper 로 풀린다.”
청중에서 가장 자주 나온 질문. Adam 의 답이 분명했다 — "Triton 은 block-level 추상에 집중. Mosaic 는 warp-group level. 둘 다 가치 있는 자리".
tl.load, tl.dot) 가 단위Adam 의 입장 — "Triton 의 block-level 도 한 단계 높다고 본다. 하지만 그 위치에서 정말 잘 작동한다. CUDA thread 단위는 dense linear algebra 에선 너무 fine-grained. 우리는 그 사이의 자리". 즉 Mosaic 가 Triton 을 대체한다고 말하지 않는다.
대부분의 ML 워크로드 — pointwise, normalization, 단순 attention — 는 Triton 으로 충분히 peak 가깝다. Hopper-specific 의 모든 trick(producer-consumer warp specialization, cluster multicast, TMA descriptor reshuffle) 을 다 끌어쓰고 싶을 때, 그리고 그 trick 들의 결합에 대한 peak 에 가까운 통제권 이 필요할 때, Mosaic GPU 의 자리가 열린다.
Adam 이 도입에서 한 번, 끝에서 한 번 강조 — "early PyTorch bindings 가 있다. Mosaic GPU 커널을 PyTorch 에서 부를 때 copy 가 일어나지 않는다".
실용적 함의가 크다. PyTorch 의 학습 루프 안에서 — JAX 를 import 할 필요 없이 — FlashAttention 3 급의 Mosaic GPU 커널을 한 줄로 갈아 끼울 수 있다는 뜻. JAX 가 없는 환경에서도 Mosaic GPU 의 출력물 (Hopper 자원을 다 끌어쓰는 PTX) 를 활용할 수 있는 길.
강의를 다시 열었을 때 가장 빨리 복원해야 하는 사실들과 — 직접 손으로 봐야 할 코드 자료들.
num_stages 가 prefetch 깊이.if → branch specialize, for → unroll. jax.lax.cond/fori_loop 로 staged.pl.pallas_call 로 elementwise add 커널을 짠다. x_ref[...] dereference 의미가 손에 잡힐 때까지.grid=(num_blocks,) 로 grid 를 만들고 pl.program_id(0) 로 자기 위치 조회. 각 thread 가 입력의 일부 슬라이스만 처리하도록.plgpu.kernel 또는 compiler_params 로 Mosaic GPU backend 활성화. 같은 코드가 Triton vs Mosaic 양쪽으로 lowering 되는지 확인.plgpu.wgmma 와 GPUBlockSpec + TilingTransform 으로 minimal matmul. 베이스라인은 jnp.matmul.num_stages 를 1, 2, 3 으로 sweep — speedup 곡선이 어떻게 변하는지.torch.cuda.synchronize() 로 정확한 latency 측정.L042 의 추상 결정 (Pallas frontend, warp-group 단위, manual knob 디자인) 이 시리즈 안에서 어떻게 다시 호출되는지 묶어둔다.
학습 노트로 정리하면서 의도적으로 비워둔 자리들 — 강의 안에서 부분적으로만 등장한 주제, 또는 후속 자료가 더 깊게 다룰 주제.
본 노트의 인용은 강의 transcript (2025-01-25 자동생성 캡션) 에서 가져왔으며, 일부 시간 표시는 캡션 timestamp 그대로다. 정확한 발언 원문은 영상에서 직접 확인 권장. "FA3 가 150줄" 같은 수치는 강의 발화에서 인용한 것으로, 실제 코드 줄 수는 JAX repo 의 attention 예제에서 확인 가능.