Defining gradients automatically: support for AutoDiff
torch2jax_with_vjp
torch2jax.gradients.torch2jax_with_vjp(torch_fn, *example_args, depth=2, nondiff_argnums=None, nondiff_mask=None, output_shapes=None, use_zeros=True, use_torch_vjp=True, output_sharding_spec=None, vmap_method='sequential')
Convert a torch function to a jax function and define a custom vjp rule for it up to depth
recursively deep.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
torch_fn
|
Callable
|
Torch function to convert. |
required |
*example_args
|
Any
|
Example arguments as tensors or torch-compatible args. |
()
|
depth
|
int
|
Max allowed differentiation depth, this is cheap. Defaults to 1. |
2
|
nondiff_argnums
|
list | tuple | None
|
Which (whole) args to not differentiate. Defaults to None. |
None
|
nondiff_mask
|
Any | None
|
Full arg matching mask. Defaults to None. |
None
|
output_shapes
|
Any | None
|
Output shapes out of the function, if provided, we never call torch function to infer them. Defaults to None. |
None
|
use_zeros
|
bool
|
Whether to set gradients of non-diff args to zeros or None. None does not appear to work with JAX currently. Defaults to True. |
True
|
use_torch_vjp
|
bool
|
(Not supported, please use inside |
True
|
output_sharding_spec
|
PartitionSpec | None
|
(not supported) sharding spec of the output, use shard_map instead for a device-local version of this function |
None
|
vmap_method
|
str
|
batching method, see https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap NOTE: only vmap_method="sequntial" is supported non-experimentally NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching |
'sequential'
|
Returns:
Callable: JIT-compatible JAX version of the torch function (VJP defined up to depth depth
).
Examples:
>>> import torch, jax
>>> from torch2jax import torch2jax_with_vjp, tree_t2j
>>> # let's define the torch function and create some example arguments
>>> torch_fn = lambda x, y: torch.nn.CrossEntropyLoss()(x, y)
>>> xt, yt = torch.randn(10, 5), torch.randint(0, 5, (10,))
>>> # we can now convert the function to jax using the torch fn and example args
>>> jax_fn = torch2jax_with_vjp(torch_fn, xt, yt)
>>> jax_fn = jax.jit(jax_fn) # we can jit it too
>>> # let's convert the arguments to JAX arrays and call the function
>>> x, y = tree_t2j((xt, yt))
>>> jax_fn(x, y)
>>> # it works!
>>> # taking gradients is easy too
>>> g_fn = jax.grad(jax_fn, argnums=0)
>>> g_fn(x, y).shape
(10, 5)
>>> # creating a more complicated computational graph is of course possible
>>> lin_model = lambda z, W, b: z @ W + b
>>> z, W, b = tree_t2j([torch.randn((10, 20)), torch.randn(20, 5), torch.randn(5)])
>>> gz_fn = jax.grad(lambda z, W, b: jax_fn(lin_model(z, W, b), y), argnums=(1, 2))
>>> dW, db = gz_fn(z, W, b)
>>> dW.shape, db.shape
((20, 5), (5,))
Source code in torch2jax/gradients.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
|