Brian's Blog
GPUMODE · Matmul v2 · L4

GPUMODE Matmul v2: Custom kernels only where they mattered

I left PyTorch in charge of the shapes it already handled well, and replaced only the shapes where a Triton kernel could actually win. This was not a full rewrite of the problem. It was a selective optimization of the parts that moved the leaderboard score.

Date May 5, 2026
GPU NVIDIA L4
Core idea Shape-specific hybrid dispatch
Winning shapes 2048x3072x2048 · 4096x5120x4096
Problem

The problem is still matrix multiplication

The matmul_v2 task is exactly what its name says: matrix multiplication. The inputs are a, b, and c, and the operation is basically this.

operationfp16
c = a @ b

The dtype is float16, and the output has to match the reference. Speed alone is not enough. If the result is wrong, the submission fails.

(128, 128, 128)Small shape. PyTorch fallback is already strong.
(256, 256, 256)Kernel launch overhead can dominate.
(512, 512, 512)Not a major contributor to the final score.
(1024, 1024, 1024)Hard to beat the baseline cleanly.
(2048, 2048, 2048)Looked fast, but failed correctness.
(2048, 3072, 2048)A Triton custom kernel won here.
(1024, 1536, 1024)Safer to leave on fallback.
(4096, 5120, 4096)Most of the work. This is where the score moved.

The key fact was that the largest shape, (4096, 5120, 4096), dominated the total compute. Optimizing small shapes can be satisfying, but it barely changes the score. Cutting even a few dozen microseconds from the largest shape matters a lot.

Do not try to optimize every shape. Start with the shapes that actually affect the score.
Step 01

First, I built a safe baseline

The first submission was intentionally simple. I called it v0_safe.py.

v0_safe.pybaseline
def custom_kernel(data):
    a, b, c = data
    torch.mm(a, b, out=c)
    return c

It looks like “just PyTorch,” but the important part is out=c. Using torch.mm(a, b, out=c) avoids allocating a temporary tensor and copying the result back. It is safe, simple, and surprisingly strong.

Run Result Note
v0_safe benchmark 2992.352 us GCP L4 benchmark result
v0_safe leaderboard proxy 3010.897 us Proxy number before official submission
largest shape about 2219 us (4096, 5120, 4096) consumed most of the total runtime

The whole run was around 3ms, and the last shape alone took about 2.2ms. That made the target obvious. Not the small shapes. The big one.

Step 02

I did not rewrite everything

Next I started attaching Triton kernels, but only for a few large shapes at first.

(2048, 2048, 2048)Initially a candidate, but removed after a correctness mismatch.
(2048, 3072, 2048)Stayed in the final Triton dispatch path.
(4096, 5120, 4096)The largest shape, and by far the most important one.
othersPyTorch fallback. Leaving code alone can also be an optimization.

The kernel split the matrix into 128 x 128 tiles and walked the K dimension in chunks of 64. In plain terms: break a big matrix into smaller blocks, then let the GPU compute those blocks in parallel.

But a problem appeared immediately. The (2048, 2048, 2048) shape did not match the reference.

correctness logreject
mismatch found!
custom implementation doesn't match reference

A fast wrong answer is useless on a leaderboard. It is not “almost good.” It is just invalid. So that shape was removed from the custom path.

Rule Use custom kernels only for shapes that are both faster and correct.
Step 03

Sweeps found the real candidates

From this point on, I stopped tuning by feel. I ran sweeps on a GCP L4 instance named cuda-l4-dev-lesson10, varying Triton settings and checking which combinations were both fast and correct.

At first I looked at BK=64. BK means how much of the K dimension the kernel processes at once. For example, 128x128x64 means 128 along M, 128 along N, and 64 along K.

Later sweeps showed that BK=32 was better. For the two large shapes, 128x128x32, num_warps=4, and num_stages=4 looked consistently strong.

Shape torch.mm 128x128x32_w4_s4_g8 Delta
(4096, 5120, 4096) 2263.859 us 2131.968 us -131.891 us
(2048, 3072, 2048) 371.507 us 326.246 us -45.261 us

This was the first point where the approach felt like a real candidate. The result became v2_bigshape_bk32.py.

Version Benchmark Leaderboard proxy
v2_bigshape_bk32 2882.787 us 2904.946 us

The baseline proxy was about 3010us, so this was clearly moving in the right direction.

Step 04

I made GROUP_M shape-specific

The next parameter was GROUP_M. In Triton, GROUP_M affects the order in which blocks are grouped and scheduled. Roughly speaking, it changes how the GPU walks around the matrix.

This matters for cache reuse. If blocks are processed in an order that reuses the same data, memory behavior improves. But the sweep showed that the best value was not the same for both large shapes.

shape-specific dispatchv3
if shape == (2048, 3072, 2048):
    group_m = 8
elif shape == (4096, 5120, 4096):
    group_m = 16
else:
    torch.mm(a, b, out=c)
    return c

This was the turning point. I had started by looking for one good kernel. The actual answer was that each shape wanted its own setting.

Turn The winning strategy was not a universal kernel. It was shape-specific dispatch that only kept the cases that won.
Version Benchmark Leaderboard proxy Repeat
v3_bigshape_bk32_grouped 2827.212 us 2862.659 us 2856.437 us
Step 05

I added cache hints only to the largest shape

The final target was the largest shape, (4096, 5120, 4096). Since it mattered most to the score, it was worth tuning more aggressively.

In v4_bigshape_cache_hint.py, I added cache hints only for this shape.

cache hintv4
tl.load(..., cache_modifier=".cg", eviction_policy="evict_first")
tl.load(..., cache_modifier=".cg", eviction_policy="evict_last")

In simple terms, the kernel gives the GPU hints about which data does not need to stay around for long and which data may be reused. More precisely, the A and B matrix loads use different cache and eviction policies.

This is not universally better, so I swept it too. I also retested GROUP_M values like 10, 12, 14, 16, 18, 20, 24, 28, and 32. For the largest shape, the best area was around GROUP_M=16, and the cache-hint variant became the best candidate.

Version Benchmark Leaderboard proxy Largest shape
v0_safe 2992.352 us 3010.897 us about 2219 us
v4_bigshape_cache_hint 2806.145 us 2840.097 us about 2062 us

The largest shape dropped by more than 150us. In this problem, that is huge, because most of the total score comes from that shape.

Final kernel

The final kernel structure was surprisingly simple

The core of the final submission, v4_bigshape_cache_hint.py, looked like this.

final dispatchhybrid
if shape == (2048, 3072, 2048):
    # Triton 128x128x32, GROUP_M=8
    run_default_triton_kernel()
    return c

if shape == (4096, 5120, 4096):
    # Triton 128x128x32, GROUP_M=16, cache hint
    run_cache_hint_triton_kernel()
    return c

# Everything else stays on PyTorch
torch.mm(a, b, out=c)
return c

The important part is that the other shapes were not forced through a custom kernel. For small shapes, PyTorch is already fast enough, and Triton launch overhead or branching can easily erase the benefit.

The (2048, 2048, 2048) shape looked promising on speed but failed correctness, so it was dropped. That restraint was part of the optimization.

small shapesPyTorch fallback
mid shapesPyTorch fallback
2048x3072x2048Triton custom kernel
4096x5120x4096Triton custom kernel + cache hint

The key was making the complicated problem simple.

Numbers

GCP numbers and GPUMODE numbers were different

Most experiments ran on GCP L4.

Environment GPU Torch Triton
GCP cuda-l4-dev-lesson10 NVIDIA L4 2.11.0+cu128 3.6.0
GPUMODE / Modal L4 NVIDIA L4 2.11.0+cu129 runner environment

That means the GCP proxy benchmark and the GPUMODE official ranked score did not match exactly. On GCP, the v4 leaderboard proxy was about 2840us. On GPUMODE, the official ranked score was 2076.331us.

Official result Submission ID 780611, file v4_bigshape_cache_hint.py, official ranked score 2076.331 us, L4 rank #1.

The gap comes from runner differences, Torch/CUDA versions, GPUMODE benchmark details, and representative or secret run behavior. The official number is the one that matters.

Lessons

What I learned from this optimization

The biggest lesson is that GPU optimization is not just about writing one impressive kernel. The more important parts were these.

  • 01 Start with a strong baseline.Before writing custom kernels, I used torch.mm(out=c) as the baseline. Because it was already strong, weak Triton variants were easy to reject.
  • 02 Correctness comes first.The (2048, 2048, 2048) custom kernel looked fast, but it was wrong. So it was removed. A fast incorrect kernel has no leaderboard value.
  • 03 Optimize the shapes that move the score.Small shapes are fun to tune, but they barely affect the final number. In this task, the largest shape decided almost everything.
  • 04 Shape-specific dispatch is powerful.One kernel for every shape is often a compromise. Some shapes belong to PyTorch, some to Triton, and some need extra cache hints.
  • 05 You do not know until you sweep.BK=64 may look plausible. BK=32 may win. GROUP_M=8 may be best for one shape, while GROUP_M=16 wins for another. The answer came from measurement.
Timeline

Rank 1 was the result of a process, not one magic kernel

v0_safe
torch.mm(out=c) baseline
Started with a safe, strong baseline.
~3010 us proxy
v1_hybrid_large
Triton for large shapes
Found correctness failures and removed the failing shapes.
mismatch reject
v2_bigshape_bk32
BK=32 candidate
128x128x32, w4, s4 became a strong candidate for the two large shapes.
~2905 us proxy
v3_grouped
Shape-specific GROUP_M
Used GROUP_M=8 for 2048x3072x2048 and GROUP_M=16 for the largest shape.
~2856-2862 us
v4_cache_hint
Cache hints on the largest shape
GCP proxy reached about 2840us, and GPUMODE official ranked score landed at 2076.331us.
L4 rank #1
Measure, discard what is wrong, keep only the shapes that win, and verify on the official server.

What I like most about this result is that it was not just “I wrote a fast kernel.” More precisely, I looked at the full problem, narrowed in on the real bottleneck, honestly dropped the shapes that failed, and validated the final answer on the official runner. That was the core of this matmul_v2 L4 rank 1 result.

§