def implicit_hessian(
k_fn: Callable,
z: JAXArray,
*params: JAXArray,
nondiff_kw: Mapping = None,
Dg: JAXArray = None,
Hg: JAXArray = None,
jvp_vec: Union[JAXArray, Sequence[JAXArray]] = None,
optimizations: Mapping = None,
):
"""Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec.
Args:
k_fn: k_fn(z, *params) = 0, lower/inner implicit function
z: the optimal embedding variable value array
*params: the parameters p of the bilevel program
nondiff_kw: nondifferentiable parameters to the implicit function
Dg: gradient sensitivity vector (wrt z), for chain rule
Hg: Hessian sensitivity vector (wrt z), for chain rule
jvp_vec: right sensitivity vector(s) (wrt p) for Hessian-vector-product
optimizations: optional optimizations
Returns:
Hessian/chain rule Hessian as specified by arguments
"""
optimizations = {} if optimizations is None else copy(optimizations)
nondiff_kw = {} if nondiff_kw is None else nondiff_kw
zlen, plen = prod(z.shape), [prod(param.shape) for param in params]
jvp_vec = _ensure_list(jvp_vec) if jvp_vec is not None else jvp_vec
if jvp_vec is not None:
assert Dg is not None
if nondiff_kw is not None:
k_fn_ = lambda z, *params: k_fn(z, *params, **nondiff_kw)
else:
k_fn_ = k_fn
# construct a default Dzk_solve_fn ##########################
if optimizations.get("Dzk_solve_fn", None) is None:
_generate_default_Dzk_solve_fn(optimizations, k_fn_)
#############################################################
# compute 2nd implicit gradients
if Dg is not None:
assert Dg.size == zlen
assert Hg is None or Hg.size == zlen**2
Dg_ = Dg.reshape((zlen, 1))
Hg_ = Hg.reshape((zlen, zlen)) if Hg is not None else Hg
# compute the left hand vector in the VJP
Dzk_solve_fn = optimizations["Dzk_solve_fn"]
v = -Dzk_solve_fn(z, *params, rhs=Dg_.reshape((zlen, 1)), T=True)
fn = lambda z, *params: jaxm.sum(v.reshape(zlen) * k_fn_(z, *params).reshape(zlen))
if jvp_vec is not None:
Dpz_jvp = _ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
nondiff_kw=nondiff_kw,
jvp_vec=jvp_vec,
optimizations=optimizations,
)
)
Dpz_jvp = [Dpz_jvp.reshape(-1) for Dpz_jvp in Dpz_jvp]
# compute the 2nd order derivatives consisting of 4 terms
# term 1 ##############################
dfn_params = jaxm.grad(lambda *params: fn(z, *params), argnums=range(len(params)))
Dpp1 = _ensure_list(jaxm.jvp(dfn_params, params, tuple(jvp_vec))[1])
Dpp1 = [Dpp1.reshape(plen) for (Dpp1, plen) in zip(Dpp1, plen)]
# term 2 ##############################
Dpp2 = [
jaxm.jvp(
lambda z: jaxm.grad(fn, argnums=i + 1)(z, *params),
(z,),
(Dpz_jvp.reshape(z.shape),),
)[1].reshape(plen)
for (i, (Dpz_jvp, plen)) in enumerate(zip(Dpz_jvp, plen))
]
# term 3 ##############################
g_ = _ensure_list(
jaxm.jvp(
lambda *params: jaxm.grad(fn)(z, *params),
params,
tuple(jvp_vec),
)[1]
)
Dpp3 = [
_ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
nondiff_kw=nondiff_kw,
Dg=g_,
optimizations=optimizations,
)
)[i].reshape(-1)
for (i, g_) in enumerate(g_)
]
# term 4 ##############################
g_ = [
jaxm.jvp(
lambda z: jaxm.grad(fn)(z, *params),
(z,),
(Dpz_jvp.reshape(z.shape),),
)[1]
for (i, Dpz_jvp) in enumerate(Dpz_jvp)
]
if Hg is not None:
g_ = [
g_.reshape(zlen) + Hg_ @ Dpz_jvp.reshape(zlen)
for (g_, Dpz_jvp) in zip(g_, Dpz_jvp)
]
Dpp4 = [
_ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
Dg=g_,
nondiff_kw=nondiff_kw,
optimizations=optimizations,
)
)[i].reshape(plen)
for ((i, g_), plen) in zip(enumerate(g_), plen)
]
Dp = [Dg_.reshape((1, zlen)) @ Dpz_jvp.reshape(zlen) for Dpz_jvp in Dpz_jvp]
Dpp = [sum(Dpp) for Dpp in zip(Dpp1, Dpp2, Dpp3, Dpp4)]
# return the results
Dp_shaped = [Dp.reshape(()) for Dp in Dp]
Dpp_shaped = [Dpp.reshape(param.shape) for (Dpp, param) in zip(Dpp, params)]
else:
# compute the full first order 1st gradients
Dpz = _ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
nondiff_kw=nondiff_kw,
optimizations=optimizations,
)
)
Dpz = [Dpz.reshape((zlen, plen)) for (Dpz, plen) in zip(Dpz, plen)]
# compute the 2nd order derivatives consisting of 4 terms
Dpp1 = HESSIAN_DIAG(lambda *params: fn(z, *params))(*params)
Dpp1 = [Dpp1.reshape((plen, plen)) for (Dpp1, plen) in zip(Dpp1, plen)]
temp = JACOBIAN(
lambda *params: JACOBIAN(fn)(z, *params),
argnums=range(len(params)),
)(*params)
temp = [jaxm.t(temp.reshape((zlen, plen))) for (temp, plen) in zip(temp, plen)]
Dpp2 = [
(temp @ Dpz).reshape((plen, plen)) for (temp, Dpz, plen) in zip(temp, Dpz, plen)
]
Dpp3 = [jaxm.t(Dpp2) for Dpp2 in Dpp2]
Dzz = HESSIAN(lambda z: fn(z, *params))(z).reshape((zlen, zlen))
if Hg is not None:
Dpp4 = [jaxm.t(Dpz) @ (Hg_ + Dzz) @ Dpz for Dpz in Dpz]
else:
Dpp4 = [jaxm.t(Dpz) @ Dzz @ Dpz for Dpz in Dpz]
Dp = [Dg_.reshape((1, zlen)) @ Dpz for Dpz in Dpz]
Dpp = [sum(Dpp) for Dpp in zip(Dpp1, Dpp2, Dpp3, Dpp4)]
# return the results
Dp_shaped = [Dp.reshape(param.shape) for (Dp, param) in zip(Dp, params)]
Dpp_shaped = [
Dpp.reshape(param.shape + param.shape) for (Dpp, param) in zip(Dpp, params)
]
return (Dp_shaped[0], Dpp_shaped[0]) if len(params) == 1 else (Dp_shaped, Dpp_shaped)
else:
Dpz, optimizations = implicit_jacobian(
k_fn,
z,
*params,
nondiff_kw=nondiff_kw,
full_output=True,
optimizations=optimizations,
)
Dpz = _ensure_list(Dpz)
Dpz = [Dpz.reshape(zlen, plen) for (Dpz, plen) in zip(Dpz, plen)]
# compute derivatives
if optimizations.get("Dzzk", None) is None:
Hk = HESSIAN_DIAG(k_fn_)(z, *params)
Dzzk, Dppk = Hk[0], Hk[1:]
optimizations["Dzzk"] = Dzzk
else:
Dppk = HESSIAN_DIAG(lambda *params: k_fn_(z, *params))(*params)
Dzpk = JACOBIAN(
lambda *params: JACOBIAN(k_fn_)(z, *params),
argnums=range(len(params)),
)(*params)
Dppk = [Dppk.reshape((zlen, plen, plen)) for (Dppk, plen) in zip(Dppk, plen)]
Dzzk = Dzzk.reshape((zlen, zlen, zlen))
Dzpk = [Dzpk.reshape((zlen, zlen, plen)) for (Dzpk, plen) in zip(Dzpk, plen)]
Dpzk = [jaxm.t(Dzpk) for Dzpk in Dzpk]
# solve the IFT equation
lhs = [
Dppk
+ Dpzk @ Dpz[None, ...]
+ jaxm.t(Dpz)[None, ...] @ Dzpk
+ (jaxm.t(Dpz)[None, ...] @ Dzzk) @ Dpz[None, ...]
for (Dpz, Dzpk, Dpzk, Dppk) in zip(Dpz, Dzpk, Dpzk, Dppk)
]
Dzk_solve_fn = optimizations["Dzk_solve_fn"]
Dppz = [
-Dzk_solve_fn(z, *params, rhs=lhs.reshape((zlen, plen * plen)), T=False).reshape(
(zlen, plen, plen)
)
for (lhs, plen) in zip(lhs, plen)
]
# return computed values
Dpz_shaped = [Dpz.reshape(z.shape + param.shape) for (Dpz, param) in zip(Dpz, params)]
Dppz_shaped = [
Dppz.reshape(z.shape + param.shape + param.shape) for (Dppz, param) in zip(Dppz, params)
]
return (Dpz_shaped[0], Dppz_shaped[0]) if len(params) == 1 else (Dpz_shaped, Dppz_shaped)