Skip to content

torch2jax_without_vjp — forward-only with sharding support

Use this for multi-GPU sharding (output_sharding_spec) and keyword arguments (example_kw).

torch2jax_without_vjp

torch2jax.api._torch2jax(fn, *example_args, example_kw=None, output_shapes=None, output_sharding_spec=None, vmap_method='sequential')

Define a jit-compatible JAX function that calls a PyTorch function. Arbitrary nesting of arguments and outputs is supported.

Parameters:

Name Type Description Default
fn Callable

PyTorch function to wrap.

required
*example_args Any

Example arguments as tensors or torch-compatible args.

()
example_kw Any | None

Example keyword arguments. Defaults to None.

None
output_shapes Any

Output shapes or shapes + dtype struct. Defaults to None.

None
output_sharding_spec PartitionSpec | None

jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.

None
vmap_method str

batching method, see https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap

NOTE: only vmap_method="sequntial" is supported non-experimentally

NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching

'sequential'

Returns: Callable: JIT-compatible JAX function.

Examples:

>>> import torch, jax
>>> from torch2jax import torch2jax_with_vjp, tree_t2j
>>> # let's define the torch function and create some example arguments
>>> torch_fn = lambda x, y: torch.nn.CrossEntropyLoss()(x, y)
>>> xt, yt = torch.randn(10, 5), torch.randint(0, 5, (10,))
>>> # we can now convert the function to jax using the torch fn and example args
>>> jax_fn = torch2jax_with_vjp(torch_fn, xt, yt)
>>> jax_fn = jax.jit(jax_fn) # we can jit it too
>>> # let's convert the arguments to JAX arrays and call the function
>>> x, y = tree_t2j((xt, yt))
>>> jax_fn(x, y)
>>> # it works!
Source code in torch2jax/api.py
def _torch2jax(
    fn: Callable,
    *example_args: Any,
    example_kw: Any | None = None,
    output_shapes: Any = None,
    output_sharding_spec: PartitionSpec | None = None,
    vmap_method: str = "sequential",
) -> Callable:
    """Define a jit-compatible JAX function that calls a PyTorch function.  Arbitrary nesting of
    arguments and outputs is supported.

    Args:
        fn (Callable): PyTorch function to wrap.
        *example_args (Any): Example arguments as tensors or torch-compatible args.
        example_kw: Example keyword arguments. Defaults to None.
        output_shapes: Output shapes or shapes + dtype struct. Defaults to None.
        output_sharding_spec: jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.
        vmap_method: batching method, see
            [https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)

            NOTE: only vmap_method="sequntial" is supported non-experimentally

            NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
    Returns:
        Callable: JIT-compatible JAX function.

    Examples:
        >>> import torch, jax
        >>> from torch2jax import torch2jax_with_vjp, tree_t2j
        >>> # let's define the torch function and create some example arguments
        >>> torch_fn = lambda x, y: torch.nn.CrossEntropyLoss()(x, y)
        >>> xt, yt = torch.randn(10, 5), torch.randint(0, 5, (10,))
        >>> # we can now convert the function to jax using the torch fn and example args
        >>> jax_fn = torch2jax_with_vjp(torch_fn, xt, yt)
        >>> jax_fn = jax.jit(jax_fn) # we can jit it too
        >>> # let's convert the arguments to JAX arrays and call the function
        >>> x, y = tree_t2j((xt, yt))
        >>> jax_fn(x, y)
        >>> # it works!
    """

    # check for presence of example_args and example_kw
    _had_output_shapes = output_shapes is not None
    has_kw = example_kw is not None

    # find the input structure
    if has_kw:
        input_struct = jax.tree.structure((example_args, example_kw))
    else:
        input_struct = jax.tree.structure(example_args)

    # define flattened version of the function (flat arguments and outputs)
    def flat_fn(*args_flat):
        if has_kw:
            args, kw = jax.tree.unflatten(input_struct, args_flat)
            ret = fn(*args, **kw)
        else:
            args = jax.tree.unflatten(input_struct, args_flat)
            ret = fn(*args)
        return jax.tree.leaves(ret)

    example_inputs = (example_args, example_kw) if has_kw else example_args
    input_shapes = jax.tree.map(lambda x: ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)), example_inputs)

    # find the output structure
    if output_shapes is None:
        with torch.no_grad():
            output = fn(*example_args, **example_kw) if has_kw else fn(*example_args)
        output_shapes, output_struct = jax.tree.flatten(
            jax.tree.map(lambda x: ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)), output)
        )
    else:
        if not all(
            isinstance(x, (torch.Size, ShapeDtypeStruct, jax.Array, torch.Tensor)) or hasattr(x, "shape")
            for x in jax.tree.leaves(output_shapes)
        ):
            warn_once(_WARN_OUTPUT_SHAPES_FORMAT, fn)
        output_shapes = normalize_shapes(output_shapes, extra_args=input_shapes)
        output_shapes, output_struct = jax.tree.flatten(output_shapes)
    if output_sharding_spec is not None:
        output_sharding_spec_flat, output_sharding_struct = jax.tree.flatten(output_sharding_spec)
        msg = (
            "When providing `output_shading_spec` its structure must match the structure of `output_shapes`."
            f"\nExpected: {output_struct}\n Actual:   {output_sharding_struct}"
        )
        assert jax.tree.structure(output_sharding_spec) == output_struct, msg
    else:
        output_sharding_spec_flat, output_sharding_struct = None, None

    # define the wrapped function using flat interface
    wrapped_fn_flat = _torch2jax_flat(
        flat_fn,
        input_shapes=None,
        output_shapes=output_shapes,
        output_sharding_spec=output_sharding_spec_flat,
        vmap_method=vmap_method,
    )

    # shape-aware cache for automatic re-wrapping on shape changes
    _cache = {}
    _original_shape_key = tuple(
        (tuple(a.shape), dtype_t2j(a.dtype))
        for a in jax.tree.leaves((example_args, example_kw) if has_kw else example_args)
    )
    format_key = lambda key: ", ".join([f"{np.dtype(k[1]).name}{list(k[0])}" for k in key])

    # define the actual wrapper function
    def wrapped_fn(*args, **kw):
        if not has_kw and len(kw) > 0:
            raise RuntimeError("Keyword arguments not expected!")
        if has_kw:
            args = (args, kw)
        if jax.tree.structure(args) != input_struct:
            msg = (_MISMATCH_ARGS_KW_MSG if has_kw else _MISMATCH_ARGS_MSG).format(args, input_struct)
            raise RuntimeError(msg)

        shape_key = tuple((tuple(a.shape), a.dtype) for a in jax.tree.leaves(args))
        if shape_key != _original_shape_key:
            if shape_key not in _cache:
                msg = _SHAPE_CHANGE_WARN_EXPLICIT if _had_output_shapes else _SHAPE_CHANGE_WARN_CONCRETE
                warn_always(msg.format(format_key(_original_shape_key), format_key(shape_key)))
                dummy_flat = [torch.zeros(a.shape, dtype=dtype_j2t(a.dtype)) for a in jax.tree.leaves(args)]
                dummy_tree = jax.tree.unflatten(input_struct, dummy_flat)
                opts = dict(output_sharding_spec=output_sharding_spec, vmap_method=vmap_method)
                dummy_args, dummy_kw = dummy_tree if has_kw else (dummy_tree, None)
                _cache[shape_key] = _torch2jax(fn, *dummy_args, example_kw=dummy_kw, **opts)
            return _cache[shape_key](*args[0], **args[1]) if has_kw else _cache[shape_key](*args)

        ret = wrapped_fn_flat(*jax.tree.leaves(args))
        return jax.tree.unflatten(output_struct, ret)

    return wrapped_fn