torch2jax_with_vjp (deprecated)
Deprecated: use
torch2jax(..., depth=2)instead.
torch2jax_with_vjp
torch2jax.gradients.torch2jax_with_vjp(*args, depth=2, **kw)
Deprecated: use torch2jax(..., depth=2) instead.