Skip to content

sensitivity_jax.sensitivity.generate_optimization_fns(loss_fn, opt_fn, k_fn, normalize_grad=False, optimizations=None, jit=True, custom_arg_serializer=None)

Directly generates upper/outer bilevel program derivative functions.

Parameters:

Name Type Description Default
loss_fn Callable

loss_fn(z, *params), upper/outer level loss

required
opt_fn Callable

opt_fn(*params) = z, lower/inner argmin function

required
k_fn Callable

k_fn(z, *params) = 0, lower/inner implicit function

required
normalize_grad bool

whether to normalize the gradient by its norm

False
jit bool

whether to apply just-in-time (jit) compilation to the functions

True

Returns:

Type Description

f_fn(*params), g_fn(*params), h_fn(*params)

parameters-only upper/outer level loss, gradient and Hessian.

Source code in sensitivity_jax/sensitivity.py
def generate_optimization_fns(
    loss_fn: Callable,
    opt_fn: Callable,
    k_fn: Callable,
    normalize_grad: bool = False,
    optimizations: Mapping = None,
    jit: bool = True,
    custom_arg_serializer: Optional[Callable] = None,
):
    """Directly generates upper/outer bilevel program derivative functions.

    Args:
        loss_fn: loss_fn(z, *params), upper/outer level loss
        opt_fn: opt_fn(*params) = z, lower/inner argmin function
        k_fn: k_fn(z, *params) = 0, lower/inner implicit function
        normalize_grad: whether to normalize the gradient by its norm
        jit: whether to apply just-in-time (jit) compilation to the functions
    Returns:
        ``f_fn(*params), g_fn(*params), h_fn(*params)``
        parameters-only upper/outer level loss, gradient and Hessian.
    """
    sol_cache = {}
    optimizations = {} if optimizations is None else copy(optimizations)

    @fn_with_sol_cache(opt_fn, sol_cache, jit=jit, custom_arg_serializer=custom_arg_serializer)
    def f_fn(z, *params, **nondiff_kw):
        return loss_fn(z, *params, **nondiff_kw)

    @fn_with_sol_cache(opt_fn, sol_cache, jit=jit, custom_arg_serializer=custom_arg_serializer)
    def g_fn(z, *params, **nondiff_kw):
        g = JACOBIAN(loss_fn, argnums=range(len(params) + 1))(z, *params, **nondiff_kw)
        Dp = implicit_jacobian(
            k_fn,
            z,
            *params,
            nondiff_kw=None if len(nondiff_kw) == 0 else nondiff_kw,
            Dg=g[0],
            optimizations=optimizations,
        )
        Dp = Dp if len(params) != 1 else [Dp]
        ret = [Dp + g for (Dp, g) in zip(Dp, g[1:])]
        if normalize_grad:
            ret = [(z / (jaxm.norm(z) + 1e-7)) for z in ret]
        return ret[0] if len(ret) == 1 else ret

    @fn_with_sol_cache(opt_fn, sol_cache, jit=jit, custom_arg_serializer=custom_arg_serializer)
    def h_fn(z, *params, **nondiff_kw):
        g = JACOBIAN(loss_fn, argnums=range(len(params) + 1))(z, *params, **nondiff_kw)

        if optimizations.get("Hz_fn", None) is None:
            optimizations["Hz_fn"] = jaxm.hessian(loss_fn)
        Hz_fn = optimizations["Hz_fn"]
        Hz = Hz_fn(z, *params, **nondiff_kw)
        H = [Hz] + HESSIAN_DIAG(lambda *params: loss_fn(z, *params, **nondiff_kw))(*params)

        _, Dpp = implicit_hessian(
            k_fn,
            z,
            *params,
            nondiff_kw=None if len(nondiff_kw) == 0 else nondiff_kw,
            Dg=g[0],
            Hg=H[0],
            optimizations=optimizations,
        )
        Dpp = Dpp if len(params) != 1 else [Dpp]
        ret = [Dpp + H for (Dpp, H) in zip(Dpp, H[1:])]
        return ret[0] if len(ret) == 1 else ret

    return (f_fn, g_fn, h_fn)