Skip to content

sensitivity

name summary
generate_optimization_fns(loss_fn, opt_fn, k_fn, normalize_grad, optimizations, jit, custom_arg_serializer) Directly generates upper/outer bilevel program derivative functions.
generate_optimization_with_state_fns(loss_fn, opt_fn, k_fn, normalize_grad, optimizations, jit, custom_arg_serializer) Directly generates upper/outer bilevel program derivative functions.
implicit_hessian(k_fn, z, params, nondiff_kw, Dg, Hg, jvp_vec, optimizations) Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec.
implicit_jacobian(k_fn, z, params, nondiff_kw, Dg, jvp_vec, matrix_free_inverse, full_output, optimizations) Computes the implicit Jacobian or VJP or JVP depending on Dg, jvp_vec.

batch_sensitivity

name summary
generate_optimization_fns(loss_fn, opt_fn, k_fn, normalize_grad, optimizations, jit, use_cache, kw_in_key) Directly generates upper/outer bilevel program derivative functions.
implicit_hessian(k_fn, z, params, nondiff_kw, Dg, Hg, jvp_vec, optimizations) Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec, using batched operations.
implicit_hessian2(k_fn, z, params, nondiff_kw, Dg, Hg, jvp_vec, optimizations) Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec, using vmap.
implicit_jacobian(k_fn, z, params, nondiff_kw, Dg, jvp_vec, matrix_free_inverse, full_output, optimizations) Computes the implicit Jacobian or VJP or JVP depending on Dg, jvp_vec, using batched operations.
implicit_jacobian2(k_fn, z, params, nondiff_kw, Dg, jvp_vec, matrix_free_inverse, full_output, optimizations) Computes the implicit Jacobian or VJP or JVP depending on Dg, jvp_vec, using vmap.

differentiation

name summary
BATCH_HESSIAN(fn, config) Computes the Hessian, assuming the first in/out dimension is the batch.
BATCH_JACOBIAN(fn, config) Computes the Hessian, assuming the first in/out dimension is the batch.
HESSIAN_DIAG(fn, config) Generates a function which computes per-argument partial Hessians.

extras.optimization.agd

name summary
minimize_agd(f_fn, g_fn, args, verbose, verbose_prefix, max_it, ai, af, full_output, callback_fn, use_writer, use_tqdm, state, optimizer) Minimize a loss function f_fn with Accelerated Gradient Descent (AGD) with respect to *args. Uses PyTorch.

extras.optimization.lbfgs

name summary
minimize_lbfgs(f_fn, g_fn, args, verbose, verbose_prefix, lr, max_it, full_output, callback_fn, use_writer, use_tqdm, state) Minimize a loss function f_fn with L-BFGS with respect to *args. Taken from PyTorch.

extras.optimization.sqp

name summary
minimize_sqp(f_fn, g_fn, h_fn, args, reg0, verbose, verbose_prefix, max_it, ls_pts_nb, force_step, full_output, callback_fn, use_writer, use_tqdm, state, parallel_ls, jit) Minimizes an unconstrained objective using Sequential Quadratic Programming (SQP).

utils

name summary
fn_with_sol_and_state_cache(fwd_fn, cache, jit, use_cache, kw_in_key, custom_arg_serializer) Wraps a function in a version where computation of the first argument via fwd_fn is cached.
fn_with_sol_cache(fwd_fn, cache, jit, use_cache, kw_in_key, custom_arg_serializer) Wraps a function in a version where computation of the first argument via fwd_fn is cached.