JAX Advanced Techniques
Created by: Robert Dyro, Spencer Richards
Introduction
What is JAX?
JAX is a Python framework for
 creating computational graphs (or simply routines) using a highlevel language, Python
 manipulating those computation graphs/functions/routines to:
 compile them into efficient computational code
 take automatic derivatives of created functions
 parallelize functions over lots of data
 run the same highlevel computational functions on a variety of hardware: CPU, GPU, TPUs
This tutorial is an attempt to produce an educational reference for JAX
techniques we found particularly interesting in our work and research. As such,
it is a collection of a few disjoint topics.
PyTrees
Since JAX is designed with function transformations in mind (e.g., compilation, derivatives,
parallelization), the first difficulty to arise is how to deal with Python
functions that can have multiple arguments, keyword vs positional arguments,
nested dictionaries for configurations and so on. The philosophy of JAX is simple:
any nested Python structure is allowed. It doesn't matter if your function
takes tuples, nested tuples, nested dictionaries, dictionaries of tuples, etc.,as arguments.
The JAX name for any valid nested Python structure is a Pytree, a tree because
each container (e.g., tuple, list, dictionary) is a root with several children
nodes (the elements in the container).
There are many excellent JAX introductory resources for Pytrees, but the advanced
techniques using Pytrees we cover in this tutorial are the following:
 automatically determining which function arguments should be batched over in
vmap
 how to compile a JAX function that takes arguments not supported by JAX
The three most important methods for manipulating trees in JAX are:
tree_flatten for flattening a Pytree (nested Python container) into a single list
tree_unflatten for reforming a single list of arguments (e.g., once
modified) into the original Pytree (nested Python container) structure
tree_map for applying a function (typically a simple Python lambda) to every
argument in a Pytree (nested Python container)
1. Automatically batching over arguments in vmap
When you're working on computational routines, you might have a large nested
Python dictionary representing your problem data and problem configuration
parameters. If you're trying to batch compute a function over your problem, you
might need to batch over all data, but none of the configuration parameters
(these are the same for each batch). How to automatically compute which
arguments should be batched over and which not?
Note: vmap expects that the user explicitly instructs it which arguments to batch over (and which axis to use)
from typing import Dict, Any, Callable
from jax import vmap, Array
from jax.tree_util import tree_map
def computational_routine(problem: Dict[str, Any]) > Array:
...
def vmap_automatically(computation_fn: Callable, problem: Dict[str, Any], batch_size: int):
# let's define a way to check whether x has a shape
# (is some form of data, strings don't have a "shape" attribute)
# and whether the first entry in the shape corresponds to the `batch_size` provided
def is_batchable(x: Any, batch_size: int):
return hasattr(x, "shape") and len(x.shape) > 0 and x.shape[0] == batch_size
# we're going to map over every argument in the problem (recursively)
# if the argument has the first dimension of size batch_size, 0 is the batch axis
# otherwise, we put None to indicate that the argument should not be batched over
in_axes = tree_map(lambda x: 0 if is_batchable(x, batch_size) else None, problem)
return jax.vmap(computation_fn, in_axes)(problem)
2. How to compile a JAX function that takes arguments unsupported by JAX?
When calling jax.jit to indicate that we want a function compiled, there's an
extra argument, static_argnums (and static_argnames for keyword definition)
for indicating which arguments might change the internal logic of the function
or do things not supported by JAX. This way, JAX compiles a separate version of
the function for each combination of "static" arguments.
The simplest example of this is if we have a boolean flag that changes the
internal logic of our function
def linear_transform(x, W, b, affine: bool = True):
if affine:
return x @ W + b
else:
return x @ W
Here, if we want JAX to properly deal with the argument affine in compilation, we issue
linear_transform_compiled = jax.jit(linear_transform, static_argnames="affine")
which will cause JAX to (justintime) compile a separate version of the
function for each possible value of affine .
The need to explicitly specify which arguments are results in at least two annoying cases:
 compiling member functions of a class
 compiling functions which have an internal logic indicated by a string
We would like to quickly and automatically convert all arguments of a jitted
function that are not JAXrepresentable to static arguments.
from typing import Callable
from jax import jit, Array
class AutoJIT:
def __init__(self, fn: Callable):
self.fn = fn
self.compiled_fn = None
def __call__(self, *args, **kw):
# this is the first time we see the arguments to the function
# we're going to determine which arguments are `static`
if self.compiled_fn is None:
# 1. for every variable argument, we check if it's a jax.Array, if not, it's static
static_argnums = [i for (i, arg) in enumerate(args) if not isinstance(arg, Array)]
# 2. similarly, for every keyword argument, we check and set static if not jax.Array
static_argnames = [k for (k, v) in kw.items() if not isinstance(v, Array)]
# 3. we jit the function with the static arguments and argnames
self.compiled_fn = jit(
self.fn, static_argnums=static_argnums, static_argnames=static_argnames
)
# we just call the compiled function
return self.compiled_fn(*args, **kw)
Advanced example
We can make use of argument flattening to define a more advanced version of the
automatic jit function
from typing import Callable
from jax import jit, Array
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
class AdvancedAutoJIT:
"""The advanced version of AutoJIT, which flattens the arguments to handle pytree inputs."""
def __init__(self, fn: Callable):
self.fn = fn
# we now have a cache of compiled functions for each combination of provided arguments
self.compiled_fn_cache = dict()
def __call__(self, *args, **kw):
# 1. first, we flatten the arguments and keyword arguments
flat_args, args_struct = tree_flatten(args)
flat_kw, kw_struct = tree_flatten(kw)
args_types, kw_types = tree_map(type, flat_args), tree_map(type, flat_kw)
# 2. we produce a "unique" identifier for provided arguments: the structure and types
cache_key = (args_struct, kw_struct, tuple(args_types), tuple(kw_types))
# 3. underneath, we're going to call the function using with all arguments flattened
flat_args_kw = tuple(flat_args) + tuple(flat_kw)
if cache_key not in self.compiled_fn_cache:
# 3. we produce a flat argument version of the function
def flat_fn(*flat_args_kw):
# 4. we need to unflatten the arguments and keyword arguments
args, kw_args = flat_args_kw[: len(flat_args)], flat_args_kw[len(flat_args) :]
args = tree_unflatten(args_struct, args)
kw = tree_unflatten(kw_struct, kw_args)
# 5. we call the function with the unflattened arguments
return self.fn(*args, **kw)
# 6. we can now determine, using the flat argument version, which args need to be static
static_argnums = [
i for (i, arg) in enumerate(flat_args_kw) if not isinstance(arg, Array)
]
# 7. we can now compile the flat argument version of the function
self.compiled_fn_cache[cache_key] = jit(flat_fn, static_argnums=static_argnums)
# 8. we can now call the compiled function with the flat arguments
return self.compiled_fn_cache[cache_key](*flat_args)
The PyTree section has been heavily inspired by the Equinox
https://github.com/patrickkidger/equinox
package which pushes further on the PyTree philosophy in JAX.
Using random numbers in JAX
Using random numbers in JAX requires explicitly providing a random number
generator key and the key uniquely determines the numbers generated.
The two most important methods here are
jax.random.PRNGKey(seed: int)  to generate the initial random key
jax.random.split(key: jax.Array[jax.uint32], num_splits: int) to generate num_splits new keys
Generally, the strategy for generating new random numbers is to:
 split the global existing random key into
n + 1 new keys
 retain the first of the split keys as the new global key
 use the
n other keys for applications
import jax
import time
global_key = jax.random.PRNGKey(int(time.time()))
# 1. we have some data with a batch dimension of 100
data = jax.numpy.ones((100, 56))
# 2. we will generate the right number of random keys for each batch
new_keys = jax.random.split(global_key, data.shape[0] + 1)
global_key, random_keys = new_keys[0], new_keys[1:]
# 3. we use the keys in the downstream application
jax.vmap(fn)(data, random_keys)
Pickling JAX and Persistent Compilation
In our applications, we found the need to serialize ("pickle" in Python)
arbitrary code and JAX data. While the inbuilt pickle module in Python works
well for this purpose, we found that cloudpickle (TODO add link) can work much
better. Especially with lambda s, inline defined functions, which are often
useful when working with JAX functional transformations.
We recognize that sending arbitrary Python code to remote workers is not
necessarily a common application, but with JAX functional philosophy, functions
are often as important as data, so a problemtocompute can now consist of data
and functions.
In this section of the tutorial, we cover two topics code serialization topics:
 how to achieve persistent compilation for longcompiling JAX routines
 how to avoid code recompilation when sending code to remote workers
1. How to achieve persistent compilation for longcompiling JAX routines?
As far as we can tell, there's currently no way to cache or save the result of a
function compilation to disk, so that it survives program restart. Persistent
compilation is particularly useful as code development speed can be
significantly hindered in research contexts, if the function compilation takes
upwards of 1 minute.
The only solution to persistent compilation we found, was to use a persistent
process (a separate program) that could accept Python functions and subsequently
compute data
One simple possibility is to use the remote call Python package rpyc .
For the server code, we simply have
import rpyc
import jax
import cloudpickle as ser
class PersistentJAXFunctions(rpyc.Service):
def exposed_call(self, fn_bytes: bytes, args_bytes, kwargs_bytes):
if not hasattr(self, "persistent_functions"):
self.persistent_functions = dict()
if fn_bytes not in self.persistent_functions:
self.persistent_functions[fn_bytes] = jax.jit(ser.loads(fn_bytes))
args, kwargs = ser.loads(args_bytes), ser.loads(kwargs_bytes)
return self.persistent_functions[fn_bytes](*args, **kwargs)
and for the client
import rpyc
import jax
import cloudpickle as ser
# a decorator to make a function persistent via calls to an rpyc server
def persistent_call(fn, port=18861):
c = rpyc.connect("localhost", port)
fn_bytes = ser.dumps(fn)
def persistent_fn(*args, **kw):
args_bytes, kw_bytes = ser.dumps(args), ser.dumps(kw)
return c.root.call(fn_bytes, args_bytes, kw_bytes)
return persistent_fn
@persistent_call
def linear_transform(x, W, b):
return x @ W + b
The obvious disadvantage of this solution is that it can easily be bottlenecked
by what is effectively a TCP based data transfer, which can be slow.
2. What to do if pickling fails on a JAX object?
In Python, if an object fails to pickle, it is possible to define a custom serialization method using the inbuilt library copyreg like so:
import copyreg
class P:
def __init__(self, a = 1.0):
self.a = a
def custom_pickling(p: P):
def custom_constructor(a):
return P(a * 2)
pickleable_objs_to_save = (p.a,)
return custom_constructor, pickleable_objs_to_save
# register the custom pickling method
copyreg.pickle(P, custom_pickling)
Extending JAX
1. Without defining gradients (KDTree)
Let's consider the example of wrapping scipy object for KDTree nearest point
search. This object primarily returns integer indices, so we need not worry about
differentiability.
class JAXKDTree(KDTree):
def __init__(self, data, n_return: int = 10, **kw):
self.jax_data = data
super().__init__(np.array(data), **kw)
self.n_return = n_return
def query(self, x, **kw):
# this function returns a float distance and an int index, so it's easy to wrap
# 1. we cast x to a numpy array
# 2. we recompute the distance to the nearest neighbor in JAX to make it differentiable
idx = jax.pure_callback(lambda x: np.array(super().query(x)[1]), ShapeDtypeStruct((), int), x)
#_, idx = super().query(np.array(x), **kw)
dist = jnp.linalg.norm(x  self.jax_data[idx, :])
return dist, idx
def _query_ball_point(self, x, r, **kw):
# for JAX, we need to make sure that this function returns the same output shape every time
# 1. we call the original function
idx = super().query_ball_point(x, r, **kw)
# 2. we sort the indices by distance
dists = [np.linalg.norm(x  self.data[i, :]) for i in idx]
idx = [idx[i] for i in np.argsort(dists)]
# 3. we pad or truncate the output to keep the fixed length of (self.n_return,)
idx = idx[: self.n_return] + [1 for _ in range(self.n_return  len(idx))]
return np.array(idx)
def query_ball_point(self, x, r, **kw):
# 1. for a pure_callback in JAX, we construct the function output shape and dtype
# this can be a pytree of these shape + dtype structs
output_shape = jax.ShapeDtypeStruct((self.n_return,), np.int64)
# 2. we wrap the original function with jax.pure_callback
return jax.pure_callback(
lambda x, r: self._query_ball_point(x, r, **kw), output_shape, x, r
)
2. With gradients (Torch)
Let's wrap the PyTorch cross_entropy_loss function, which includes both
floating point and integer arguments.
We first define the PyTorch interface as follows
import torch
from jax.tree_util import tree_map, tree_flatten
# we define two, numpy compatible, pytorch functions
def cross_entropy_loss_torch(x, y):
# 1. if arguments come in as numpy arrays, so we need to convert them to torch tensors
numpy_mode = any([not isinstance(arg, torch.Tensor) for arg in [x, y]])
x, y = tree_map(lambda arg: torch.as_tensor(arg), [x, y])
# 2. we use pytorch here directly
out = torch.nn.functional.cross_entropy(x, y)
# 3. we convert the output back to a numpy array, this is what JAX expects
return out.cpu().numpy() if numpy_mode else out
def cel_vjp(args, gs):
# 1. if arguments come in as numpy arrays, so we need to convert them to torch tensors
numpy_mode = any([not isinstance(arg, torch.Tensor) for arg in tree_flatten([args, gs])[0]])
args, gs = tree_map(lambda arg: torch.as_tensor(arg), [args, gs])
x, y = args
# 2. we make use of the new functional interface in PyTorch 2.0
_, vjp_fn = torch.func.vjp(lambda x: cross_entropy_loss_torch(x, y), x)
out = vjp_fn(gs)[0] # only the x term matters, y is an integer and does not have gradients
dx = out.cpu().numpy() if numpy_mode else out
dy = None
return (dx, dy)
Then, we construct the custom vectorJacobian rule in JAX
# time to define the JAX function with a reversemode gradient
@jax.custom_vjp
def cross_entropy_loss_jax(x, y):
output_shape = jax.ShapeDtypeStruct((), np.float64)
return jax.pure_callback(cross_entropy_loss_torch, output_shape, x, y)
# 1. we define a fwd pass function which returns the output and the arguments as a tuple
def fwd_fn(x, y):
args = (x, y)
return cross_entropy_loss_jax(x, y), args
# 2. we define a backwards pass function that returns just backwards the sensitivity terms
def bwd_fn(args, gs):
output_shape = (jax.ShapeDtypeStruct(x.shape, x.dtype), None)
return jax.pure_callback(cel_vjp, output_shape, args, gs)
cross_entropy_loss_jax.defvjp(fwd_fn, bwd_fn)
Compilefriendly control flow
Ifstatements
JAX documentation provides a fairly good introduction to converting control flow
into JAX language compatible operations. TODO add link
In summary, the two conditional branching operators types are:
jax.lax.select & jax.numpy.where compute both alternative values, but only return one
 because both paths are computed, NaNs can be propagated in reversemode
jax.grad
 see here for a discussion of this issue in JAX documentation
where works for every entry in an array and is batch friendly
jax.cond only computes one of the two code paths depending on the condition
value
 avoids unnecessary computation
 will not propagate NaNs from one path (as the wrong path will not be computed)
 requires a single boolean predicate, meaning it has to be manually batched (
vmap ed) over a batch of data conditions
Note: Confusingly, there also exists jax.numpy.select , not jax.lax.select , but we do not discuss this here.
z = 2
if c:
val = a * z
else:
val = b * z
# becomes
val = jax.numpy.where(c, a * z, b * z) # or jax.lax.select(c, a, b)
val = jax.cond(c, lambda: a * z, lambda: b * z)
# or
val = jax.cond(c, lambda z: a * z, lambda z: b * z, z) # argument version
Loops
JAX can trace Python loops, the main justifications for using a JAX looplike
operator instead of a Python loop are:
 since JAX does not know of a Python loop, it simply traces computation,
the unrolled computation can lead to a large number of operations in the
resulting computation graph, taking (much) longer to compile
JAX provides 3 looping routines:
jax.lax.scan runs a fixed number of iterations sequentially and returns all of history
jax.lax.fori_loop runs a fixed number of iterations sequentially and returns last value
jax.lax.while_loop runs a variable number of iterations and returns the last value
Noticeably, JAX is missing a version which runs a variable number of iterations,
returning the history. This is likely, because the use case might not be common
and JAX generally struggles with representing a variable history  which will
have a variable shape.
Loops in JAX accept PyTree arguments, so their use is actually not that complicated!
jax.lax.scan returns the sequential mapping history for a fixed number of iterations
import jax, jax.numpy as jnp, jax.random as random
r = random.normal(random.PRNGKey(0), (10,))
scaled_fib = jnp.array([1, 1])
for i in range(r.size):
next_num = r[i] * scaled_fib[1] + scaled_fib[2]
scaled_fib = jnp.concatenate([scaled_fib, jnp.array([next_num])])
# becomes
def scan_fn(carry, multiplier):
prev_num, cur_num = carry # arguments can be PyTrees
next_num = multiplier * cur_num + prev_num
carry = (cur_num, next_num)
return carry, next_num
prev_num, cur_num = jnp.array(1.0), jnp.array(1.0)
scaled_fib2 = jnp.concatenate(
[
jnp.array([prev_num, cur_num]),
# we select the mapped array, not the carry
jax.lax.scan(scan_fn, (prev_num, cur_num), r)[1],
]
)
jax.lax.fori_loop returns only the last value for a fixed number of iterations
# a simple ODE
import jax, jax.numpy as jnp
A = jnp.array([[1.0, 0.1], [0.0, 0.9]])
x0 = jnp.array([0.0, 1.0])
x = x0
for i in range(10):
x = A @ x
# becomes
jax.lax.fori_loop(0, 10, lambda i, x: A @ x, x0)
jax.lax.while_loop returns only the last value for a variable number of iterations
# a simple ODE, which we terminate once we exceed position x[0] = 0.5
# we want to count the number of iterations
import jax, jax.numpy as jnp
A = jnp.array([[1.0, 0.1], [0.0, 0.9]])
x0 = jnp.array([0.0, 1.0])
it, x = 0, x0
while x[0] < 0.5:
x = A @ x
it += 1
print(f"it = {it}")
print(f"x = {x}")
# becomes
# we make use of PyTrees to propagate a tuple of iteration counter and the state x
cond_fn = lambda it_x: it_x[1][0] < 0.5
body_fn = lambda it_x: (it_x[0] + 1, A @ it_x[1])
print(jax.lax.while_loop(cond_fn, body_fn, (0, x0)))
Advanced loop example
What if we want to perform a computation a variable number of times, while
retaining history? We can make our own loop version by combining scan with
cond to execute a function for a maximum number of iterations (an upper bound
on the number of iterations), but we avoid unnecessary computations once the end condition is
met, by using cond .
# a simple ODE, which we terminate once we exceed position x[0] = 0.5
# we want to count the number of iterations
import jax, jax.numpy as jnp
A = jnp.array([[1.0, 0.1], [0.0, 0.9]])
x0 = jnp.array([0.0, 1.0])
it, x = 0, x0
while x[0] < 0.5:
x = A @ x
it += 1
print(f"it = {it}")
print(f"x = {x}")
# becomes
# we make use of PyTrees to propagate a tuple of iteration counter and the state x
def scan_fn(carry, _):
it, x = carry
# only advance x if the position is not yet reached
it_x = jax.lax.cond(x[0] < 0.5, lambda x: (it + 1, A @ x), lambda x: (it, x), x)
return it_x, it_x
it, x = jax.lax.scan(scan_fn, (0, x0), jnp.zeros(100))[0]
NN libraries for JAX
Here, we quickly cover the three main machine learning (ML) libraries
for JAX. These, in short, are:
Library 
Comments 
equinox 
represents ML building blocks using PyTrees, it defines customcontainer rules to allow JAX tree_util routines to operate over the ML model object 
haiku 
developed by DeepMind, it closely resembles PyTorch 
flax 
the most popular library, developed by Google (like JAX) 
In our work, we qualitatively evaluated these libraries on these main metrics:
 compilation speed  how quickly the resulting ML models can be compiled with
jax.jit
 documentation quality
 adoption  how widely adopted the library is
 PyTorch likeness  how closely model building resembles PyTorch
 state handling  how well the library allows handling state (e.g., batch norm mean and variance)
Library 
compilation speed 
documentation 
adoption 
PyTorchlike 
state handling 
equinox 
✓ 
✓ 
x 
x 
~ 
haiku 
✓ 
✓ 
~ 
✓ 
✓ 
flax 
~ 
✓ 
✓ 
✓ 
✓ 
Flax is most undoubtedly most widely adopted and recently includes a new module
initialization interface that closely resembles PyTorch, making it easy for us
to jump between PyTorch and JAX. However, we ran into long compilation problems
with Flax when working on (admittedly very sprawling)
DARTS network.
Equinox, despite beautiful design philosophy, defines ML modules to by default
accept vector arguments only, with support for batches using vmap . This
philosophy, however, presents conceptual problems when attempting to work with
batch layers, like BatchNorm. The author provides support for this layer, but
the support is currently conceptually confusing to us.
For this reason, we are currently using Haiku!
The jfi package (friendly JAX interface)
In our research and coming from PyTorch, we found it much easier to work with
JAX by writing a library of utility functions
 being able to generate random numbers without specifying a generator key every time
 by maintaining and updating a global generator key
 placing the generate arrays on the correct device and dtypes
 by wrapping generation routines (e.g.,
zeros , ones , randn ) and using jax.device_put
 creating arrays on the CPU by default, even when the GPU is available
 by setting the correct JAX environment flags and wrapping the generation routines
 having access to jax transformations
grad, jit, vmap and numpylike routines under one module
 by copying the
jax.numpy module and binding some jax and jax.random functions to it
Check out our work here:
jfiJAXFriendlyInterface.
Resources/References
