JAX Advanced Techniques

Robert Dyro, Spencer Richards

JAX Advanced Techniques

Created by: Robert Dyro, Spencer Richards


What is JAX?

JAX is a Python framework for

  • creating computational graphs (or simply routines) using a high-level language, Python
  • manipulating those computation graphs/functions/routines to:
    • compile them into efficient computational code
    • take automatic derivatives of created functions
    • parallelize functions over lots of data
    • run the same high-level computational functions on a variety of hardware: CPU, GPU, TPUs

This tutorial is an attempt to produce an educational reference for JAX techniques we found particularly interesting in our work and research. As such, it is a collection of a few disjoint topics.


Since JAX is designed with function transformations in mind (e.g., compilation, derivatives, parallelization), the first difficulty to arise is how to deal with Python functions that can have multiple arguments, keyword vs positional arguments, nested dictionaries for configurations and so on. The philosophy of JAX is simple: any nested Python structure is allowed. It doesn't matter if your function takes tuples, nested tuples, nested dictionaries, dictionaries of tuples, etc.,as arguments.

The JAX name for any valid nested Python structure is a Pytree, a tree because each container (e.g., tuple, list, dictionary) is a root with several children nodes (the elements in the container).

There are many excellent JAX introductory resources for Pytrees, but the advanced techniques using Pytrees we cover in this tutorial are the following:

  1. automatically determining which function arguments should be batched over in vmap
  2. how to compile a JAX function that takes arguments not supported by JAX

The three most important methods for manipulating trees in JAX are:

  • tree_flatten for flattening a Pytree (nested Python container) into a single list
  • tree_unflatten for reforming a single list of arguments (e.g., once modified) into the original Pytree (nested Python container) structure
  • tree_map for applying a function (typically a simple Python lambda) to every argument in a Pytree (nested Python container)

1. Automatically batching over arguments in vmap

When you're working on computational routines, you might have a large nested Python dictionary representing your problem data and problem configuration parameters. If you're trying to batch compute a function over your problem, you might need to batch over all data, but none of the configuration parameters (these are the same for each batch). How to automatically compute which arguments should be batched over and which not?

Note: vmap expects that the user explicitly instructs it which arguments to batch over (and which axis to use)

from typing import Dict, Any, Callable
from jax import vmap, Array
from jax.tree_util import tree_map

def computational_routine(problem: Dict[str, Any]) -> Array:

def vmap_automatically(computation_fn: Callable, problem: Dict[str, Any], batch_size: int):

    # let's define a way to check whether x has a shape 
    # (is some form of data, strings don't have a "shape" attribute)
    # and whether the first entry in the shape corresponds to the `batch_size` provided
    def is_batchable(x: Any, batch_size: int):
        return hasattr(x, "shape") and len(x.shape) > 0 and x.shape[0] == batch_size

    # we're going to map over every argument in the problem (recursively)
    # if the argument has the first dimension of size batch_size, 0 is the batch axis
    # otherwise, we put None to indicate that the argument should not be batched over
    in_axes = tree_map(lambda x: 0 if is_batchable(x, batch_size) else None, problem)
    return jax.vmap(computation_fn, in_axes)(problem)

2. How to compile a JAX function that takes arguments unsupported by JAX?

When calling jax.jit to indicate that we want a function compiled, there's an extra argument, static_argnums (and static_argnames for keyword definition) for indicating which arguments might change the internal logic of the function or do things not supported by JAX. This way, JAX compiles a separate version of the function for each combination of "static" arguments.

The simplest example of this is if we have a boolean flag that changes the internal logic of our function

def linear_transform(x, W, b, affine: bool = True):
    if affine:
        return x @ W + b
        return x @ W

Here, if we want JAX to properly deal with the argument affine in compilation, we issue

linear_transform_compiled = jax.jit(linear_transform, static_argnames="affine")

which will cause JAX to (just-in-time) compile a separate version of the function for each possible value of affine.

The need to explicitly specify which arguments are results in at least two annoying cases:

  • compiling member functions of a class
  • compiling functions which have an internal logic indicated by a string

We would like to quickly and automatically convert all arguments of a jitted function that are not JAX-representable to static arguments.

from typing import Callable
from jax import jit, Array
class AutoJIT:
    def __init__(self, fn: Callable):
        self.fn = fn
        self.compiled_fn = None

    def __call__(self, *args, **kw):
        # this is the first time we see the arguments to the function
        # we're going to determine which arguments are `static`
        if self.compiled_fn is None:
            # 1. for every variable argument, we check if it's a jax.Array, if not, it's static
            static_argnums = [i for (i, arg) in enumerate(args) if not isinstance(arg, Array)]
            # 2. similarly, for every keyword argument, we check and set static if not jax.Array
            static_argnames = [k for (k, v) in kw.items() if not isinstance(v, Array)]
            # 3. we jit the function with the static arguments and argnames
            self.compiled_fn = jit(
                    self.fn, static_argnums=static_argnums, static_argnames=static_argnames
        # we just call the compiled function
        return self.compiled_fn(*args, **kw)

Advanced example

We can make use of argument flattening to define a more advanced version of the automatic jit function

from typing import Callable
from jax import jit, Array
from jax.tree_util import tree_flatten, tree_unflatten, tree_map

class AdvancedAutoJIT:
    """The advanced version of AutoJIT, which flattens the arguments to handle pytree inputs."""
    def __init__(self, fn: Callable):
        self.fn = fn
        # we now have a cache of compiled functions for each combination of provided arguments
        self.compiled_fn_cache = dict() 

    def __call__(self, *args, **kw):
        # 1. first, we flatten the arguments and keyword arguments
        flat_args, args_struct = tree_flatten(args)
        flat_kw, kw_struct = tree_flatten(kw)
        args_types, kw_types = tree_map(type, flat_args), tree_map(type, flat_kw)
        # 2. we produce a "unique" identifier for provided arguments: the structure and types
        cache_key = (args_struct, kw_struct, tuple(args_types), tuple(kw_types))
        # 3. underneath, we're going to call the function using with all arguments flattened
        flat_args_kw = tuple(flat_args) + tuple(flat_kw)
        if cache_key not in self.compiled_fn_cache:

            # 3. we produce a flat argument version of the function
            def flat_fn(*flat_args_kw):
                # 4. we need to unflatten the arguments and keyword arguments
                args, kw_args = flat_args_kw[: len(flat_args)], flat_args_kw[len(flat_args) :]
                args = tree_unflatten(args_struct, args)
                kw = tree_unflatten(kw_struct, kw_args)
                # 5. we call the function with the unflattened arguments
                return self.fn(*args, **kw)

            # 6. we can now determine, using the flat argument version, which args need to be static
            static_argnums = [
                i for (i, arg) in enumerate(flat_args_kw) if not isinstance(arg, Array)
            # 7. we can now compile the flat argument version of the function
            self.compiled_fn_cache[cache_key] = jit(flat_fn, static_argnums=static_argnums)

        # 8. we can now call the compiled function with the flat arguments
        return self.compiled_fn_cache[cache_key](*flat_args)

The PyTree section has been heavily inspired by the Equinox https://github.com/patrick-kidger/equinox package which pushes further on the PyTree philosophy in JAX.

Using random numbers in JAX

Using random numbers in JAX requires explicitly providing a random number generator key and the key uniquely determines the numbers generated.

The two most important methods here are

  • jax.random.PRNGKey(seed: int) - to generate the initial random key
  • jax.random.split(key: jax.Array[jax.uint32], num_splits: int) to generate num_splits new keys

Generally, the strategy for generating new random numbers is to:

  • split the global existing random key into n + 1 new keys
  • retain the first of the split keys as the new global key
  • use the n other keys for applications
import jax
import time
global_key = jax.random.PRNGKey(int(time.time()))

# 1. we have some data with a batch dimension of 100
data = jax.numpy.ones((100, 56))

# 2. we will generate the right number of random keys for each batch
new_keys = jax.random.split(global_key, data.shape[0] + 1)
global_key, random_keys = new_keys[0], new_keys[1:]

# 3. we use the keys in the downstream application
jax.vmap(fn)(data, random_keys)

Pickling JAX and Persistent Compilation

In our applications, we found the need to serialize ("pickle" in Python) arbitrary code and JAX data. While the in-built pickle module in Python works well for this purpose, we found that cloudpickle (TODO add link) can work much better. Especially with lambdas, in-line defined functions, which are often useful when working with JAX functional transformations.

We recognize that sending arbitrary Python code to remote workers is not necessarily a common application, but with JAX functional philosophy, functions are often as important as data, so a problem-to-compute can now consist of data and functions.

In this section of the tutorial, we cover two topics code serialization topics:

  1. how to achieve persistent compilation for long-compiling JAX routines
  2. how to avoid code recompilation when sending code to remote workers

1. How to achieve persistent compilation for long-compiling JAX routines?

As far as we can tell, there's currently no way to cache or save the result of a function compilation to disk, so that it survives program restart. Persistent compilation is particularly useful as code development speed can be significantly hindered in research contexts, if the function compilation takes upwards of 1 minute.

The only solution to persistent compilation we found, was to use a persistent process (a separate program) that could accept Python functions and subsequently compute data

One simple possibility is to use the remote call Python package rpyc.

For the server code, we simply have

import rpyc
import jax
import cloudpickle as ser

class PersistentJAXFunctions(rpyc.Service):
    def exposed_call(self, fn_bytes: bytes, args_bytes, kwargs_bytes):
        if not hasattr(self, "persistent_functions"):
            self.persistent_functions = dict()
        if fn_bytes not in self.persistent_functions:
            self.persistent_functions[fn_bytes] = jax.jit(ser.loads(fn_bytes))
        args, kwargs = ser.loads(args_bytes), ser.loads(kwargs_bytes)
        return self.persistent_functions[fn_bytes](*args, **kwargs)

and for the client

import rpyc
import jax
import cloudpickle as ser

# a decorator to make a function persistent via calls to an rpyc server
def persistent_call(fn, port=18861):
    c = rpyc.connect("localhost", port)
    fn_bytes = ser.dumps(fn)

    def persistent_fn(*args, **kw):
        args_bytes, kw_bytes = ser.dumps(args), ser.dumps(kw)
        return c.root.call(fn_bytes, args_bytes, kw_bytes)

    return persistent_fn

def linear_transform(x, W, b):
    return x @ W + b

The obvious disadvantage of this solution is that it can easily be bottlenecked by what is effectively a TCP based data transfer, which can be slow.

2. What to do if pickling fails on a JAX object?

In Python, if an object fails to pickle, it is possible to define a custom serialization method using the in-built library copyreg like so:

import copyreg

class P:
    def __init__(self, a = 1.0):
        self.a = a

def custom_pickling(p: P):
    def custom_constructor(a):
        return P(a * 2)

    pickleable_objs_to_save = (p.a,)
    return custom_constructor, pickleable_objs_to_save

# register the custom pickling method
copyreg.pickle(P, custom_pickling)

Extending JAX

1. Without defining gradients (KDTree)

Let's consider the example of wrapping scipy object for KDTree nearest point search. This object primarily returns integer indices, so we need not worry about differentiability.

class JAXKDTree(KDTree):
    def __init__(self, data, n_return: int = 10, **kw):
        self.jax_data = data
        super().__init__(np.array(data), **kw)
        self.n_return = n_return

    def query(self, x, **kw):
        # this function returns a float distance and an int index, so it's easy to wrap
        # 1. we cast x to a numpy array
        # 2. we recompute the distance to the nearest neighbor in JAX to make it differentiable
        idx = jax.pure_callback(lambda x: np.array(super().query(x)[1]), ShapeDtypeStruct((), int), x)
        #_, idx = super().query(np.array(x), **kw)
        dist = jnp.linalg.norm(x - self.jax_data[idx, :])
        return dist, idx

    def _query_ball_point(self, x, r, **kw):
        # for JAX, we need to make sure that this function returns the same output shape every time

        # 1. we call the original function
        idx = super().query_ball_point(x, r, **kw)
        # 2. we sort the indices by distance
        dists = [np.linalg.norm(x - self.data[i, :]) for i in idx]
        idx = [idx[i] for i in np.argsort(dists)]
        # 3. we pad or truncate the output to keep the fixed length of (self.n_return,)
        idx = idx[: self.n_return] + [-1 for _ in range(self.n_return - len(idx))]
        return np.array(idx)

    def query_ball_point(self, x, r, **kw):
        # 1. for a pure_callback in JAX, we construct the function output shape and dtype
        # this can be a pytree of these shape + dtype structs
        output_shape = jax.ShapeDtypeStruct((self.n_return,), np.int64)
        # 2. we wrap the original function with jax.pure_callback
        return jax.pure_callback(
            lambda x, r: self._query_ball_point(x, r, **kw), output_shape, x, r

2. With gradients (Torch)

Let's wrap the PyTorch cross_entropy_loss function, which includes both floating point and integer arguments.

We first define the PyTorch interface as follows

import torch
from jax.tree_util import tree_map, tree_flatten

# we define two, numpy compatible, pytorch functions

def cross_entropy_loss_torch(x, y):
    # 1. if arguments come in as numpy arrays, so we need to convert them to torch tensors
    numpy_mode = any([not isinstance(arg, torch.Tensor) for arg in [x, y]])
    x, y = tree_map(lambda arg: torch.as_tensor(arg), [x, y])
    # 2. we use pytorch here directly
    out = torch.nn.functional.cross_entropy(x, y)
    # 3. we convert the output back to a numpy array, this is what JAX expects
    return out.cpu().numpy() if numpy_mode else out

def cel_vjp(args, gs):
    # 1. if arguments come in as numpy arrays, so we need to convert them to torch tensors
    numpy_mode = any([not isinstance(arg, torch.Tensor) for arg in tree_flatten([args, gs])[0]])
    args, gs = tree_map(lambda arg: torch.as_tensor(arg), [args, gs])
    x, y = args
    # 2. we make use of the new functional interface in PyTorch 2.0
    _, vjp_fn = torch.func.vjp(lambda x: cross_entropy_loss_torch(x, y), x)
    out = vjp_fn(gs)[0] # only the x term matters, y is an integer and does not have gradients
    dx = out.cpu().numpy() if numpy_mode else out
    dy = None
    return (dx, dy)

Then, we construct the custom vector-Jacobian rule in JAX

# time to define the JAX function with a reverse-mode gradient
def cross_entropy_loss_jax(x, y):
    output_shape = jax.ShapeDtypeStruct((), np.float64)
    return jax.pure_callback(cross_entropy_loss_torch, output_shape, x, y)

# 1. we define a fwd pass function which returns the output and the arguments as a tuple
def fwd_fn(x, y):
    args = (x, y)
    return cross_entropy_loss_jax(x, y), args

# 2. we define a backwards pass function that returns just backwards the sensitivity terms
def bwd_fn(args, gs):
    output_shape = (jax.ShapeDtypeStruct(x.shape, x.dtype), None)
    return jax.pure_callback(cel_vjp, output_shape, args, gs)

cross_entropy_loss_jax.defvjp(fwd_fn, bwd_fn)

Compile-friendly control flow


JAX documentation provides a fairly good introduction to converting control flow into JAX language compatible operations. TODO add link

In summary, the two conditional branching operators types are:

  • jax.lax.select & jax.numpy.where compute both alternative values, but only return one
    • because both paths are computed, NaNs can be propagated in reverse-mode jax.grad
    • see here for a discussion of this issue in JAX documentation
    • where works for every entry in an array and is batch friendly
  • jax.cond only computes one of the two code paths depending on the condition value
    • avoids unnecessary computation
    • will not propagate NaNs from one path (as the wrong path will not be computed)
    • requires a single boolean predicate, meaning it has to be manually batched (vmap-ed) over a batch of data conditions

Note: Confusingly, there also exists jax.numpy.select, not jax.lax.select, but we do not discuss this here.

z = 2
if c:
    val = a * z
    val = b * z

# becomes
val = jax.numpy.where(c, a * z, b * z) # or jax.lax.select(c, a, b)
val = jax.cond(c, lambda: a * z, lambda: b * z) 
# or 
val = jax.cond(c, lambda z: a * z, lambda z: b * z, z) # argument version


JAX can trace Python loops, the main justifications for using a JAX loop-like operator instead of a Python loop are:

    1. since JAX does not know of a Python loop, it simply traces computation, the unrolled computation can lead to a large number of operations in the resulting computation graph, taking (much) longer to compile

JAX provides 3 looping routines:

  • jax.lax.scan runs a fixed number of iterations sequentially and returns all of history
  • jax.lax.fori_loop runs a fixed number of iterations sequentially and returns last value
  • jax.lax.while_loop runs a variable number of iterations and returns the last value

Noticeably, JAX is missing a version which runs a variable number of iterations, returning the history. This is likely, because the use case might not be common and JAX generally struggles with representing a variable history - which will have a variable shape.

Loops in JAX accept PyTree arguments, so their use is actually not that complicated!

  1. jax.lax.scan returns the sequential mapping history for a fixed number of iterations
import jax, jax.numpy as jnp, jax.random as random

r = random.normal(random.PRNGKey(0), (10,))
scaled_fib = jnp.array([1, 1])
for i in range(r.size):
    next_num = r[i] * scaled_fib[-1] + scaled_fib[-2]
    scaled_fib = jnp.concatenate([scaled_fib, jnp.array([next_num])])
# becomes
def scan_fn(carry, multiplier):
    prev_num, cur_num = carry  # arguments can be PyTrees
    next_num = multiplier * cur_num + prev_num
    carry = (cur_num, next_num)
    return carry, next_num
prev_num, cur_num = jnp.array(1.0), jnp.array(1.0)
scaled_fib2 = jnp.concatenate(
        jnp.array([prev_num, cur_num]),
        # we select the mapped array, not the carry
        jax.lax.scan(scan_fn, (prev_num, cur_num), r)[1],  
  1. jax.lax.fori_loop returns only the last value for a fixed number of iterations
# a simple ODE
import jax, jax.numpy as jnp
A = jnp.array([[1.0, 0.1], [0.0, 0.9]])
x0 = jnp.array([0.0, 1.0])

x = x0
for i in range(10):
    x = A @ x
# becomes
jax.lax.fori_loop(0, 10, lambda i, x: A @ x, x0)
  1. jax.lax.while_loop returns only the last value for a variable number of iterations
# a simple ODE, which we terminate once we exceed position x[0] = 0.5
# we want to count the number of iterations
import jax, jax.numpy as jnp
A = jnp.array([[1.0, 0.1], [0.0, 0.9]])
x0 = jnp.array([0.0, 1.0])

it, x = 0, x0
while x[0] < 0.5:
    x = A @ x
    it += 1 
print(f"it = {it}")
print(f"x = {x}")
# becomes
# we make use of PyTrees to propagate a tuple of iteration counter and the state x
cond_fn = lambda it_x: it_x[1][0] < 0.5
body_fn = lambda it_x: (it_x[0] + 1, A @ it_x[1])
print(jax.lax.while_loop(cond_fn, body_fn, (0, x0)))

Advanced loop example

What if we want to perform a computation a variable number of times, while retaining history? We can make our own loop version by combining scan with cond to execute a function for a maximum number of iterations (an upper bound on the number of iterations), but we avoid unnecessary computations once the end condition is met, by using cond.

# a simple ODE, which we terminate once we exceed position x[0] = 0.5
# we want to count the number of iterations
import jax, jax.numpy as jnp
A = jnp.array([[1.0, 0.1], [0.0, 0.9]])
x0 = jnp.array([0.0, 1.0])

it, x = 0, x0
while x[0] < 0.5:
    x = A @ x
    it += 1 
print(f"it = {it}")
print(f"x = {x}")
# becomes
# we make use of PyTrees to propagate a tuple of iteration counter and the state x
def scan_fn(carry, _):
    it, x = carry
    # only advance x if the position is not yet reached
    it_x = jax.lax.cond(x[0] < 0.5, lambda x: (it + 1, A @ x), lambda x: (it, x), x)
    return it_x, it_x

it, x = jax.lax.scan(scan_fn, (0, x0), jnp.zeros(100))[0]

NN libraries for JAX

Here, we quickly cover the three main machine learning (ML) libraries for JAX. These, in short, are:

Library Comments
equinox represents ML building blocks using PyTrees, it defines custom-container rules to allow JAX tree_util routines to operate over the ML model object
haiku developed by DeepMind, it closely resembles PyTorch
flax the most popular library, developed by Google (like JAX)

In our work, we qualitatively evaluated these libraries on these main metrics:

  • compilation speed - how quickly the resulting ML models can be compiled with jax.jit
  • documentation quality
  • adoption - how widely adopted the library is
  • PyTorch likeness - how closely model building resembles PyTorch
  • state handling - how well the library allows handling state (e.g., batch norm mean and variance)
Library compilation speed documentation adoption PyTorch-like state handling
equinox x x ~
haiku ~
flax ~

Flax is most undoubtedly most widely adopted and recently includes a new module initialization interface that closely resembles PyTorch, making it easy for us to jump between PyTorch and JAX. However, we ran into long compilation problems with Flax when working on (admittedly very sprawling) DARTS network.

Equinox, despite beautiful design philosophy, defines ML modules to by default accept vector arguments only, with support for batches using vmap. This philosophy, however, presents conceptual problems when attempting to work with batch layers, like BatchNorm. The author provides support for this layer, but the support is currently conceptually confusing to us.

For this reason, we are currently using Haiku!

The jfi package (friendly JAX interface)

In our research and coming from PyTorch, we found it much easier to work with JAX by writing a library of utility functions

  • being able to generate random numbers without specifying a generator key every time
    • by maintaining and updating a global generator key
  • placing the generate arrays on the correct device and dtypes
    • by wrapping generation routines (e.g., zeros, ones, randn) and using jax.device_put
  • creating arrays on the CPU by default, even when the GPU is available
    • by setting the correct JAX environment flags and wrapping the generation routines
  • having access to jax transformations grad, jit, vmap and numpy-like routines under one module
    • by copying the jax.numpy module and binding some jax and jax.random functions to it

Check out our work here: jfi-JAXFriendlyInterface.