utils
Utilities
torch2jax.j2t(x, via='dlpack')
Transfer a single jax.Array to a PyTorch tensor.
Source code in torch2jax/dlpack_passing.py
torch2jax.t2j(x, via='dlpack')
torch2jax.tree_j2t(xs, via='dlpack')
Transfer a tree of PyTorch tensors to a corresponding tree of jax.Array-s.
Source code in torch2jax/dlpack_passing.py
torch2jax.tree_t2j(xs, via='dlpack')
Transfer a tree of jax.Array-s to a corresponding tree of PyTorch tensors.