sensitivity_jax.batch_sensitivity.implicit_hessian2(k_fn, z, *params, nondiff_kw=None, Dg=None, Hg=None, jvp_vec=None, optimizations=None)
Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec, using vmap.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
k_fn |
Callable
|
k_fn(z, *params) = 0, lower/inner implicit function |
required |
z |
JAXArray
|
the optimal embedding variable value array |
required |
*params |
JAXArray
|
the parameters p of the bilevel program |
()
|
Dg |
JAXArray
|
gradient sensitivity vector (wrt z), for chain rule |
None
|
Hg |
JAXArray
|
Hessian sensitivity vector (wrt z), for chain rule |
None
|
jvp_vec |
Union[JAXArray, Sequence[JAXArray]]
|
right sensitivity vector(s) (wrt p) for Hessian-vector-product |
None
|
optimizations |
Mapping
|
optional optimizations |
None
|
Returns:
Type | Description |
---|---|
Hessian/chain rule Hessian as specified by arguments |