def implicit_hessian(
k_fn: Callable,
z: Tensor,
*params: Tensor,
Dg: Tensor = None,
Hg: Tensor = None,
jvp_vec: Union[Tensor, Sequence[Tensor]] = None,
optimizations: Mapping = None,
):
"""Computes the implicit Hessian or chain rule depending on Dg, Hg, jvp_vec.
Args:
k_fn: k_fn(z, *params) = 0, lower/inner implicit function
z: the optimal embedding variable value array
*params: the parameters p of the bilevel program
Dg: gradient sensitivity vector (wrt z), for chain rule
Hg: Hessian sensitivity vector (wrt z), for chain rule
jvp_vec: right sensitivity vector(s) (wrt p) for Hessian-vector-product
optimizations: optional optimizations
Returns:
Hessian/chain rule Hessian as specified by arguments
"""
optimizations = {} if optimizations is None else copy(optimizations)
zlen, plen = _prod(z.shape), [_prod(param.shape) for param in params]
jvp_vec = _ensure_list(jvp_vec) if jvp_vec is not None else jvp_vec
if jvp_vec is not None:
assert Dg is not None
# construct a default Dzk_solve_fn ##########################
if optimizations.get("Dzk_solve_fn", None) is None:
_generate_default_Dzk_solve_fn(optimizations, k_fn)
#############################################################
# compute 2nd implicit gradients
if Dg is not None:
assert Dg.numel() == zlen
assert Hg is None or Hg.numel() == zlen**2
Dg_ = Dg.reshape((zlen, 1))
Hg_ = Hg.reshape((zlen, zlen)) if Hg is not None else Hg
# compute the left hand vector in the VJP
Dzk_solve_fn = optimizations["Dzk_solve_fn"]
v = -Dzk_solve_fn(z, *params, rhs=Dg_.reshape((zlen, 1)), T=True)
v = v.detach()
fn = lambda z, *params: torch.sum(v.reshape(zlen) * k_fn(z, *params).reshape(zlen))
if jvp_vec is not None:
for param in params:
param.requires_grad = True
z.requires_grad = True
Dpz_jvp = _ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
jvp_vec=jvp_vec,
optimizations=optimizations,
)
)
Dpz_jvp = [Dpz_jvp.reshape(-1).detach() for Dpz_jvp in Dpz_jvp]
# compute the 2nd order derivatives consisting of 4 terms
# term 1 ##############################
# Dpp1 = HESSIAN_DIAG(lambda *params: fn(z, *params), *params)
g_ = grad(fn(z, *params), params, create_graph=True)
Dpp1 = [
fwd_grad(g_, param, grad_inputs=jvp_vec).reshape(plen)
for (g_, param, jvp_vec) in zip(g_, params, jvp_vec)
]
# term 2 ##############################
# temp = JACOBIAN(
# lambda z: JACOBIAN(
# lambda *params: fn(z, *params), *params, create_graph=True
# ),
# z,
# )
# temp = [temp] if len(params) == 1 else temp
# temp = [
# temp.reshape((plen, zlen)) for (temp, plen) in zip(temp, plen)
# ]
# Dpp2 = [
# (temp @ Dpz).reshape((plen, plen))
# for (temp, Dpz, plen) in zip(temp, Dpz, plen)
# ]
g_ = grad(fn(z, *params), params, create_graph=True)
Dpp2 = [
fwd_grad(g_, z, grad_inputs=Dpz_jvp.reshape(z.shape)).reshape(-1)
for (g_, Dpz_jvp) in zip(g_, Dpz_jvp)
]
# term 3 ##############################
# Dpp3 = [t(Dpp2) for Dpp2 in Dpp2]
g_ = grad(fn(z, *params), z, create_graph=True)
g_ = [
fwd_grad(g_, param, grad_inputs=jvp_vec)
for (param, jvp_vec) in zip(params, jvp_vec)
]
Dpp3 = [
_ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
Dg=g_,
optimizations=optimizations,
)
)[i].reshape(-1)
for (i, g_) in enumerate(g_)
]
# term 4 ##############################
# Dzz = HESSIAN(lambda z: fn(z, *params), z).reshape((zlen, zlen))
# if Hg is not None:
# Dpp4 = [t(Dpz) @ (Hg_ + Dzz) @ Dpz for Dpz in Dpz]
# else:
# Dpp4 = [t(Dpz) @ Dzz @ Dpz for Dpz in Dpz]
g_ = grad(fn(z, *params), z, create_graph=True)
g_ = [fwd_grad(g_, z, grad_inputs=Dpz_jvp.reshape(z.shape)) for Dpz_jvp in Dpz_jvp]
if Hg is not None:
g_ = [
g_.reshape(zlen) + Hg_ @ Dpz_jvp.reshape(zlen)
for (g_, Dpz_jvp) in zip(g_, Dpz_jvp)
]
Dpp4 = [
_ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
Dg=g_,
optimizations=optimizations,
)
)[i].reshape(plen)
for ((i, g_), plen) in zip(enumerate(g_), plen)
]
Dp = [Dg_.reshape((1, zlen)) @ Dpz_jvp.reshape(zlen) for Dpz_jvp in Dpz_jvp]
Dpp = [sum(Dpp) for Dpp in zip(Dpp1, Dpp2, Dpp3, Dpp4)]
# return the results
Dp_shaped = [Dp.reshape(()) for Dp in Dp]
Dpp_shaped = [Dpp.reshape(param.shape) for (Dpp, param) in zip(Dpp, params)]
else:
# compute the full first order 1st gradients
Dpz = _ensure_list(
implicit_jacobian(
k_fn,
z,
*params,
optimizations=optimizations,
)
)
Dpz = [Dpz.reshape((zlen, plen)).detach() for (Dpz, plen) in zip(Dpz, plen)]
# compute the 2nd order derivatives consisting of 4 terms
Dpp1 = HESSIAN_DIAG(lambda *params: fn(z, *params), params)
Dpp1 = [Dpp1.reshape((plen, plen)) for (Dpp1, plen) in zip(Dpp1, plen)]
# temp = JACOBIAN(
# lambda z: JACOBIAN(
# lambda *params: fn(z, *params), params, create_graph=True
# ),
# z,
# )
temp = JACOBIAN(
lambda *params: JACOBIAN(lambda z: fn(z, *params), z, create_graph=True),
params,
)
temp = [
temp.reshape((zlen, plen)).transpose(-1, -2) for (temp, plen) in zip(temp, plen)
]
Dpp2 = [
(temp @ Dpz).reshape((plen, plen)) for (temp, Dpz, plen) in zip(temp, Dpz, plen)
]
Dpp3 = [t(Dpp2) for Dpp2 in Dpp2]
Dzz = HESSIAN(lambda z: fn(z, *params), z).reshape((zlen, zlen))
if Hg is not None:
Dpp4 = [t(Dpz) @ (Hg_ + Dzz) @ Dpz for Dpz in Dpz]
else:
Dpp4 = [t(Dpz) @ Dzz @ Dpz for Dpz in Dpz]
Dp = [Dg_.reshape((1, zlen)) @ Dpz for Dpz in Dpz]
Dpp = [sum(Dpp) for Dpp in zip(Dpp1, Dpp2, Dpp3, Dpp4)]
# return the results
Dp_shaped = [Dp.reshape(param.shape) for (Dp, param) in zip(Dp, params)]
Dpp_shaped = [
Dpp.reshape(param.shape + param.shape) for (Dpp, param) in zip(Dpp, params)
]
return (Dp_shaped[0], Dpp_shaped[0]) if len(params) == 1 else (Dp_shaped, Dpp_shaped)
else:
Dpz, optimizations = implicit_jacobian(
k_fn,
z,
*params,
full_output=True,
optimizations=optimizations,
)
Dpz = _ensure_list(Dpz)
Dpz = [Dpz.reshape(zlen, plen) for (Dpz, plen) in zip(Dpz, plen)]
# compute derivatives
if optimizations.get("Dzzk", None) is None:
Hk = HESSIAN_DIAG(k_fn, (z, *params))
Dzzk, Dppk = Hk[0], Hk[1:]
optimizations["Dzzk"] = Dzzk
else:
Dppk = HESSIAN_DIAG(lambda *params: k_fn(z, *params), params)
Dzpk = JACOBIAN(
lambda *params: JACOBIAN(lambda z: k_fn(z, *params), z, create_graph=True),
params,
)
Dppk = [Dppk.reshape((zlen, plen, plen)) for (Dppk, plen) in zip(Dppk, plen)]
Dzzk = Dzzk.reshape((zlen, zlen, zlen))
Dzpk = [Dzpk.reshape((zlen, zlen, plen)) for (Dzpk, plen) in zip(Dzpk, plen)]
Dpzk = [Dzpk.transpose(-1, -2) for Dzpk in Dzpk]
# solve the IFT equation
lhs = [
Dppk
+ Dpzk @ Dpz[None, ...]
+ t(Dpz)[None, ...] @ Dzpk
+ (t(Dpz)[None, ...] @ Dzzk) @ Dpz[None, ...]
for (Dpz, Dzpk, Dpzk, Dppk) in zip(Dpz, Dzpk, Dpzk, Dppk)
]
Dzk_solve_fn = optimizations["Dzk_solve_fn"]
Dppz = [
-Dzk_solve_fn(z, *params, rhs=lhs.reshape((zlen, plen * plen)), T=False).reshape(
(zlen, plen, plen)
)
for (lhs, plen) in zip(lhs, plen)
]
# return computed values
Dpz_shaped = [Dpz.reshape(z.shape + param.shape) for (Dpz, param) in zip(Dpz, params)]
Dppz_shaped = [
Dppz.reshape(z.shape + param.shape + param.shape) for (Dppz, param) in zip(Dppz, params)
]
return (Dpz_shaped[0], Dppz_shaped[0]) if len(params) == 1 else (Dpz_shaped, Dppz_shaped)