Skip to content

sensitivity_torch.extras.optimization.minimize_sqp(f_fn, g_fn, h_fn, *args, reg0=1e-07, verbose=False, verbose_prefix='', max_it=100, ls_pts_nb=5, force_step=False, batched=False, full_output=False, callback_fn=None, use_writer=False, use_tqdm=True)

Minimize a loss function f_fn with Unconstrained Sequential Quadratic Programming (SQP) with respect to a single arg.

Parameters:

Name Type Description Default
f_fn Callable

loss function

required
g_fn Callable

gradient of the loss function

required
h_fn Callable

Hessian of the loss function

required
*args Tensor

arguments to be optimized

()
reg0 float

Hessian regularization – optimization step regularization

1e-07
verbose bool

whether to print output

False
verbose_prefix str

prefix to append to verbose output, e.g. indentation

''
max_it int

maximum number of iterates

100
batched bool

whether to optimize a batch of arguments, with batch of losses

False
ls_pts_nb int

number of linesearch points to consider per optimization step

5
force_step bool

whether to take any non-zero optimization step even if worse

False
full_output bool

whether to output optimization history

False
callback_fn Callable

callback function of the form cb_fn(*args, **kw)

None
use_writer bool

whether to use tensorflow's Summary Writer (via PyTorch)

False
use_tqdm Union[bool, tqdm_module.std.tqdm, tqdm_module.notebook.tqdm_notebook]

whether to use tqdm (to estimate total runtime)

True

Returns:

Type Description

Optimized args or (args, args_hist) if full_output is True

Source code in sensitivity_torch/extras/optimization.py
def minimize_sqp(
    f_fn: Callable,
    g_fn: Callable,
    h_fn: Callable,
    *args: Tensor,
    reg0: float = 1e-7,
    verbose: bool = False,
    verbose_prefix: str = "",
    max_it: int = 100,
    ls_pts_nb: int = 5,
    force_step: bool = False,
    batched: bool = False,
    full_output: bool = False,
    callback_fn: Callable = None,
    use_writer: bool = False,
    use_tqdm: Union[bool, tqdm_module.std.tqdm, tqdm_module.notebook.tqdm_notebook] = True,
):
    """Minimize a loss function ``f_fn`` with Unconstrained Sequential Quadratic
    Programming (SQP) with respect to a single ``arg``.

    Args:
        f_fn: loss function
        g_fn: gradient of the loss function
        h_fn: Hessian of the loss function
        *args: arguments to be optimized
        reg0: Hessian regularization – optimization step regularization
        verbose: whether to print output
        verbose_prefix: prefix to append to verbose output, e.g. indentation
        max_it: maximum number of iterates
        batched: whether to optimize a batch of arguments, with batch of losses
        ls_pts_nb: number of linesearch points to consider per optimization step
        force_step: whether to take any non-zero optimization step even if worse
        full_output: whether to output optimization history
        callback_fn: callback function of the form ``cb_fn(*args, **kw)``
        use_writer: whether to use tensorflow's Summary Writer (via PyTorch)
        use_tqdm: whether to use tqdm (to estimate total runtime)
    Returns:
        Optimized ``args`` or ``(args, args_hist)`` if ``full_output`` is ``True``
    """
    if isinstance(use_tqdm, bool):
        if use_tqdm:
            print_fn, rng_wrapper = tqdm_module.tqdm.write, tqdm_module.tqdm
        else:
            print_fn, rng_wrapper = print, lambda x: x
    else:
        print_fn, rng_wrapper = use_tqdm.write, use_tqdm

    if len(args) > 1:
        raise ValueError("SQP only only supports single variable functions")
    x = args[0]
    x_shape = x.shape
    if batched:
        M, x_size = x_shape[0], np.prod(x_shape[1:])
    else:
        M, x_size = 1, x.numel()
    it, imprv = 0, float("inf")
    x_best, f_best = x, torch.atleast_1d(f_fn(x))
    f_hist, x_hist = [f_best], [x.detach().clone()]

    if callback_fn is not None:
        pdb.set_trace()
        callback_fn(x)

    tp = TablePrinter(
        ["it", "imprv", "loss", "reg_it", "bet", "||g_prev||_2"],
        ["%05d", "%9.4e", "%9.4e", "%02d", "%9.4e", "%9.4e"],
        prefix=verbose_prefix,
        use_writer=use_writer,
    )
    if verbose:
        print_fn(tp.make_header())
    for it in rng_wrapper(range(max_it)):
        g = g_fn(x).reshape((M, x_size))
        H = h_fn(x).reshape((M, x_size, x_size))
        if torch.any(torch.isnan(g)):
            raise RuntimeError("Gradient is NaN")
        if torch.any(torch.isnan(H)):
            raise RuntimeError("Hessian is NaN")

        # F, (reg_it_max, _) = _positive_factorization_cholesky(H, reg0)
        F, (reg_it_max, _) = _positive_factorization_lobpcg(H, reg0)

        d = torch.cholesky_solve(-g[..., None], F)[..., 0].reshape(x_shape)
        # F = H + reg0 * I
        # d = torch.solve(F, -g[..., None])[..., 0].reshape(x_shape)
        f = f_hist[-1]
        bet, data = _linesearch(
            f,
            x,
            d,
            f_fn,
            g_fn,
            ls_pts_nb=ls_pts_nb,
            force_step=force_step,
        )

        x = x + torch.reshape(bet, (M,) + (1,) * len(x_shape[1:])) * d
        x_hist.append(x.clone().detach())
        imprv = torch.mean(bet * data["d_norm"]).detach()
        if callback_fn is not None:
            callback_fn(x)
        if batched:
            x_bests = [None for _ in range(M)]
            f_bests = [None for _ in range(M)]
            for i in range(M):
                if data["f_best"][i] < f_best[i]:
                    x_bests[i], f_bests[i] = x[i, ...], data["f_best"][i]
                else:
                    x_bests[i], f_bests[i] = x_best[i, ...], f_best[i]
            x_best, f_best = torch.stack(x_bests), torch.stack(f_bests)
        else:
            if data["f_best"][0] < f_best[0]:
                x_best, f_best = x, data["f_best"]
        f_hist.append(data["f_best"])
        if verbose:
            print_fn(
                tp.make_values(
                    [
                        it,
                        imprv,
                        torch.mean(data["f_best"]),
                        reg_it_max,
                        bet[0],
                        torch.norm(g),
                    ]
                )
            )
        if imprv < 1e-9:
            break
        it += 1
    if verbose:
        print_fn(tp.make_footer())
    if full_output:
        return x_best, x_hist + [x_best.detach().clone()]
    else:
        return x_best