Skip to content

utils

Utilities

torch2jax.j2t(x, via='dlpack')

Transfer a single jax.Array to a PyTorch tensor.

Source code in torch2jax/dlpack_passing.py
def j2t(x: Array, via: str = "dlpack") -> Tensor:
    """Transfer a single jax.Array to a PyTorch tensor."""
    try:
        devices = x.devices()
        if len(devices) > 1:
            msg = "You are attempting to convert a JAX array with multiple devices to a PyTorch tensor."
            msg += " This is not supported"
            raise RuntimeError(msg)
        device = list(devices)[0]
    except ConcretizationTypeError:
        msg = "You are attempting to convert a non-concrete JAX array to a PyTorch tensor."
        msg += " This is not supported, since that JAX array does not contain any numbers."
        raise RuntimeError(msg)
    return _transfer(x, via=via, device=device)
torch2jax.t2j(x, via='dlpack')

Transfer a single PyTorch tensor to a jax.Array.

Source code in torch2jax/dlpack_passing.py
def t2j(x: Tensor, via: str = "dlpack") -> Array:
    """Transfer a single PyTorch tensor to a jax.Array."""
    return _transfer(x, via=via, device=x.device)
torch2jax.tree_j2t(xs, via='dlpack')

Transfer a tree of PyTorch tensors to a corresponding tree of jax.Array-s.

Source code in torch2jax/dlpack_passing.py
def tree_j2t(xs: list[Array] | tuple[Array], via: str = "dlpack") -> list[Tensor] | tuple[Tensor]:
    """Transfer a tree of PyTorch tensors to a corresponding tree of jax.Array-s."""
    return jax.tree.map(lambda x: j2t(x, via=via) if isinstance(x, Array) else x, xs)
torch2jax.tree_t2j(xs, via='dlpack')

Transfer a tree of jax.Array-s to a corresponding tree of PyTorch tensors.

Source code in torch2jax/dlpack_passing.py
def tree_t2j(xs: list[Tensor] | tuple[Array], via: str = "dlpack") -> list[Array] | tuple[Array]:
    """Transfer a tree of  jax.Array-s to a corresponding tree of PyTorch tensors."""
    return jax.tree.map(lambda x: t2j(x, via=via) if isinstance(x, Tensor) else x, xs)