Skip to content

sensitivity_jax.batch_sensitivity.implicit_hessian2(k_fn, z, *params, nondiff_kw=None, Dg=None, Hg=None, jvp_vec=None, optimizations=None)

Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec, using vmap.

Parameters:

Name Type Description Default
k_fn Callable

k_fn(z, *params) = 0, lower/inner implicit function

required
z JAXArray

the optimal embedding variable value array

required
*params JAXArray

the parameters p of the bilevel program

()
Dg JAXArray

gradient sensitivity vector (wrt z), for chain rule

None
Hg JAXArray

Hessian sensitivity vector (wrt z), for chain rule

None
jvp_vec Union[JAXArray, Sequence[JAXArray]]

right sensitivity vector(s) (wrt p) for Hessian-vector-product

None
optimizations Mapping

optional optimizations

None

Returns:

Type Description

Hessian/chain rule Hessian as specified by arguments

Source code in sensitivity_jax/batch_sensitivity.py
def implicit_hessian2(
    k_fn: Callable,
    z: JAXArray,
    *params: JAXArray,
    nondiff_kw: Mapping = None,
    Dg: JAXArray = None,
    Hg: JAXArray = None,
    jvp_vec: Union[JAXArray, Sequence[JAXArray]] = None,
    optimizations: Mapping = None,
):
    """Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec,
    using vmap.

    Args:
        k_fn: k_fn(z, *params) = 0, lower/inner implicit function
        z: the optimal embedding variable value array
        *params: the parameters p of the bilevel program
        Dg: gradient sensitivity vector (wrt z), for chain rule
        Hg: Hessian sensitivity vector (wrt z), for chain rule
        jvp_vec: right sensitivity vector(s) (wrt p) for Hessian-vector-product
        optimizations: optional optimizations
    Returns:
        Hessian/chain rule Hessian as specified by arguments
    """
    # we need a custom rule for optimizations
    optimizations = {} if optimizations is None else optimizations
    nondiff_kw = {} if nondiff_kw is None else nondiff_kw
    optimizations = _split_for_broadcast(optimizations, z.shape[0])
    nondiff_kw = _split_for_broadcast(nondiff_kw, z.shape[0])

    # call the function with vmap
    return jaxm.vmap(
        lambda z, *params, nondiff_kw_=None, Dg=None, Hg=None, jvp_vec=None, optimizations_=None: implicit_hessian_(
            k_fn,
            z,
            *params,
            nondiff_kw=None
            if nondiff_kw is None
            else dict(nondiff_kw_, **nondiff_kw[1]),
            Dg=Dg,
            Hg=Hg,
            jvp_vec=jvp_vec,
            optimizations=None
            if optimizations is None
            else dict(optimizations_, **optimizations[1]),
        )
    )(
        z,
        *params,
        nondiff_kw_=None if nondiff_kw is None else nondiff_kw[0],
        Dg=Dg,
        Hg=Hg,
        jvp_vec=jvp_vec,
        optimizations_=None if optimizations is None else optimizations[0],
    )