LESSON 10 · 2026.04.20 · L4
nsys · ncu 로 내 커널 9 개를 뜯어봤더니 — 숫자 뒤에 숨어 있던 것들
"빠르다 / 느리다" 를 넘어서 "왜 그 숫자인가" 에 닿는 도구. 레슨 1–9 에서 넘어갔던 세 가지 주장을 nsys timeline 과 ncu stall counter 로 숫자화한 세션.
GPU · L4 · sm_89
tools · nsys 2025.1 · ncu 2025.2
phases · 3
레슨 1–9 까지 CUDA / Triton 커널 10 개를 짰다. 각각 벤치를 돌리고 "v4 가 v1 보다 200× 빨랐다", "ours 가 SDPA 의 78 % 속도를 냈다" 같은 숫자를 남겼다.
그런데 왜 그 숫자인지 정확히 모른 채 넘어간 부분이 많았다:
- atomic 이 느리다는 걸 알지만, 얼마나 느린지, 그 "느림" 이 HW counter 어디에 나타나는지?
- pinned vs pageable 의 속도 차이가 어느 경로 에서 나는지?
- Triton 으로 짠 FA 가 SDPA 의 78 % 라는데, 잃는 22 % 가 어디에 있는지?
이번 글은 기존 커널을 한 줄도 안 바꾸고 nsys (timeline profiler) 와 ncu (per-kernel metric profiler) 로 뜯어본 기록. 새 커널을 더 짜는 것보다 이걸 먼저 하는 게 시간당 이득이 컸다.
두 도구, 두 시각
| 도구 | 보여주는 것 | 오버헤드 | 대표 질문 |
| nsys (Nsight Systems) | 시간축 이벤트 타임라인 (CUDA API, kernel, memcpy, stream sync) | ~1–5 % | "시간축 어디에서 뭐가 기다리고 있나" |
| ncu (Nsight Compute) | 커널 한 번의 내부 HW counter (stall reason, tensor pipe, memory SOL) | 10–30 × 실시간 (replay) | "warp 가 매 cycle 뭐 하고 있나" |
nsys 는 "kernel 과 transfer 의 critical path 균형" 을 본다. ncu 는 "커널 내부에서 warp 가 놀고 있는지, 계산하는지, 기다리는지" 를 본다. 둘이 겹치지 않는다.
Phase 1 · nsys — pageable 은 D2H 에서 특히 느리다
질문: 레슨 04 에서 pinned memory 가 pageable 보다 빠르다고 했다. 얼마나, 어느 방향에서?
실험: bin/vector_add --n 16M --iterations 5 을 두 번 — 한 번은 --pageable, 한 번은 --pinned — nsys 아래서 돌린 뒤 .nsys-rep 를 로컬로 가져와서 GUI + CLI stats 로 숫자 추출.
| direction | pageable GB/s | pinned GB/s | speedup |
| H2D (134 MB) | 4.77 | 12.35 | 2.59× |
| D2H (67 MB) | 1.33 | 13.19 | 9.91× |
놀란 점: pageable H2D 가 4.77 GB/s 인데 pageable D2H 는 1.33 GB/s — 같은 PCIe 인데 3.6× 느리다. 타임라인을 보면 이유가 명확:
- pageable D2H: device → pinned bounce buffer → pageable host 의 2-hop. 마지막 memcpy-to-pageable 이 SIMD-친화적이지 않고 page fault / cache eviction 이 섞여서 느림.
- pageable H2D: reverse 이지만, 드라이버가 pinned staging buffer 에 먼저 복사한 뒤 DMA 로 GPU 까지. 이 "host memcpy → pinned" 구간이 OS 입장에서 seq-read 라 상대적으로 빠름.
- pinned: 두 방향 모두 0-hop (DMA 직결) 대칭 — 12–13 GB/s 수렴.
L4 PCIe Gen4 x16 effective BW ≈ 26 GB/s. pinned 는 그 ~50 % 도달. pageable D2H 는 5 % 에 머무름 — 구조에서 나오는 세금.
그리고 커널 시간은 움직이지 않는다 (pageable 0.834 ms, pinned 0.836 ms). pinning 은 전송 경로만 건드리지 on-device 실행과 무관 — 이 당연한 사실을 숫자로 확인.
한 줄 결론
pageable → pinned 전환을 "전송 시간 2× 빨라짐" 으로 모호하게 기억하지 말고, "D2H 에서 10× 빨라짐" 으로 기억하자. 실제 사용자 latency 에서 가장 큰 변화가 거기서 온다.
Phase 2 · ncu — atomic 의 "느림" 은 occupancy 가 아니라 lg_throttle
질문: 레슨 02 reduction v1 (atomicAdd per thread) 이 v4 (warp shuffle + block 당 1 atomic) 보다 수백 배 느린 건 아는데, 그 "느림" 이 정확히 어떤 HW counter 에 드러나는가?
실험: bin/reduction --n 4M --version {1,4} 을 각각 ncu --set detailed --launch-skip 20 --launch-count 1 -k "regex:reduce_v{1,4}_" 로 한 번씩 뜬 뒤, stall 분포 + SOL 지표 비교.
| metric | v1 (atomic per thread) | v4 (shuffle + block atomic) |
| Elapsed cycles | 12,085,435 | 55,229 (218× 적음) |
| DRAM throughput | 0.46 % | 88.2 % (192× 많음) |
| L2 hit rate | 88.74 % (!) | 0.95 % |
| Achieved occupancy | 91.17 % | 91.89 % (거의 같음) |
| Dominant stall | lg_throttle 31.1 % | long_scoreboard 84.6 % |
세 가지 놀라운 점:
- Occupancy 가 같다. 둘 다 91–92 %. 직관적으론 "v1 이 atomic 에 막혀서 warp 가 못 뜰 것" 같지만, 사실은 warp 는 다 뜨는데 뜬 채로 기다린다. occupancy 엔 잡히지 않는다.
- DRAM 이 비어 있다. v1 의 DRAM 0.46 %. 이 커널은 memory-bound 가 아니다.
- L2 hit 이 88.74 % — 비정상적으로 높다. 모든 thread 가 같은 4-byte accumulator 를 건드리니까 그 cache line 이 L2 에 못 박혀서 계속 hit. 하지만 hit 이 많다고 빠른 게 아니다 — 모든 SM 이 그 한 line 을 두고 싸우는 직렬화 가 일어난다.
이게 lg_throttle 31.1 % 로 나옴 — "local/global memory throttle", LSU (load/store unit) 가 atomic path 에서 back-pressure 를 받는 신호. v4 에서는 lg_throttle 0 %, dominant stall 이 long_scoreboard (정상 DRAM load 대기) 로 바뀌고, DRAM 이 88 % 까지 차면서 memory-bound 의 건강한 모양 이 된다.
레슨
"atomic 이 느리다" 는 이 정도 세부로 기억하자 — "atomic 은 L2 cache line serialization 을 만들고, 그게 lg_throttle 로 counter 에 나오고, 그 사이 DRAM 은 빈다". 세 문장이 같이 있어야 "왜 느린지" 가 설명됐다.
Phase 3 · ncu 로 우리 커널 vs SDPA 의 20 % gap 추적
질문: 레슨 09 에서 Triton 으로 짠 4-D causal FA 가 F.scaled_dot_product_attention 의 78–90 % 속도. 그 22 % 가 어디에 있나?
실험: B=1 H=32 N=2048 d=128 causal fp16 (LLaMA-7B mid-range, gap 이 가장 컸던 shape). 두 구현을 각각 한 번씩 ncu 로 뜨고 metric 비교.
먼저 발견한 것 — SDPA backend 가 cuDNN 이 아니었다. kernel 이름:
void flash_fwd_kernel<Flash_fwd_kernel_traits<128, 64, 64, 4, 0, 0, half_t, ...>>(Flash_fwd_params)
이건 Tri Dao 의 Flash Attention 2 CUDA 구현 — PyTorch 2.11 이 번들로 가지고 있다가 L4 + fp16 + causal 조합에서 디스패치한 것. cuDNN 아님. 즉 우리는 Triton FA 를 같은 알고리즘의 숙성된 CUDA 구현 과 비교하게 된다.
| metric | ours (Triton) | SDPA (FA-2 CUDA) | 비율 |
| Elapsed cycles | 1,565,141 | 827,328 | 1.89× |
| Compute (SM) throughput | 39.3 % | 72.1 % | 1.84× |
| Tensor pipe utilization | 44.6 % | 78.8 % | 1.77× |
| DRAM throughput | 10.6 % | 20.3 % | 1.92× |
| Registers per thread | 255 (spill 직전) | 184 | 0.72× |
| Achieved occupancy | 8.3 % | 16.2 % | 1.95× |
Stall 분포:
| stall reason | ours | SDPA |
| total samples | 78,144 | 42,886 |
wait (MMA output dep) | 38.6 % | 19.0 % |
selected (issue 됨) | 21.7 % | 13.6 % |
math_pipe_throttle (tensor saturation) | 19.4 % | 41.5 % |
short_scoreboard (reg dep) | 14.9 % | 2.2 % |
20 % gap 이 있는 네 군데
- Register pressure → Occupancy 반토막. ours 는
BLOCK_M=128 을 autotune 이 골라서 register 가 255 (literally max, spill 직전). 그 결과 SM 에 resident warp 수가 절반. SDPA 는 BLOCK_M=64 타일로 184 reg/thread, 2 배의 warp 를 동시에 살려 둔다. Occupancy 8.3 % vs 16.2 %.
- MMA dependency chain (
wait 38.6 %). tl.dot 직후 output accumulator 를 너무 빨리 consume. num_stages 가 부족해서 producer MMA 가 아직 끝나지 않은 상태에서 consumer 가 기다림. SDPA 는 wait 19 % 로 절반.
- Register dependency (
short_scoreboard 14.9 % vs SDPA 2.2 %). (1) 의 연쇄 — register file 이 꽉 차서 producer-consumer 가 자주 같은 물리 register 를 참조.
- SDPA 는 이미 "좋은 bottleneck" 에 도달함.
math_pipe_throttle 41.5 % — tensor core 가 포화. 이게 wait 보다 좋은 신호인 이유: "FLOP 을 더 박아야 빨라지는" 구간에 왔다는 뜻. 우리는 거기까지 못 감.
이 세션의 세 가지 교훈
(a) Occupancy 는 throughput 이 아니다
Phase 2 와 Phase 3 의 공통점: occupancy 만 봤으면 틀린 진단을 내렸을 것.
- Phase 2: v1 과 v4 의 occupancy 는 91–92 % 로 같은데 wall cycle 은 218× 차이.
- Phase 3: ours 의 occupancy 가 8.3 % 로 SDPA 의 16.2 % 대비 절반 이지만, 이건 "warp pool 이 비어서" 가 아니라 "큰 tile + 높은 register pressure" 의 부작용.
occupancy 는 "몇 warp 가 살 수 있느냐" 의 상한. "그 warp 가 뭘 하고 있느냐" 는 별도로 봐야 하고, DRAM / compute / tensor pipe 의 SOL % 와 stall 분포를 같이 봐야 한다.
(b) ncu 의 stall 분포 = kernel 의 성격 지문
long_scoreboard dominant = DRAM / L2 load 대기 — memory-bound. 처방: access pattern, tiling.
math_pipe_throttle dominant = tensor / FP pipe saturated — compute-bound (healthy). 처방: "더 빨리 가기 어렵다" — 알고리즘 변경 또는 HW 변경.
wait dominant = MMA output dependency — 파이프라이닝 부족. 처방: num_stages 올리기, accumulator 패턴 재구성.
lg_throttle dominant = LSU atomic / misaligned — 알고리즘 설계 문제. 처방: 알고리즘 재설계 (Phase 2 의 reduction v1 → v4).
이 분포를 보고 나서야 "무엇을 고쳐야 할지" 가 명확해진다. ncu 없이 이 판단은 못 한다.
(c) 큰 tile 이 빠를 거라는 직관은 틀릴 수 있다
Phase 3 에서 autotune 은 BLOCK_M=128 을 골랐지만, 그게 L4 의 이 shape 에서 최적은 아니었다. register pressure 가 warp pool 을 고갈. 작은 tile 의 장점 (register 적게 씀 → occupancy 올라감, K/V 재사용 주기 짧아서 파이프라이닝 잘 됨) vs 큰 tile (각 block 이 한 번 읽고 많이 계산 → AI 높음) 은 측정하지 않으면 모름. 그리고 autotune 이 wall-time 으로 best 를 골라도 "그 best 가 HW 를 최대로 쓰고 있는가" 는 ncu 로 확인해야 한다.
실용적인 도구 체인
이 세션 이후로 새 커널 작성 시 기본 체크리스트가 바뀌었다:
- 커널 돌려서 wall time 재기
nsys 로 timeline 떠서 kernel vs transfer 의 critical path 확인
ncu --set detailed 로 DRAM / Compute / Tensor pipe 의 SOL % 확인
- Stall reason 분포 확인 — 어떤 stall 이 dominant 인가?
- dominant stall 에 따라 처방 (위 표)
"speedup 을 자랑하기 전에" 적어도 DRAM % 와 dominant stall 은 기록으로 남긴다. 반대로 "왜 느린지 모르는 상태" 에서 발표하는 자료는 이제 만들지 않는다.
마지막으로 — 이 세션이 말하는 것
- 레슨 04 의 pinning 효과는 D2H 10×, H2D 2.6× 로 비대칭. nsys timeline 의 2-hop memcpy path 를 보면 이해된다.
- 레슨 02 의 reduction v1→v4 218× 차이는 occupancy 가 아니라 atomic serialization (
lg_throttle 31 %) 때문.
- 레슨 09 의 FA 가 SDPA 의 78–90 % 에서 잃는 22 % 는 register pressure (255 regs → occupancy 절반) + MMA dependency (
wait 39 %) 때문. SDPA 는 이미 tensor pipe throttle (healthy bottleneck) 지점.
이 세 인사이트는 프로파일링 툴 없이는 얻을 수 없는 해석. 그리고 이 해석이 있어야 다음 iteration 에서 뭘 고쳐야 할지 말이 된다. 새 커널 0 개 만든 세션이지만, 다음부터의 모든 커널 튜닝의 출발점을 앞당겼다 는 의미에서 제일 남는 장사였다.
부록 — 재현용 커맨드
# Phase 1 — nsys timeline diff (pinned vs pageable)
./scripts/gcp_run_lesson10_phase1.sh <PROJECT_ID> us-west1-b cuda-l4-dev-lesson10
# Phase 2 — ncu reduction v1 vs v4
./scripts/gcp_run_lesson10_phase2.sh <PROJECT_ID> us-west1-b cuda-l4-dev-lesson10
# Phase 3 — ncu ours vs SDPA
./scripts/gcp_run_lesson10_phase3.sh <PROJECT_ID> us-west1-b cuda-l4-dev-lesson10
GCP DL image 에서 ncu 는 sudo -E env PATH=$PATH ncu ... 로 감싸야 perf counter 접근 가능.
LESSON 10 · 2026.04.20 · L4
Nine kernels under the knife — what was hiding behind the numbers
Beyond "fast/slow," reaching "why that number." Turning three claims from lessons 1–9 into numbers with nsys timelines and ncu stall counters.
GPU · L4 · sm_89
tools · nsys 2025.1 · ncu 2025.2
phases · 3
Through lessons 1–9 I wrote ten CUDA / Triton kernels. I benched each and logged numbers like "v4 is 200× faster than v1" and "ours runs at 78 % of SDPA."
But there was a lot I had moved past without actually knowing why the number was what it was:
- I knew atomic is slow — but how much, and in which HW counter does that slowness show up?
- What path creates the pinned-vs-pageable gap?
- Our Triton FA is 78 % of SDPA — where is the 22 % we lose?
This essay is a record of tearing those kernels apart with nsys (timeline profiler) and ncu (per-kernel metric profiler) — without changing a single line of kernel code. Doing this first paid back more per hour than writing another kernel.
Two tools, two viewpoints
| tool | what it shows | overhead | the question it answers |
| nsys (Nsight Systems) | Time-axis event timeline (CUDA API, kernel, memcpy, stream sync) | ~1–5 % | "where on the time axis is something waiting?" |
| ncu (Nsight Compute) | Per-kernel internal HW counters (stall reason, tensor pipe, memory SOL) | 10–30 × real time (replay) | "what is a warp doing every cycle?" |
nsys looks at the "critical-path balance between kernel and transfer." ncu looks at "whether warps inside a kernel are idle, computing, or waiting." They don't overlap.
Phase 1 · nsys — pageable is especially slow on D2H
Question: Lesson 04 said pinned memory is faster than pageable. How much, in which direction?
Experiment: Run bin/vector_add --n 16M --iterations 5 twice — once with --pageable, once with --pinned — under nsys, then pull .nsys-rep locally and inspect in GUI + CLI stats.
| direction | pageable GB/s | pinned GB/s | speedup |
| H2D (134 MB) | 4.77 | 12.35 | 2.59× |
| D2H (67 MB) | 1.33 | 13.19 | 9.91× |
The surprise: pageable H2D at 4.77 GB/s but pageable D2H at 1.33 GB/s — same PCIe, and yet 3.6× slower. The timeline makes the reason obvious:
- Pageable D2H: a 2-hop path — device → pinned bounce buffer → pageable host. That final memcpy-to-pageable isn't SIMD-friendly and mixes in page faults / cache eviction.
- Pageable H2D: reverse — the driver first copies into a pinned staging buffer, then DMAs to the GPU. "host memcpy → pinned" is close to a seq-read from the OS's view, relatively fast.
- Pinned: 0-hop (DMA direct) in both directions, symmetric — converges to 12–13 GB/s.
L4's PCIe Gen4 x16 effective BW ≈ 26 GB/s. Pinned reaches ~50 % of that. Pageable D2H sits at 5 % — a structural tax.
And the kernel time doesn't move (pageable 0.834 ms, pinned 0.836 ms). Pinning touches only the transfer path, not on-device execution — an obvious fact, now confirmed with numbers.
One-line takeaway
Don't remember pageable → pinned as the vague "transfer gets 2× faster." Remember it as "D2H gets 10× faster." That's where the user-visible latency drop actually comes from.
Phase 2 · ncu — atomic's "slowness" is not occupancy, it's lg_throttle
Question: Lesson 02 reduction v1 (atomicAdd per thread) is hundreds of times slower than v4 (warp shuffle + 1 atomic per block). Fine — but what specific HW counter exposes that slowness?
Experiment: Run bin/reduction --n 4M --version {1,4} each under ncu --set detailed --launch-skip 20 --launch-count 1 -k "regex:reduce_v{1,4}_", then compare stall distribution + SOL metrics.
| metric | v1 (atomic per thread) | v4 (shuffle + block atomic) |
| Elapsed cycles | 12,085,435 | 55,229 (218× fewer) |
| DRAM throughput | 0.46 % | 88.2 % (192× higher) |
| L2 hit rate | 88.74 % (!) | 0.95 % |
| Achieved occupancy | 91.17 % | 91.89 % (essentially the same) |
| Dominant stall | lg_throttle 31.1 % | long_scoreboard 84.6 % |
Three surprises:
- Occupancy is the same. Both sit at 91–92 %. Intuitively you'd think "v1's warps can't launch because atomic blocks them." In fact the warps do launch — and then sit there waiting. That doesn't register in occupancy.
- DRAM is empty. v1's DRAM is 0.46 %. This kernel is not memory-bound.
- But L2 hit is 88.74 % — absurdly high. All threads touch the same 4-byte accumulator, so that cache line gets pinned in L2 and keeps hitting. But lots of hits isn't speed — every SM fighting over one line creates serialization.
That shows up as lg_throttle 31.1 % — "local/global memory throttle," a signal that the LSU (load/store unit) is getting back-pressured on the atomic path. In v4, lg_throttle goes to 0 %, the dominant stall flips to long_scoreboard (normal DRAM load wait), DRAM fills to 88 %, and the kernel takes the healthy shape of a memory-bound kernel.
Lesson
Remember "atomic is slow" at this resolution: "atomic creates L2 cache-line serialization, which shows up on the counter as lg_throttle, and meanwhile DRAM sits empty." Only with those three sentences together is "why it's slow" actually explained.
Phase 3 · ncu-tracing the 20 % gap between ours and SDPA
Question: Lesson 09's 4-D causal FA in Triton hit 78–90 % of F.scaled_dot_product_attention. Where is the 22 %?
Experiment: B=1 H=32 N=2048 d=128 causal fp16 (LLaMA-7B mid-range, where the gap was biggest). Profile each under ncu and compare metrics.
First finding — SDPA's backend wasn't cuDNN. Kernel name:
void flash_fwd_kernel<Flash_fwd_kernel_traits<128, 64, 64, 4, 0, 0, half_t, ...>>(Flash_fwd_params)
That's Tri Dao's Flash Attention 2 CUDA implementation — PyTorch 2.11 ships it and dispatches to it on L4 + fp16 + causal. Not cuDNN. So we're actually comparing Triton FA to a seasoned CUDA implementation of the same algorithm.
| metric | ours (Triton) | SDPA (FA-2 CUDA) | ratio |
| Elapsed cycles | 1,565,141 | 827,328 | 1.89× |
| Compute (SM) throughput | 39.3 % | 72.1 % | 1.84× |
| Tensor pipe utilization | 44.6 % | 78.8 % | 1.77× |
| DRAM throughput | 10.6 % | 20.3 % | 1.92× |
| Registers per thread | 255 (verge of spilling) | 184 | 0.72× |
| Achieved occupancy | 8.3 % | 16.2 % | 1.95× |
Stall distribution:
| stall reason | ours | SDPA |
| total samples | 78,144 | 42,886 |
wait (MMA output dep) | 38.6 % | 19.0 % |
selected (issued) | 21.7 % | 13.6 % |
math_pipe_throttle (tensor saturation) | 19.4 % | 41.5 % |
short_scoreboard (reg dep) | 14.9 % | 2.2 % |
Four places the 20 % gap lives
- Register pressure → occupancy halved. Autotune picked
BLOCK_M=128, pushing registers to 255 (literally the max, on the verge of spill). Resident warps on the SM get halved. SDPA uses BLOCK_M=64, 184 regs/thread, and keeps 2× the warps alive. Occupancy 8.3 % vs 16.2 %.
- MMA dependency chain (
wait 38.6 %). We consume the output accumulator too close to a tl.dot. num_stages is low, so the consumer waits on the producer MMA. SDPA's wait is only 19 %.
- Register dependency (
short_scoreboard 14.9 % vs SDPA 2.2 %). A follow-on effect of #1 — with the register file stuffed, producer-consumer often reference the same physical register.
- SDPA is already sitting at a "good" bottleneck.
math_pipe_throttle 41.5 % — tensor core saturated. That's a better signal than wait: it means they're in the "you'd need more FLOPs to go faster" regime. We don't reach it.
Three lessons from this session
(a) Occupancy is not throughput
Common thread across Phase 2 and Phase 3: had I only looked at occupancy, I'd have made the wrong diagnosis.
- Phase 2: v1 and v4 have the same 91–92 % occupancy, but 218× wall-cycle difference.
- Phase 3: our 8.3 % occupancy is half of SDPA's 16.2 %, but that's a side effect of "big tile + high register pressure," not "empty warp pool."
Occupancy caps "how many warps can live." What those warps are doing is separate, and you need DRAM / compute / tensor-pipe SOL % + stall distribution to see it.
(b) A ncu stall distribution is a kernel's fingerprint
long_scoreboard dominant = waiting on DRAM / L2 — memory-bound. Fix: access pattern, tiling.
math_pipe_throttle dominant = tensor / FP pipe saturated — compute-bound (healthy). Fix: "hard to go faster" — change the algorithm or the hardware.
wait dominant = MMA output dependency — pipelining deficit. Fix: raise num_stages, reshape accumulator usage.
lg_throttle dominant = LSU atomic / misaligned — an algorithm design issue. Fix: redesign (Phase 2's reduction v1 → v4).
Only once this distribution is visible does "what to fix" become clear. Without ncu, you can't make that judgement.
(c) "Bigger tile is faster" is an unreliable intuition
In Phase 3, autotune picked BLOCK_M=128, but that wasn't optimal for this shape on L4. Register pressure drained the warp pool. Small tiles (fewer registers → more resident warps, shorter K/V reuse cycle makes software pipelining easier) vs big tiles (each block reads once and computes more → higher arithmetic intensity) — you don't know without measuring. And autotune picking best-by-wall-time doesn't guarantee that best is actually using the HW fully. Confirm with ncu.
The practical tool chain I took home
- Run the kernel, measure wall time.
nsys timeline → check kernel vs transfer critical path.
ncu --set detailed → check SOL % for DRAM / Compute / Tensor pipe.
- Read stall reason distribution. What's dominant?
- Prescribe based on dominant stall (see table above).
"Before bragging about the speedup," log at least the DRAM % and the dominant stall. Conversely, no more publishing numbers from a "I don't know why it's fast/slow" state.
Closing — what this session is saying
- Lesson 04's pinning effect is asymmetric — D2H 10×, H2D 2.6×. The nsys timeline's 2-hop memcpy path explains it.
- Lesson 02's reduction v1→v4 218× gap is not occupancy — it's atomic serialization (
lg_throttle 31 %).
- Lesson 09's FA leaving 22 % to SDPA comes from register pressure (255 regs → half-occupancy) + MMA dependency (
wait 39 %). SDPA is already in the healthy tensor-pipe-throttle regime.
These three interpretations are impossible without profiling tools. And with them, the next iteration's "what to change" actually makes sense. A session with zero new kernels — but one that moved forward the starting point of every kernel-tuning session after this.
Appendix — reproduction commands
# Phase 1 — nsys timeline diff (pinned vs pageable)
./scripts/gcp_run_lesson10_phase1.sh <PROJECT_ID> us-west1-b cuda-l4-dev-lesson10
# Phase 2 — ncu reduction v1 vs v4
./scripts/gcp_run_lesson10_phase2.sh <PROJECT_ID> us-west1-b cuda-l4-dev-lesson10
# Phase 3 — ncu ours vs SDPA
./scripts/gcp_run_lesson10_phase3.sh <PROJECT_ID> us-west1-b cuda-l4-dev-lesson10
On the GCP DL image, wrap ncu with sudo -E env PATH=$PATH ncu ... to get perf-counter access.