Pallas-Triton kernels and kernel auto-tuningRobert Dyro |
|
Modern kernel languages and why use them? Some algorithms for LLM inference and training have enough algorithmic structure that their implementation in terms of higher level ops (e.g., numpy operations) is often suboptimal on modern hardware. Lower-level languages expose more control over hardware and allow encoding of algorithmic structure in the implementation (e.g., smaller intermediate values improving cache-locality in flash attention). Currently (2025), modern accelerators, like NVIDIA GPUs and Google's TPUs consist largely of a matrix multiplication unit which needs to be efficiently driven in software to achieve peak computational performance. Any lower level language that allows for low-level hardware control and the ability to efficiently program the matrix multiplication unit is a good candidate for LLM algorithms. For NVIDIA GPUs, CUDA satisfies both requirements, but is a fairly low-level dialect, which means writing code can be tedious and error-prone. Modern alternatives positioned somewhere between CUDA's low-level control and high level (numpy) ops are kernel languages like Triton, and Pallas which both use Python syntax and array-level operations to express a concise, but low-level interface for programming accelerators. Both Triton and Pallas have dialects targeting different hardware and different levels of manual control. Triton supports NVIDIA GPUs (with some support for AMD), whereas Pallas supports NVIDIA GPUs and Google TPUs. The old GPU-targeting dialect of Pallas used Triton to lower pallas programs on GPU. However, with architectural changes in the Hopper and Blackwell GPUs, Triton's original dialect is not expressive enough to efficiently use the new matrix multiplication unit. Both Triton and Pallas introduced new dialects, Gluon and Mosaic GPU giving lower-level control and allowing writing kernels that utilize the hardware fully. Additionally, Pallas also contains the TPU dialect which does the same for TPUs. An opinionated summary:
KernelsFlash AttentionThe attention operator is fundamental to modern LLMs by capturing temporal dependencies between tokens. Flash attention is a memory efficient and cache-friendly algorithm for computing the attention operator exactly. An example implementation in Pallas-Triton: https://github.com/rdyro/flash_attention_pallas_triton Ragged Dot (Group Matrix Multiplication)Ragged dot is a matrix multiplication operator where the right-hand side (RHS) consists of a stack of "expert" matrices and the left-hand side (LHS) has a ragged assignment of rows to these experts. For example, the first 5 LHS rows are multiplied with the first "expert" matrix, the next 3 rows with the second, the next 7 rows with the eight "expert" matrix, and so on. With the output rows having the same assignment as the LHS. Both the LHS and the output are a 2D tensor with the row dimension ragged. The RHS is a 3D tensor with the first dimension representing the "expert" index.
An example implementation in Pallas-Triton: https://github.com/rdyro/gpu_ragged_dot Auto-tuningOne of the most successful techniques for getting full-performance of a kernel is to parameterize various kernel parameters like work-chunk size, block tiling or memory pipeline buffering. These can then be tuned by randomly sampling or exhaustively scanning over the space of parameters to identify the set resulting in the fastest run-time on a particular hardware and problem size. tune-jax is a Python package (also
available on PyPI via For example, after tuning flash attention on my RTX 3090, I got the following sorted and truncated results:
The fastest parameter setting runs |