Skip to content

sensitivity_jax.sensitivity.implicit_hessian(k_fn, z, *params, nondiff_kw=None, Dg=None, Hg=None, jvp_vec=None, optimizations=None)

Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec.

Parameters:

Name Type Description Default
k_fn Callable

k_fn(z, *params) = 0, lower/inner implicit function

required
z JAXArray

the optimal embedding variable value array

required
*params JAXArray

the parameters p of the bilevel program

()
nondiff_kw Mapping

nondifferentiable parameters to the implicit function

None
Dg JAXArray

gradient sensitivity vector (wrt z), for chain rule

None
Hg JAXArray

Hessian sensitivity vector (wrt z), for chain rule

None
jvp_vec Union[JAXArray, Sequence[JAXArray]]

right sensitivity vector(s) (wrt p) for Hessian-vector-product

None
optimizations Mapping

optional optimizations

None

Returns:

Type Description

Hessian/chain rule Hessian as specified by arguments

Source code in sensitivity_jax/sensitivity.py
def implicit_hessian(
    k_fn: Callable,
    z: JAXArray,
    *params: JAXArray,
    nondiff_kw: Mapping = None,
    Dg: JAXArray = None,
    Hg: JAXArray = None,
    jvp_vec: Union[JAXArray, Sequence[JAXArray]] = None,
    optimizations: Mapping = None,
):
    """Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec.

    Args:
        k_fn: k_fn(z, *params) = 0, lower/inner implicit function
        z: the optimal embedding variable value array
        *params: the parameters p of the bilevel program
        nondiff_kw: nondifferentiable parameters to the implicit function
        Dg: gradient sensitivity vector (wrt z), for chain rule
        Hg: Hessian sensitivity vector (wrt z), for chain rule
        jvp_vec: right sensitivity vector(s) (wrt p) for Hessian-vector-product
        optimizations: optional optimizations
    Returns:
        Hessian/chain rule Hessian as specified by arguments
    """
    optimizations = {} if optimizations is None else copy(optimizations)
    nondiff_kw = {} if nondiff_kw is None else nondiff_kw

    zlen, plen = prod(z.shape), [prod(param.shape) for param in params]
    jvp_vec = _ensure_list(jvp_vec) if jvp_vec is not None else jvp_vec
    if jvp_vec is not None:
        assert Dg is not None

    if nondiff_kw is not None:
        k_fn_ = lambda z, *params: k_fn(z, *params, **nondiff_kw)
    else:
        k_fn_ = k_fn

    # construct a default Dzk_solve_fn ##########################
    if optimizations.get("Dzk_solve_fn", None) is None:
        _generate_default_Dzk_solve_fn(optimizations, k_fn_)
    #############################################################

    # compute 2nd implicit gradients
    if Dg is not None:
        assert Dg.size == zlen
        assert Hg is None or Hg.size == zlen**2

        Dg_ = Dg.reshape((zlen, 1))
        Hg_ = Hg.reshape((zlen, zlen)) if Hg is not None else Hg

        # compute the left hand vector in the VJP
        Dzk_solve_fn = optimizations["Dzk_solve_fn"]
        v = -Dzk_solve_fn(z, *params, rhs=Dg_.reshape((zlen, 1)), T=True)
        fn = lambda z, *params: jaxm.sum(v.reshape(zlen) * k_fn_(z, *params).reshape(zlen))

        if jvp_vec is not None:
            Dpz_jvp = _ensure_list(
                implicit_jacobian(
                    k_fn,
                    z,
                    *params,
                    nondiff_kw=nondiff_kw,
                    jvp_vec=jvp_vec,
                    optimizations=optimizations,
                )
            )
            Dpz_jvp = [Dpz_jvp.reshape(-1) for Dpz_jvp in Dpz_jvp]

            # compute the 2nd order derivatives consisting of 4 terms
            # term 1 ##############################
            dfn_params = jaxm.grad(lambda *params: fn(z, *params), argnums=range(len(params)))
            Dpp1 = _ensure_list(jaxm.jvp(dfn_params, params, tuple(jvp_vec))[1])
            Dpp1 = [Dpp1.reshape(plen) for (Dpp1, plen) in zip(Dpp1, plen)]

            # term 2 ##############################
            Dpp2 = [
                jaxm.jvp(
                    lambda z: jaxm.grad(fn, argnums=i + 1)(z, *params),
                    (z,),
                    (Dpz_jvp.reshape(z.shape),),
                )[1].reshape(plen)
                for (i, (Dpz_jvp, plen)) in enumerate(zip(Dpz_jvp, plen))
            ]

            # term 3 ##############################
            g_ = _ensure_list(
                jaxm.jvp(
                    lambda *params: jaxm.grad(fn)(z, *params),
                    params,
                    tuple(jvp_vec),
                )[1]
            )
            Dpp3 = [
                _ensure_list(
                    implicit_jacobian(
                        k_fn,
                        z,
                        *params,
                        nondiff_kw=nondiff_kw,
                        Dg=g_,
                        optimizations=optimizations,
                    )
                )[i].reshape(-1)
                for (i, g_) in enumerate(g_)
            ]

            # term 4 ##############################
            g_ = [
                jaxm.jvp(
                    lambda z: jaxm.grad(fn)(z, *params),
                    (z,),
                    (Dpz_jvp.reshape(z.shape),),
                )[1]
                for (i, Dpz_jvp) in enumerate(Dpz_jvp)
            ]
            if Hg is not None:
                g_ = [
                    g_.reshape(zlen) + Hg_ @ Dpz_jvp.reshape(zlen)
                    for (g_, Dpz_jvp) in zip(g_, Dpz_jvp)
                ]
            Dpp4 = [
                _ensure_list(
                    implicit_jacobian(
                        k_fn,
                        z,
                        *params,
                        Dg=g_,
                        nondiff_kw=nondiff_kw,
                        optimizations=optimizations,
                    )
                )[i].reshape(plen)
                for ((i, g_), plen) in zip(enumerate(g_), plen)
            ]
            Dp = [Dg_.reshape((1, zlen)) @ Dpz_jvp.reshape(zlen) for Dpz_jvp in Dpz_jvp]
            Dpp = [sum(Dpp) for Dpp in zip(Dpp1, Dpp2, Dpp3, Dpp4)]

            # return the results
            Dp_shaped = [Dp.reshape(()) for Dp in Dp]
            Dpp_shaped = [Dpp.reshape(param.shape) for (Dpp, param) in zip(Dpp, params)]
        else:
            # compute the full first order 1st gradients
            Dpz = _ensure_list(
                implicit_jacobian(
                    k_fn,
                    z,
                    *params,
                    nondiff_kw=nondiff_kw,
                    optimizations=optimizations,
                )
            )
            Dpz = [Dpz.reshape((zlen, plen)) for (Dpz, plen) in zip(Dpz, plen)]

            # compute the 2nd order derivatives consisting of 4 terms
            Dpp1 = HESSIAN_DIAG(lambda *params: fn(z, *params))(*params)
            Dpp1 = [Dpp1.reshape((plen, plen)) for (Dpp1, plen) in zip(Dpp1, plen)]

            temp = JACOBIAN(
                lambda *params: JACOBIAN(fn)(z, *params),
                argnums=range(len(params)),
            )(*params)
            temp = [jaxm.t(temp.reshape((zlen, plen))) for (temp, plen) in zip(temp, plen)]
            Dpp2 = [
                (temp @ Dpz).reshape((plen, plen)) for (temp, Dpz, plen) in zip(temp, Dpz, plen)
            ]
            Dpp3 = [jaxm.t(Dpp2) for Dpp2 in Dpp2]
            Dzz = HESSIAN(lambda z: fn(z, *params))(z).reshape((zlen, zlen))
            if Hg is not None:
                Dpp4 = [jaxm.t(Dpz) @ (Hg_ + Dzz) @ Dpz for Dpz in Dpz]
            else:
                Dpp4 = [jaxm.t(Dpz) @ Dzz @ Dpz for Dpz in Dpz]
            Dp = [Dg_.reshape((1, zlen)) @ Dpz for Dpz in Dpz]
            Dpp = [sum(Dpp) for Dpp in zip(Dpp1, Dpp2, Dpp3, Dpp4)]

            # return the results
            Dp_shaped = [Dp.reshape(param.shape) for (Dp, param) in zip(Dp, params)]
            Dpp_shaped = [
                Dpp.reshape(param.shape + param.shape) for (Dpp, param) in zip(Dpp, params)
            ]
        return (Dp_shaped[0], Dpp_shaped[0]) if len(params) == 1 else (Dp_shaped, Dpp_shaped)
    else:
        Dpz, optimizations = implicit_jacobian(
            k_fn,
            z,
            *params,
            nondiff_kw=nondiff_kw,
            full_output=True,
            optimizations=optimizations,
        )
        Dpz = _ensure_list(Dpz)
        Dpz = [Dpz.reshape(zlen, plen) for (Dpz, plen) in zip(Dpz, plen)]

        # compute derivatives
        if optimizations.get("Dzzk", None) is None:
            Hk = HESSIAN_DIAG(k_fn_)(z, *params)
            Dzzk, Dppk = Hk[0], Hk[1:]
            optimizations["Dzzk"] = Dzzk
        else:
            Dppk = HESSIAN_DIAG(lambda *params: k_fn_(z, *params))(*params)
        Dzpk = JACOBIAN(
            lambda *params: JACOBIAN(k_fn_)(z, *params),
            argnums=range(len(params)),
        )(*params)
        Dppk = [Dppk.reshape((zlen, plen, plen)) for (Dppk, plen) in zip(Dppk, plen)]
        Dzzk = Dzzk.reshape((zlen, zlen, zlen))
        Dzpk = [Dzpk.reshape((zlen, zlen, plen)) for (Dzpk, plen) in zip(Dzpk, plen)]
        Dpzk = [jaxm.t(Dzpk) for Dzpk in Dzpk]

        # solve the IFT equation
        lhs = [
            Dppk
            + Dpzk @ Dpz[None, ...]
            + jaxm.t(Dpz)[None, ...] @ Dzpk
            + (jaxm.t(Dpz)[None, ...] @ Dzzk) @ Dpz[None, ...]
            for (Dpz, Dzpk, Dpzk, Dppk) in zip(Dpz, Dzpk, Dpzk, Dppk)
        ]
        Dzk_solve_fn = optimizations["Dzk_solve_fn"]
        Dppz = [
            -Dzk_solve_fn(z, *params, rhs=lhs.reshape((zlen, plen * plen)), T=False).reshape(
                (zlen, plen, plen)
            )
            for (lhs, plen) in zip(lhs, plen)
        ]

        # return computed values
        Dpz_shaped = [Dpz.reshape(z.shape + param.shape) for (Dpz, param) in zip(Dpz, params)]
        Dppz_shaped = [
            Dppz.reshape(z.shape + param.shape + param.shape) for (Dppz, param) in zip(Dppz, params)
        ]
        return (Dpz_shaped[0], Dppz_shaped[0]) if len(params) == 1 else (Dpz_shaped, Dppz_shaped)