Skip to content

Changelog

  • version 0.4.1
    • bug-fix: in torch2jax_with_vjp, nondiff arguments were erroneously memorized
  • version 0.4.0
    • added batching (vmap support) using torch.vmap, this makes jax.jacobian work
    • robustified support for gradients
    • added mixed type arguments, including support for float16, float32, float64 and integer types
    • removed unnecessary torch function calls in defining gradients
    • added an example of wrapping a BERT model in JAX (with weights modified from JAX), examples/bert_from_jax.ipynb
  • version 0.3.0
    • added a beta-version of a new wrapping method torch2jax_with_vjp which allows recursively defining reverse-mode gradients for the wrapped torch function that works in JAX both normally and under JIT
  • version 0.2.0
    • arbitrary input and output structure is now allowed
    • removed the restriction on the number of arguments or their maximum dimension
    • old interface is available via torch2jax.compat.torch2jax
  • version 0.1.2
    • full CPU only version support, selected via torch.cuda.is_available()
    • bug-fix: compilation should now cache properly
  • version 0.1.1
    • bug-fix: functions do not get overwritten, manual fn id parameter replaced with automatic id generation
    • compilation caching is now better
  • version 0.1.0
    • first working version of the package