sensitivity_jax.sensitivity.generate_optimization_fns(loss_fn, opt_fn, k_fn, normalize_grad=False, optimizations=None, jit=True, custom_arg_serializer=None)
Directly generates upper/outer bilevel program derivative functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn |
Callable
|
loss_fn(z, *params), upper/outer level loss |
required |
opt_fn |
Callable
|
opt_fn(*params) = z, lower/inner argmin function |
required |
k_fn |
Callable
|
k_fn(z, *params) = 0, lower/inner implicit function |
required |
normalize_grad |
bool
|
whether to normalize the gradient by its norm |
False
|
jit |
bool
|
whether to apply just-in-time (jit) compilation to the functions |
True
|
Returns:
Type | Description |
---|---|
|
|
parameters-only upper/outer level loss, gradient and Hessian. |