Skip to content

Converting a PyTorch Function to JAX (without gradients)

torch2jax

torch2jax.api.torch2jax(fn, *example_args, example_kw=None, example_kwargs=None, output_shapes=None, input_struct=None, use_torch_vmap=True)

Define a jit-compatible JAX function that calls a PyTorch function. Arbitrary nesting of arguments and outputs is supported.

Parameters:

Name Type Description Default
fn Callable

PyTorch function to wrap.

required
*example_args Any

Example arguments as tensors or torch-compatible args.

()
example_kw Any | None

Example keyword arguments. Defaults to None.

None
example_kwargs Any | None

Example keyword arguments. Defaults to None.

None
output_shapes Any

Output shapes or shapes + dtype struct. Defaults to None.

None
input_struct PyTreeDef | None

Input structure, which can be inferred from example arguments and keywords. Defaults to None.

None
use_torch_vmap bool

Whether to batch using torch.vmap or a dumb loop. Defaults to True.

True

Returns:

Name Type Description
Callable Callable

JIT-compatible JAX function.

Examples:

>>> import torch, jax
>>> from torch2jax import torch2jax_with_vjp, tree_t2j
>>> # let's define the torch function and create some example arguments
>>> torch_fn = lambda x, y: torch.nn.CrossEntropyLoss()(x, y)
>>> xt, yt = torch.randn(10, 5), torch.randint(0, 5, (10,))
>>> # we can not convert the function to jax using the torch fn and example args
>>> jax_fn = torch2jax_with_vjp(torch_fn, xt, yt)
>>> jax_fn = jax.jit(jax_fn) # we can jit it too
>>> # let's convert the arguments to JAX arrays and call the function
>>> x, y = tree_t2j((xt, yt))
>>> jax_fn(x, y)
>>> # it works!
Source code in torch2jax/api.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def torch2jax(
    fn: Callable,
    *example_args: Any,
    example_kw: Any | None = None,
    example_kwargs: Any | None = None,
    output_shapes: Any = None,
    input_struct: PyTreeDef | None = None,
    use_torch_vmap: bool = True,
) -> Callable:
    """Define a jit-compatible JAX function that calls a PyTorch function.  Arbitrary nesting of
    arguments and outputs is supported.

    Args:
        fn (Callable): PyTorch function to wrap.
        *example_args (Any): Example arguments as tensors or torch-compatible args.
        example_kw (Any | None, optional): Example keyword arguments. Defaults to None.
        example_kwargs (Any | None, optional): Example keyword arguments. Defaults to None.
        output_shapes (Any, optional): Output shapes or shapes + dtype struct. Defaults to None.
        input_struct (PyTreeDef | None, optional): Input structure, which can be inferred from
                                                   example arguments and keywords. Defaults to None.
        use_torch_vmap (bool, optional): Whether to batch using torch.vmap or a dumb loop. Defaults to
                                         True.
    Returns:
        Callable: JIT-compatible JAX function.

    Examples:
        >>> import torch, jax
        >>> from torch2jax import torch2jax_with_vjp, tree_t2j
        >>> # let's define the torch function and create some example arguments
        >>> torch_fn = lambda x, y: torch.nn.CrossEntropyLoss()(x, y)
        >>> xt, yt = torch.randn(10, 5), torch.randint(0, 5, (10,))
        >>> # we can not convert the function to jax using the torch fn and example args
        >>> jax_fn = torch2jax_with_vjp(torch_fn, xt, yt)
        >>> jax_fn = jax.jit(jax_fn) # we can jit it too
        >>> # let's convert the arguments to JAX arrays and call the function
        >>> x, y = tree_t2j((xt, yt))
        >>> jax_fn(x, y)
        >>> # it works!
    """

    # check for presence of example_args and example_kw
    msg = "Please provide either example_kw or example_kwargs, not both."
    assert example_kw is None or example_kwargs is None, msg
    if example_kwargs is not None:
        example_kw = example_kwargs
    has_kw = example_kw is not None

    if input_struct is None:
        if has_kw:
            input_struct = tree_structure((example_args, example_kw))
        else:
            input_struct = tree_structure(example_args)

    if output_shapes is None:
        with torch.no_grad():
            output = fn(*example_args, **example_kw) if has_kw else fn(*example_args)
        output_shapes, output_struct = tree_flatten(
            tree_map(lambda x: ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)), output)
        )

    else:
        output_shapes, output_struct = tree_flatten(output_shapes)
        msg = "Please provide all shapes as torch.Size or jax.ShapeDtypeStruct."
        assert all(
            isinstance(x, (torch.Size, ShapedArray, ShapeDtypeStruct)) or hasattr(x, "shape")
            for x in output_shapes
        ), msg

    # define flattened version of the function (flat arguments and outputs)
    def flat_fn(*args_flat):
        nonlocal output_shapes, example_args
        if has_kw:
            args, kw = tree_unflatten(input_struct, args_flat)
            ret = fn(*args, **kw)
        else:
            args = tree_unflatten(input_struct, args_flat)
            ret = fn(*args)
        return tree_flatten(ret)[0]

    # define the wrapped function using flat interface
    wrapped_fn_flat = torch2jax_flat(
        flat_fn, output_shapes=output_shapes, use_torch_vmap=use_torch_vmap
    )

    if has_kw:

        def wrapped_fn(*args, **kw):
            ret = wrapped_fn_flat(*tree_flatten((args, kw))[0])
            return tree_unflatten(output_struct, ret)

    else:

        def wrapped_fn(*args):
            ret = wrapped_fn_flat(*tree_flatten(args)[0])
            return tree_unflatten(output_struct, ret)

    return wrapped_fn