Roadmap
- call PyTorch functions on JAX data without input data copy
- call PyTorch functions on JAX data without input data copy under jit
- support both GPU and CPU
- (feature) support partial CPU building on systems without CUDA
- (user-friendly) support functions with a single output (return a single output, not a tuple)
- (user-friendly) support arbitrary argument input and output structure (use pytrees on the Python side)
- (feature) support batching (e.g., support for
jax.vmap
) - (feature) support integer input/output types
- (feature) support mixed-precision arguments in inputs/outputs
- (feature) support defining VJP for the wrapped function (import the experimental functionality from jit-JAXFriendlyInterface)
- (tests) test how well device mapping works on multiple GPUs
- (tests) setup automatic tests for multiple versions of Python, PyTorch and JAX
- (feature) look into supporting in-place functions (support for output without copy)
- (feature) support TPU