gpumode · 강의 아카이브
《GPU Mode》 L042 2025 · JAN · 25 High priority transcript · 약 90분 · available

Mosaic GPU

JAX 안에서 Hopper / Blackwell 의 모든 트릭 — TMA, WGMMA, async copy, warp specialization, block cluster — 을 직접 만질 수 있는 Python DSL. Adam Paszke 가 깐 Pallas + Mosaic GPU 의 설계 철학과 FlashAttention 3 를 150줄로 짠 사례를 두 축으로 정리한 학습 노트.

JAX / Pallas Mosaic GPU TMA WGMMA warp specialization MLIR block cluster FlashAttention 3
A
Speaker
Adam Paszke
Google · JAX core · Mosaic GPU 시작자
강의 번호
L042
스피커
Adam Paszke
학습 우선순위
High · 정독
다시 볼 때
FA3 의 150줄을 직접 읽는다
§ 01강의가 풀려는 문제· Why this lecture exists

“peak 성능과 Python 의 개발 속도” 를 한 자리에 — 그 자리가 어떤 모양이어야 하는지

강의의 출발점은 역할 변화다. PyTorch 가 처음 태어났을 때 ML 라이브러리의 목표는 generality + acceptable performance 였다 — 30% 손실은 받아들일 만 했다. 지금은 거꾸로다. transformer 가 거의 모든 모델이 됐고, scaling 이 비용의 대부분을 결정하면서 peak performance 가 먼저, 사용성은 그 다음이 됐다. Mosaic GPU 는 그 반전된 역할 위에서 다시 그린 DSL.

Adam 이 강의 도입에서 명시한 디자인 좌표는 세 개로 압축된다.

  1. compiler 가 못 따라잡는 자리 — Ampere → Hopper → Blackwell 로 갈수록 프로그래밍 모델이 매 세대 다르다. 일반 컴파일러가 이 변화를 따라잡으려면 시간이 걸리고, 그 시간 동안 사용자는 peak 의 60–70% 에서 머문다.
  2. C++ template metaprogramming 의 천장 — CUTLASS / ThunderKittens 가 가는 길. Adam 이 본 한계는 “error message 가 너무 거칠다, 디버깅이 사실상 불가능에 가까운 자리들이 있다”. 같은 metaprogramming 이 Python tracing 으로 풀리면 훨씬 깔끔하다.
  3. JAX 사용자에게 Triton 이 비어 있는 자리 — Triton 은 PyTorch 와 함께 자라며 거의 표준이 됐다. JAX 사용자는 같은 진입점이 없었다. Mosaic GPU 가 그 자리.
강의의 인지적 frame

Mosaic GPU 의 모든 디자인 결정은 한 문장으로 회수된다 — “boilerplate 는 자동화하되, performance 에 영향이 큰 manual knob 은 노출한다”. async copy 의 동기화는 자동화. block cluster / multicast 는 노출. WGMMA accumulator 는 reference 로 노출하지만, async copy 의 barrier wait/arrive 짝짓기는 helper 로 숨긴다.

“우리는 당신의 커널을 LLM-proof 하게 만들고 싶다 — LLM 한테 커널 짜달라고 부탁하지 않아도 되도록 boilerplate 를 다 걷어내고, performance 에 직결되는 결정만 손에 남게.”Adam Paszke · 17:50

강의의 실무적 산출은 두 개다 — Pallas 라는 frontend 와 Mosaic GPU 라는 backend, 그 둘의 결합으로 짠 FlashAttention 3 of 150 lines (Hopper 위에서 tensor core 의 70% 이상 활용). 이 결과 한 줄이 강의 중간에 박혀 있고, 나머지는 그 결과를 어떻게 끌어냈는지의 디자인 reasoning 이다.

§ 02JAX 위 GPU 커널의 빈자리· why a JAX-native DSL

Triton 이 PyTorch 에 한 일을, JAX 에 다시 해야 했다

Adam 이 도입에서 깐 사실 — Triton 의 큰 성공 요인이 integration triviality 였다는 것. PyTorch 에서 그냥 파이썬 함수처럼 import 하고 부른다. Mosaic GPU 도 같은 원칙: JAX 에서, 그리고 PyTorch 에서도 copy 없이 부를 수 있어야 한다.

강의가 깔린 시점(2025년 1월)에 시장에 있던 “Python 에서 GPU 커널을 짠다” 의 후보를 정리하면.

  • Triton — 가장 폭넓게 채택. PyTorch 표준 진입점. 단점: tile-level abstraction 이 일부 Hopper feature(특히 warp-group MMA + cluster + TMA multicast 의 결합)를 직접 노출하기 어렵다.
  • CUTLASS / CuTe — C++. peak 가깝게 짤 수 있지만 metaprogramming 이 template 으로 묶인다. Python 사용자에게는 거리감.
  • ThunderKittens — C++20 의 const-expr 활용. 우아하지만 Python 사용자에겐 또 다른 언어.
  • Triton inline / inductor 의 string 생성 — 빠르게 prototype 할 수 있지만, scoping/이름 충돌 등 string-based metaprogramming 의 고전 문제 그대로.

Mosaic GPU 의 답: JAX tracing 위에 얹는다. jax.experimental.pallas 를 frontend 로, Mosaic GPU 를 backend 로. 사용자는 JAX numpy 와 비슷한 코드를 쓰고, 컴파일러가 Hopper 자원을 하나씩 매핑한다.

FIG · 같은 “Python GPU DSL” 시장의 4축L042 시점 / 2025
L0 · C++CUTLASS / CuTetemplate metaprogramming · peak 성능NVIDIA
모든 Hopper 기능 노출
L1 · C++20ThunderKittensconst-expr · concise · 학습 자료 좋음Stanford HazyResearch
L2 · Python ASTTritontile-level DSL · 가장 큰 채택 · PyTorch 표준OpenAI
L3 · JAX tracingPallas + Mosaic GPUwarp-group level · Hopper feature 직접 노출Google · JAX team
네 도구가 같은 경쟁선이 아니다. CUTLASS/TK 는 C++ 로 peak 까지 갈 수 있고, Triton/Mosaic 는 Python 으로 빠르게 prototype 할 수 있다. Mosaic 의 자리는 “JAX 사용자에게 Triton-급 진입점, 그리고 Triton 이 노출 안 하는 Hopper 디테일까지”.

강의에서 Adam 이 한 번 더 분명히 했다 — "We just want to have a DSL for fast Hopper, and at this point also Blackwell, and we also want to program them from Python." 즉 GPU 의 새 세대를 따라잡을 수 있는 빠르게 진화 가능한 backend 가 그 빈자리.

§ 03메타프로그래밍 4형식· C++ template · string · Python AST · JAX tracing

왜 tracing 이 GPU 커널 metaprogramming 의 가장 자연스러운 형태인가

Mark 가 강의 중 끼어들어 한 질문 — "CUTLASS 는 C++ template 으로, 다른 프로젝트는 string 으로, Triton 과 Mosaic 는 MLIR 로 간다. 차이가 뭔가?" Adam 의 답이 강의의 첫 깊은 자료. 4가지 metaprogramming 패턴을 비교한다.

① C++ template

  • compile-time evaluation 가능
  • 그러나 error message 가 거칠다
  • const-expr (C++20) 가 도움은 되지만 loop unrolling 처럼 자연스럽게 풀리는 건 아님
  • "weird-encoded program 의 디버깅이 근본적으로 어렵다"

② string templating

  • FlashInfer / inductor 의 일부가 채택
  • scoping/name capture 를 직접 관리해야 함
  • "binding lifetime 관리가 까다로워서 제대로 일반화하기 어렵다"
  • 작은 code paste 는 OK

③ Python AST (Triton)

  • AST 를 직접 파싱해서 변환
  • Python control flow 가 자동으로 staged out
  • const-expr 를 자유롭게 추가 가능 (C++ 위원회 거치지 않음)
  • "const vs runtime 의 구분이 약하다 — 두 단계 evaluation 모델이 미묘"

④ JAX tracing (Pallas)

  • 한 번 함수를 evaluate 해서 IR 를 만든다
  • Python if → branch specialization, for → unrolling 자동
  • "이게 GPU 커널에는 오히려 장점 — pragma unroll 안 적어도 된다"
  • 조건분기를 staged 로 두려면 helper combinator (예: jax.lax.cond)

Adam 의 결론 — GPU 커널은 metaprogramming 이 사실상 모든 곳에 등장한다. tile shape 계산, register tile 분할, warp 별 작업 분배 모두 compile-time 결정이다. Python control flow 가 자연스럽게 unroll/specialize 되는 tracing 모델이 가장 적합하다는 입장.

미세하지만 큰 차이

JAX 의 traced loop 는 guaranteed unrolled. Triton 의 tl.range 도 비슷하지만 const-expr 표시가 필요하다. C++ 의 constexpr for 는 fold expression 으로만 가능하고 일반 unroll 과 다르다. 같은 의도를 표현하기 위한 manual 코드량이 언어마다 크게 다르다.

“ML 라이브러리에는 string metaprogramming 도 있고 template metaprogramming 도 있다. 다 잘 동작한다. 단, GPU 커널에서는 tracing 이 가장 적은 비용으로 가장 자연스럽게 풀리더라.”학습 노트 · Adam 의견 정리
§ 04Pallas 와 Mosaic GPU 의 관계· three backends, one frontend

Pallas 는 frontend 다 — backend 가 셋이다

강의에서 가장 많이 헷갈렸던 자리. Pallas 는 JAX 의 “커널 짜는 frontend” 다. 그리고 그 frontend 는 세 backend 로 lowering 된다.

FIG · Pallas 의 세 backendfrontend 동일, backend 분기
L0
Pallas (Python)
JAX numpy 비슷한 DSL · 동일 frontend
B1
Triton backend
block-level · NVIDIA GPU
B2
Mosaic TPU backend
work-group level · TPU
B3
Mosaic GPU backend
warp-group level · Hopper / Blackwell
강의 시점(2025년 1월)에 Pallas frontend 자체는 이미 stable. Mosaic GPU backend 가 가장 새로운 backend로, "low-level 이지만 같은 frontend 추상" 이라는 위치. 같은 Pallas 코드 일부를 Triton / Mosaic GPU 양쪽으로 컴파일할 수 있다.

Adam 이 두 번 강조한 작은 사실 — Pallas 의 “thread” 는 CUDA 의 thread 가 아니다. Pallas 의 thread 한 개는, Mosaic GPU backend 에서는 warp group 한 개에 매핑된다. 이게 Mosaic GPU 의 추상화 레벨 결정의 핵심이다.

왜 warp group 인가 — Adam 의 reasoning.

  • tile/block 단위는 너무 높다 — Triton 의 자리. Hopper 의 디테일이 모두 가려진다.
  • CUDA thread 단위는 너무 낮다 — “dense linear algebra 에서 단일 thread 단위로 partition 하는 건 의미 없다”. 모든 흥미로운 instruction 이 warp 또는 warp group 단위.
  • warp group 이 Hopper 의 자연 단위 — WGMMA 자체가 warp group 위에서 도는 instruction 이다. async copy + barrier 도 warp group 정렬이 자연스럽다.

그래서 Mosaic GPU 의 단일 “Pallas thread” 는 4개의 CUDA warp = 128개의 CUDA thread. WGMMA 의 자연스러운 host.

# Pallas 의 단일 thread = Mosaic GPU 의 warp group
@pl.kernel
def add_kernel(x_ref, y_ref, z_ref):
    # x_ref, y_ref 는 reference (mutable, shaped)
    # [...] 는 dereference
    z_ref[...] = x_ref[...] + y_ref[...]

jax_add = pl.pallas_call(
    add_kernel,
    out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    grid=(num_blocks,),     # 1D 그리드
)
# 결과: 일반 JAX 함수. jit 가능, vmap 가능, autodiff 가능.

Pallas references 는 shaped pointer 다. Python slice 가 “이 자리에서 읽어라/써라” 의 의미. x_ref[...] 가 dereference (C 의 *x 와 비슷), z_ref[...] = ... 가 assign-through-pointer.

§ 05Hopper 자원 노출하기· TMA · WGMMA · cluster · TMA multicast

“boilerplate 는 자동, performance knob 은 manual” 의 실제 모양

Mosaic GPU 의 디자인 원칙은 추상이 아니다. Hopper 의 4가지 핵심 feature 각각에서 무엇을 자동화하고 무엇을 노출했는지를 보면 분명해진다.

TMA
async copy
SMEM 로의 bulk tensor 복사. shape transform 까지 free 로 처리. 자동: barrier 짝짓기, descriptor 생성. manual: tile shape, swizzle.
WGMMA
warp-group MMA
Hopper tensor core. SMEM ↔ register 양쪽에서 operand 가능. 자동: PTX 생성, fragment layout. manual: accumulator 의 reset/accumulate.
cluster
block cluster
같은 cluster 안의 block 들이 SMEM 통신 가능. 자동: cluster size 설정으로 활성화. manual: 사용 결정 자체 (occupancy 영향).
TMA multicast
collective copy
한 번의 GMEM read 가 여러 SM 의 SMEM 으로 동시에. 자동: collective_axes=("x",) 한 줄. manual: 어느 차원으로 묶을지.

이 패턴이 Mosaic GPU 코드를 짧게 만든다. 강의에서 Adam 이 비교한 — “같은 일을 CUTLASS 로 하면 metaprogramming + template 으로 백 줄 단위가 되는데, Pallas 위에서 한 줄 추가로 끝난다” 는 자리들이 cluster + multicast 같은 부분.

왜 이 디자인이 잘 굴러가는가 — 한 단계 더 들어가서

Adam 의 입장은 "compiler 가 우리한테 모든 결정을 잘 해주지 않는다". cluster 를 쓸지 말지는 occupancy / GPU 활용률에 큰 영향을 준다 — 자동화하면 잘못된 결정을 내릴 가능성. 반면 cluster 를 어떻게 구현하는지 (어느 block 이 어디로 broadcast 하는지) 는 거의 항상 같다 — 자동화 가능. “decision 은 manual, mechanism 은 automatic” 의 룰이 디자인 전체를 끌고 간다.

FIG · Hopper 의 producer-consumer 파이프라인 (warp-specialized matmul)Mosaic GPU 가 자동 생성
Producer WG
TMA load A0,B0
arrive
TMA load A1,B1
arrive
TMA load A2,B2
arrive
TMA load A3,B3
arrive
Consumer WG 0
wait
WGMMA acc += A0·B0
wait
WGMMA acc += A1·B1
softmax
WGMMA acc += A2·B2
Consumer WG 1
softmax
WGMMA
softmax
WGMMA
softmax
producer 는 TMA 로 SMEM 에 데이터를 담그고 barrier 로 알려준다. consumer 들은 wait 후 WGMMA 와 softmax 를 번갈아 돈다. Hopper 의 tensor core 와 ALU(softmax) 가 다른 회로이므로 두 consumer 가 각자 critical section 으로 분리되면 둘이 동시에 도는 효과 — 이게 FlashAttention 3 의 critical path.
§ 06예제 1 · matmul 파이프라인· emit_pipeline · block_spec · transforms

가장 짧은 형태의 Hopper-aware matmul 커널

강의 후반부에 깐 첫 큰 예제. matmul 을 Pallas Mosaic GPU 로 짜는 시퀀스를 코드 한 덩이씩 보여준다 — block_spec 으로 tile 을 정의하고, plgpu.emit_pipeline 으로 producer-consumer 자동 생성, plgpu.wgmma 로 tensor core 부른다.

# Mosaic GPU 의 matmul kernel — 강의 시퀀스 정리
def matmul_kernel(a_ref, b_ref, o_ref):
    # a_ref, b_ref 는 SMEM tile (per pipeline stage)
    def compute(_, a_smem, b_smem):
        acc = plgpu.wgmma(zeros, a_smem, b_smem)        # tensor core MMA
        return acc

    acc = plgpu.emit_pipeline(
        compute,
        grid=(K // BLOCK_K,),                            # K-차원 reduction
        in_specs=[
            plgpu.GPUBlockSpec((BLOCK_M, BLOCK_K),
                lambda k: (0, k),
                transforms=plgpu.TilingTransform((64, 32))),
            plgpu.GPUBlockSpec((BLOCK_K, BLOCK_N),
                lambda k: (k, 0),
                transforms=plgpu.TilingTransform((32, 64))),
        ],
    )
    plgpu.copy_smem_to_gmem(acc, o_ref)                  # 결과 store

이 코드의 의미 단위 4개.

  • emit_pipeline — TMA load → barrier wait → WGMMA → barrier arrive 의 producer-consumer 시퀀스를 자동 생성. K-차원으로 num_stages 만큼 prefetch.
  • GPUBlockSpec — tile 의 shape, 그리고 몇번째 grid step 에서 어느 위치를 가져오는지의 mapping (lambda).
  • TilingTransform — TMA descriptor 가 GMEM → SMEM 으로 가져오는 동안 data 자체를 reshuffle. tensor core 가 좋아하는 swizzled layout 으로 free 로 변환.
  • plgpu.wgmma — Hopper warp-group MMA. accumulator 는 Pallas reference 로 모델링되어 += 와 비슷한 의미.
TMA 의 숨은 슈퍼파워

TMA descriptor 는 단순히 “G→S 복사” 만 하지 않는다. stride trick 으로 row-major 데이터를 tile-major 형태로 재배열하면서 복사할 수 있다. tensor core 는 tile-major 를 좋아한다. 이 reshuffle 이 free — 별도 kernel 없이, descriptor 만 바꾸면 된다. Adam 이 강의에서 "another cool technique" 이라며 강조한 부분.

block cluster 한 줄 추가

같은 코드에 cluster 차원을 추가하고 collective_axes=("x",) 한 줄. 같은 cluster 안의 block 들이 GMEM 한 번 읽고 SMEM 으로 multicast — bandwidth 사용량이 절반으로. "한 줄 변경. 단, 쓸지 말지 결정은 occupancy 와 trade-off".

Adam 이 강조한 메타 메시지 — “정말 새로운 개념은 거의 없다. block_spec 은 Triton 의 block pointer 와 의미적으로 같고, transforms 는 SMEM swizzle 의 직접 노출이고, emit_pipeline 은 다단계 prefetching 의 helper 화다.” 친숙한 추상의 재배열이지 새로운 추상의 발명이 아니라는 입장.

§ 07예제 2 · FlashAttention 3· warp specialization · 150 lines

Hopper tensor core 70%+ 활용을 150줄 Python 으로

강의의 클라이맥스. FlashAttention 3 의 핵심 — 두 consumer warp groupWGMMAsoftmax 를 critical section 으로 분리해서, Hopper 의 tensor core 회로와 ALU 회로가 동시에 돌게 만든다 — 가 Mosaic GPU 위에서 자연스럽게 표현된다.

왜 이 trick 이 효과적인가 — Adam 의 설명을 풀면.

  • WGMMA 는 tensor core (HMMA 회로) 를 쓴다 — 이 회로는 matmul 외엔 거의 idle.
  • softmax 는 ALU 를 쓴다 — exp, reduce, divide. 별도 회로.
  • 한 warp group 이 두 일을 번갈아 하면 한 번에 하나의 회로만 사용. 다른 회로는 놀고 있음.
  • 두 consumer WG 가 각자 critical section 으로 들어가도록 barrier 로 묶으면 — 정확히 한 WG 는 softmax, 다른 WG 는 WGMMA 를 동시에. 두 회로 동시 사용.

이 동시성을 표현하는 게 plgpu.emit_pipeline_warp_specialized helper. 같은 producer-consumer 시퀀스를 warp specialized 형태로 lowering.

강의에서 직접 인용

"You use this special barrier to make it so that exactly one of those two warp groups is computing softmax while the other one is computing matmul. ... softmax uses the ALU which is one set of hardware circuits, and the matmuls use the tensor core which is another."

결과 한 줄

같은 trick 을 직접 PTX 로 구현하면 수백 줄. Mosaic GPU 위에서 helper 로 풀면 — 강의 시점 기준으로 FlashAttention 3 가 150줄, helper 가 새로 들어간 후엔 100줄 이하. "information density per line of code" 가 디자인 목표.

강의 끝에서 Adam 이 강조 — “우리가 제공하는 건 attention-specific 한 helper 가 아니다. emit_pipeline_warp_specialized 는 일반 pipeline 의 helper다. 다른 패턴 (decode, prefill, splash attention) 모두 같은 helper 로 풀린다.”

“같은 trick 을 위에서 본 사람이 짜면 어차피 거의 똑같이 짠다. 이 부분은 우리가 자동화하려고 하는 자리.”Adam Paszke · 53:55
§ 08Triton 과의 비교· block-level vs warp-group level

같은 자리를 노리지 않는다 — 추상화 레벨 한 단계 차이

청중에서 가장 자주 나온 질문. Adam 의 답이 분명했다 — "Triton 은 block-level 추상에 집중. Mosaic 는 warp-group level. 둘 다 가치 있는 자리".

Triton

  • Pallas thread = 1 block
  • tile-level op (tl.load, tl.dot) 가 단위
  • fragment layout / register 분배 모두 컴파일러가 결정
  • 대부분의 ML 커널이 이 레벨에서 충분히 빠르다
  • Hopper 의 일부 기능 (cluster + multicast 결합) 은 노출이 약함

Mosaic GPU

  • Pallas thread = 1 warp group
  • warp group 내 register layout 까지 manual 결정 가능
  • Hopper 의 모든 hardware feature 직접 노출
  • 같은 frontend (Pallas) 에서 lowering 만 다르게
  • peak 까지 짜려는 사용자, FlashAttention 같은 복잡 패턴이 표적

Adam 의 입장 — "Triton 의 block-level 도 한 단계 높다고 본다. 하지만 그 위치에서 정말 잘 작동한다. CUDA thread 단위는 dense linear algebra 에선 너무 fine-grained. 우리는 그 사이의 자리". 즉 Mosaic 가 Triton 을 대체한다고 말하지 않는다.

실용적 가이드

대부분의 ML 워크로드 — pointwise, normalization, 단순 attention — 는 Triton 으로 충분히 peak 가깝다. Hopper-specific 의 모든 trick(producer-consumer warp specialization, cluster multicast, TMA descriptor reshuffle) 을 다 끌어쓰고 싶을 때, 그리고 그 trick 들의 결합에 대한 peak 에 가까운 통제권 이 필요할 때, Mosaic GPU 의 자리가 열린다.

§ 09PyTorch 와의 인터롭· no-copy bindings

JAX 사용자만의 도구가 아니다

Adam 이 도입에서 한 번, 끝에서 한 번 강조 — "early PyTorch bindings 가 있다. Mosaic GPU 커널을 PyTorch 에서 부를 때 copy 가 일어나지 않는다".

실용적 함의가 크다. PyTorch 의 학습 루프 안에서 — JAX 를 import 할 필요 없이 — FlashAttention 3 급의 Mosaic GPU 커널을 한 줄로 갈아 끼울 수 있다는 뜻. JAX 가 없는 환경에서도 Mosaic GPU 의 출력물 (Hopper 자원을 다 끌어쓰는 PTX) 를 활용할 수 있는 길.

FIG · Mosaic GPU 의 사용 경로 두 갈래JAX 직접 / PyTorch 인터롭
SRC
Pallas Python
Mosaic GPU backend
JIT
JAX tracing
XLA + custom kernel
A
JAX 호출
jax.jit(...)
B
PyTorch 호출
torch tensor 직접 전달 · no copy
두 호출 경로 모두 동일 PTX 산출. PyTorch 학습 루프 한가운데에서 같은 GPU memory 를 가리키는 tensor 를 그대로 넘기면 된다 — DLPack 기반의 zero-copy 인터롭.
§ 10기억할 메모와 코드· key takeaways · repo

다시 열었을 때 5분 안에 손에 잡혀야 할 것

강의를 다시 열었을 때 가장 빨리 복원해야 하는 사실들과 — 직접 손으로 봐야 할 코드 자료들.

Pallas thread = warp group
Mosaic GPU backend 에서 Pallas 의 단일 thread 는 4 warp = 128 CUDA thread. WGMMA 의 자연 단위.
decision manual, mechanism auto
cluster 사용 결정은 manual (occupancy 영향). cluster 안의 broadcast 구현은 auto. 디자인 룰 전체.
emit_pipeline
producer-consumer 의 TMA load + barrier + WGMMA + barrier arrive 시퀀스 자동 생성. num_stages 가 prefetch 깊이.
TilingTransform
TMA descriptor 가 GMEM → SMEM 복사하면서 reshuffle 까지 free. tensor core 가 좋아하는 swizzled tile-major layout.
warp-specialized FA3
두 consumer WG 가 softmax / WGMMA 를 critical section 으로 분리. ALU + tensor core 회로 동시 사용.
Pallas 3 backend
Triton (NVIDIA), Mosaic TPU (TPU), Mosaic GPU (Hopper/Blackwell). frontend 동일.
PyTorch interop
DLPack 기반 zero-copy. PyTorch 학습 루프에서 JAX import 없이 Mosaic 커널 호출 가능.
tracing-based metaprogramming
Python if → branch specialize, for → unroll. jax.lax.cond/fori_loop 로 staged.
YouTube youtube.com/watch?v=wKd90avC8Nc · 약 90분
Mosaic GPU jax/experimental/mosaic/gpu · open source
Adam Paszke x.com/apaszke

손에 새기기 — 실습 시퀀스

  1. Pallas add kernel 부터 — JAX 환경에서 pl.pallas_call 로 elementwise add 커널을 짠다. x_ref[...] dereference 의미가 손에 잡힐 때까지.
  2. 1D grid 와 program_idgrid=(num_blocks,) 로 grid 를 만들고 pl.program_id(0) 로 자기 위치 조회. 각 thread 가 입력의 일부 슬라이스만 처리하도록.
  3. Mosaic GPU backend 명시plgpu.kernel 또는 compiler_params 로 Mosaic GPU backend 활성화. 같은 코드가 Triton vs Mosaic 양쪽으로 lowering 되는지 확인.
  4. WGMMA matmulplgpu.wgmmaGPUBlockSpec + TilingTransform 으로 minimal matmul. 베이스라인은 jnp.matmul.
  5. emit_pipeline — 같은 matmul 을 K-차원 producer-consumer 로 묶기. num_stages 를 1, 2, 3 으로 sweep — speedup 곡선이 어떻게 변하는지.
  6. FlashAttention 3 코드 읽기 — JAX repo 의 attention 예제. 150줄 안에서 어떤 줄이 producer/consumer 인지, barrier 가 어디 있는지 표시.
  7. PyTorch 인터롭 — Mosaic 커널을 PyTorch tensor 로 호출. torch.cuda.synchronize() 로 정확한 latency 측정.
§ 11다른 강의로 이어지는 길· connections

이 강의의 도구가 다음에 어디에 다시 등장하는지

L042 의 추상 결정 (Pallas frontend, warp-group 단위, manual knob 디자인) 이 시리즈 안에서 어떻게 다시 호출되는지 묶어둔다.

§ 12열린 질문· open questions

다음에 들었을 때 직접 검증해야 할 것들

학습 노트로 정리하면서 의도적으로 비워둔 자리들 — 강의 안에서 부분적으로만 등장한 주제, 또는 후속 자료가 더 깊게 다룰 주제.

검증 메모

본 노트의 인용은 강의 transcript (2025-01-25 자동생성 캡션) 에서 가져왔으며, 일부 시간 표시는 캡션 timestamp 그대로다. 정확한 발언 원문은 영상에서 직접 확인 권장. "FA3 가 150줄" 같은 수치는 강의 발화에서 인용한 것으로, 실제 코드 줄 수는 JAX repo 의 attention 예제에서 확인 가능.

← Lecture 041 FlashInfer — Charles Frye Lecture 043 → int8 tensorcore matmul for Turing — Erik Schultheis