Optimizing LLM inference speed in float16 in JAX with PallasRobert Dyro 

Optimizing LLM inference speed in float16 in JAX with Pallas IntroductionWe aim to optimize the Local LLM model in JAX with Why

Stage  Runtime 

Context ingestion  574 ms 
Generation  62,000 ms (62 s) 
Both context ingestion and generation involve a pass through the model, for different computation shapes. Thus, regardless, the optimization should focus on accelerating the singlepass runtime of the model. However, clearly, the computation shapes to optimized for are of the generation kind – the singlequery decoding.
The core computation – the singlepass through the model – consists of two phases:
Phase 1: The attention operation, which takes the form
$\text{softmax}\left(Q{K}^{T}\right)V$
where $Q\in {\mathbb{R}}^{n\times d}$ , $K\in {\mathbb{R}}^{m\times d}$ , and $V\in {\mathbb{R}}^{m\times d}$ are the query, key, and value matrices, respectively. The output is a matrix of size $n\times d$ .
Because the softmax operation is applied to the rows of the matrix, the number of queries need not match the number of keys and values. This is particularly useful in the case of singlequery decoding, where the query size is 1.
Phase 2: Output Multilayer Perceptron (MLP). Several possibilities are common, but for Mistral, the MLP is a SiLU activated feedforward networkwithout bias.
${W}_{\text{down}}\text{SiLU}({W}_{\text{up}}x\odot {W}_{\text{gate}}x)$
where ${W}_{\text{up}}\in {\mathbb{R}}^{d\times 4d}$ , ${W}_{\text{gate}}\in {\mathbb{R}}^{d\times 4d}$ , and ${W}_{\text{down}}\in {\mathbb{R}}^{4d\times d}$ are the weight matrices of the MLP.
We can benchmark the performance of the model to determine the first Phase to tackle for optimization.
Fig: Profile of the decode layer with default settings for a singlequery decoding. The MLP phase is the bottleneck.
Unoptimized singlepass (the main computation scope) runtime: 42 ms.
Fig: The highlighted blocks dominate the runtime of the decode layer. The MLP phase is the bottleneck.
A simple matrix multiplication has three input dimensions:
Because we established the MLP phase is the bottleneck and because the sole computationally expensive operation in the MLP phase is the matrix multiplication, we focus on optimizing this matrix multiplication (matmul) operation.
JAX exposes Pallas  a Tritonlike higher level kernel language.
On a high level, the strategy is to (1) implement a simple matrix multiplication kernel and (2) tune the parameters of the kernel for particular hardware and input shapes.
Fig: A visual representation of a simple matrix multiplication kernal with an innerloop accumulating the matrix multiplication of a slice of the matrices by multiplying smaller blocks of the matrices.
The figure above shows probably the simplest possible matrix multiplication kernel. The hope here is that the Pallas language (or in fact, the underlying Triton compiler) can optimize the warplevel parallelism and memory access patterns better to make this simple kernel efficient.
The kernel is implemented in Pallas as follows:
def matmul_kernel(x_ref, A_ref, o_ref, block_x: int, block_a: int, block_d: int): row_id, col_id = pl.program_id(0), pl.program_id(1) col_slice = pl.dslice(col_id * block_d, block_d) A_mask_j = (col_id * block_d + jnp.arange(block_d) < A_ref.shape[1])[None, :] a_i = jnp.arange(block_a) x_mask_i = (row_id * block_x + jnp.arange(block_x) < x_ref.shape[0])[:, None] x_j = jnp.arange(block_a) def body_i(start_i, carry_i): o_prev = carry_i x_mask = x_mask_i & (start_i * block_a + x_j < x_ref.shape[1])[None, :] x = pl.load( x_ref, (pl.dslice(row_id * block_x, block_x), pl.dslice(start_i * block_a, block_a)), mask=x_mask, ) a_mask = A_mask_j & (start_i * block_a + a_i < A_ref.shape[0])[:, None] a = pl.load(A_ref, (pl.dslice(start_i * block_a, block_a), col_slice), mask=a_mask) return pl.dot(x, a) + o_prev o_init = jnp.zeros((block_x, block_d), dtype=jnp.float32) o = lax.fori_loop(0, pl.cdiv(A_ref.shape[0], block_a), body_i, o_init) o_slice = (pl.dslice(row_id * block_x, block_x), pl.dslice(col_id * block_d, block_d)) o_mask = (row_id * block_x + jnp.arange(block_x) < o_ref.shape[0])[:, None] & ( col_id * block_d + jnp.arange(block_d) < o_ref.shape[1] ) pl.store(o_ref, o_slice, o.astype(o_ref.dtype), mask=o_mask)
A Pallas kernel optimization, especially for such a simple kernel, might be sensitive to the input dimensions. Picking a single set of kernel hyperparameters is unlikely to be optimal for all input shapes.
JAX recompiles a program if the dimensions change, so we can choose kernel hyperparameters based on the input dimensions. On one extreme, we could pick a large set of combinations of input dimensions and find the optimal kernel hyperparameters for each combination. On the other extreme, we could pick a single set of kernel hyperparameters, and that works well for all input dimensions. The first option will squeeze out the last bit of performance, but it might poorly generalize to unseen input shapes. The second option is likely to be suboptimal, but is much more likely to generalize.
We choose a middle ground, a small set of input dimensions, and a small set of kernel hyperparameters. We then test every input dimensions combination on every hyperparameter combination. Finally, to improve generalize, we want to select only 4 hyperparameter sets and create a map of the best hyperparameters for each input shape. This requires us to find a small set of hyperparameters that, when selected for every input dimension combination, will give us the best performance. This is a mixedinteger optimization problem.
Problem: pick a set of hyperparameters that is less than 4, such that the total speedup is maximized when for each input dimension combination, we choose one of the hyperparameter sets from that small set (of size 4).
In principle, this requires testing all possible combinations of size 4 in a set
45 possible hyperparameter combinations  $(}\genfrac{}{}{0ex}{}{45}{4}{\textstyle )}=1.5\times {10}^{5$ .
This is large but not necessarily infeasible to iterate over. However, we need
to consider another hyperparameter possibility, the fallback to the native
matmul
implementation. This increases the number of hyperparameter
combinations to $(}\genfrac{}{}{0ex}{}{46}{4}{\textstyle )}=1.4\times {10}^{6$ . This is still feasible to
iterate over, but we turn to a mixedinteger optimization solver to find the
optimal set of hyperparameters instead.
Fig: Optimal kernel hyperparameter configuration for each of the input shape test point. Configuration 0 is the inbuilt matmul. Percentage numbers denote how much slower the inbuilt implementation is from the optimized kernel. Tested on Nvidia RTX 3090.
As seen from the results, the optimized kernel can be significantly faster than the inbuilt matmul implementation. The speedup is particularly pronounced for larger input shapes.
Fig: Profile of the decode layer with optimized matrix multiplication for a singlequery decoding. The MLP phase takes roughly the same time as the attention phase.
Optimized singlepass (the main computation scope) runtime: 25 ms.
In the context of local LLMs, the speed of individual operations is crucial. The inference pass is not particularly complex, so even uncompiled PyTorch model is able to achieve a very high throughput. JAX, despite compilation, suffers from a slow default implementation of the matmul
operation for some input sizes – at least on RTX 3090 as of version 0.4.29.
What optimizations did not work:
matmul
implementation is already optimized for that hardware.Faster model compilation & loading is important:
float32
and memory usage explodes to over 30GB (which few single GPUs have). If the default device is the CPU, the pass is incredibly slow. It's possible to solve this by using a _do_init=False
initialization, but it is userhostile and documented scantily.