Optimization Sensitivity in JAX - Reference Documentation - First- and Second-order Optimization
sensitivity_jax
is a package designed to allow taking first- and
second-order derivatives through optimization or any other fixed-point
process.
Source code for this package is located here: github.com/rdyro/sensitivity_jax.
This package builds on top of JAX. We also maintain an implementation in PyTorch here.