bug-fix: in torch2jax_with_vjp, nondiff arguments were erroneously memorized
version 0.4.0
added batching (vmap support) using torch.vmap, this makes jax.jacobian work
robustified support for gradients
added mixed type arguments, including support for float16, float32, float64 and integer types
removed unnecessary torch function calls in defining gradients
added an example of wrapping a BERT model in JAX (with weights modified from JAX), examples/bert_from_jax.ipynb
version 0.3.0
added a beta-version of a new wrapping method torch2jax_with_vjp which
allows recursively defining reverse-mode gradients for the wrapped torch
function that works in JAX both normally and under JIT
version 0.2.0
arbitrary input and output structure is now allowed
removed the restriction on the number of arguments or their maximum dimension
old interface is available via torch2jax.compat.torch2jax
version 0.1.2
full CPU only version support, selected via torch.cuda.is_available()
bug-fix: compilation should now cache properly
version 0.1.1
bug-fix: functions do not get overwritten, manual fn id parameter replaced with automatic id generation