Skip to content

Roadmap

  • call PyTorch functions on JAX data without input data copy
  • call PyTorch functions on JAX data without input data copy under jit
  • support both GPU and CPU
  • (feature) support partial CPU building on systems without CUDA
  • (user-friendly) support functions with a single output (return a single output, not a tuple)
  • (user-friendly) support arbitrary argument input and output structure (use pytrees on the Python side)
  • (feature) support batching (e.g., support for jax.vmap)
  • (feature) support integer input/output types
  • (feature) support mixed-precision arguments in inputs/outputs
  • (feature) support defining VJP for the wrapped function (import the experimental functionality from jit-JAXFriendlyInterface)
  • (tests) test how well device mapping works on multiple GPUs
  • (tests) setup automatic tests for multiple versions of Python, PyTorch and JAX
  • (feature) look into supporting in-place functions (support for output without copy)
  • (feature) support TPU