Pallas-Triton kernels and kernel auto-tuning


Robert Dyro
alttext

Modern kernel languages and why use them?

Some algorithms for LLM inference and training have enough algorithmic structure that their implementation in terms of higher level ops (e.g., numpy operations) is often suboptimal on modern hardware. Lower-level languages expose more control over hardware and allow encoding of algorithmic structure in the implementation (e.g., smaller intermediate values improving cache-locality in flash attention). Currently (2025), modern accelerators, like NVIDIA GPUs and Google's TPUs consist largely of a matrix multiplication unit which needs to be efficiently driven in software to achieve peak computational performance.

Any lower level language that allows for low-level hardware control and the ability to efficiently program the matrix multiplication unit is a good candidate for LLM algorithms. For NVIDIA GPUs, CUDA satisfies both requirements, but is a fairly low-level dialect, which means writing code can be tedious and error-prone. Modern alternatives positioned somewhere between CUDA's low-level control and high level (numpy) ops are kernel languages like Triton, and Pallas which both use Python syntax and array-level operations to express a concise, but low-level interface for programming accelerators.

Both Triton and Pallas have dialects targeting different hardware and different levels of manual control. Triton supports NVIDIA GPUs (with some support for AMD), whereas Pallas supports NVIDIA GPUs and Google TPUs. The old GPU-targeting dialect of Pallas used Triton to lower pallas programs on GPU. However, with architectural changes in the Hopper and Blackwell GPUs, Triton's original dialect is not expressive enough to efficiently use the new matrix multiplication unit. Both Triton and Pallas introduced new dialects, Gluon and Mosaic GPU giving lower-level control and allowing writing kernels that utilize the hardware fully. Additionally, Pallas also contains the TPU dialect which does the same for TPUs.

An opinionated summary:

  • Triton: relatively high-level, designed for GPUs
    • Triton dialect: for GPUs
      • great for all workloads on Ampere NVIDIA GPUs
      • good for Hopper+ for memory bandwidth-bound workloads like some inference
      • easy to write
    • Gluon dialect: for NVIDIA GPUs
      • great for compute-bound workloads on Hopper+ NVIDIA GPUs
      • more difficult to write
  • Pallas: relatively high-level, designed for GPUs and TPUs
    • Pallas-Triton dialect: for GPUs, like Triton dialect
    • Pallas-Mosaic-GPU dialect: for NVIDIA GPUs
      • great for compute-bound workloads on Hopper+ NVIDIA GPUs
      • more difficult to write (but much more concise than writing CUDA directly)
    • Pallas-TPU dialect: for Google's TPUs
      • great for memory bandwidth-bound and compute-bound workloads on all TPUs
      • relatively easy to write

Kernels

Flash Attention

The attention operator is fundamental to modern LLMs by capturing temporal dependencies between tokens. Flash attention is a memory efficient and cache-friendly algorithm for computing the attention operator exactly.

An example implementation in Pallas-Triton: https://github.com/rdyro/flash_attention_pallas_triton

Ragged Dot (Group Matrix Multiplication)

Ragged dot is a matrix multiplication operator where the right-hand side (RHS) consists of a stack of "expert" matrices and the left-hand side (LHS) has a ragged assignment of rows to these experts. For example, the first 5 LHS rows are multiplied with the first "expert" matrix, the next 3 rows with the second, the next 7 rows with the eight "expert" matrix, and so on. With the output rows having the same assignment as the LHS. Both the LHS and the output are a 2D tensor with the row dimension ragged. The RHS is a 3D tensor with the first dimension representing the "expert" index.

An example implementation in Pallas-Triton: https://github.com/rdyro/gpu_ragged_dot

Auto-tuning

One of the most successful techniques for getting full-performance of a kernel is to parameterize various kernel parameters like work-chunk size, block tiling or memory pipeline buffering. These can then be tuned by randomly sampling or exhaustively scanning over the space of parameters to identify the set resulting in the fastest run-time on a particular hardware and problem size.

tune-jax is a Python package (also available on PyPI via pip install tune-jax) automating the auto-tuning process, mostly by using parallel compilation and automatic profile parsing for accurate timing results.

For example, after tuning flash attention on my RTX 3090, I got the following sorted and truncated results:

  id    block_q    block_k    num_warps    num_stages    t_mean (s)    t_std (s)
----  ---------  ---------  -----------  ------------  ------------  -----------
 168        128         64                               3.8098e-02   3.1642e-04
 176        128         64            4                  3.8528e-02   3.0663e-04
 134         64        128                          2    3.9203e-02   1.6830e-04
 178        128         64            4             2    3.9289e-02   2.3250e-04
 110         64         32                          2    3.9595e-02   1.5604e-04
...
  31         16         64            2             4    1.4387e-01   1.3587e-03
 173        128         64            2             1    1.5114e-01   1.0685e-03
 185        128        128            2             1    2.4229e-01   6.5207e-04
 184        128        128            2                  1.1032e+00   7.1724e-03
 186        128        128            2             2    1.1805e+00   1.6940e-03

The fastest parameter setting runs 1.1805e+00 / 3.8098e-02 = 31 faster than the slowest setting.