Skip to content

sensitivity_torch.extras.optimization.minimize_lbfgs(f_fn, g_fn, *args, verbose=False, verbose_prefix='', lr=1.0, max_it=100, batched=False, full_output=False, callback_fn=None, use_writer=False, use_tqdm=True)

Minimize a loss function f_fn with L-BFGS with respect to *args. Taken from PyTorch.

Parameters:

Name Type Description Default
f_fn Callable

loss function

required
g_fn Callable

gradient of the loss function

required
*args Tensor

arguments to be optimized

()
verbose bool

whether to print output

False
verbose_prefix str

prefix to append to verbose output, e.g. indentation

''
lr float

learning rate, where 1.0 is unstable, use 1e-1 in most cases

1.0
max_it int

maximum number of iterates

100
batched bool

whether to optimize a batch of arguments, with batch of losses

False
full_output bool

whether to output optimization history

False
callback_fn Callable

callback function of the form cb_fn(*args, **kw)

None
use_writer bool

whether to use tensorflow's Summary Writer (via PyTorch)

False
use_tqdm Union[bool, tqdm_module.std.tqdm, tqdm_module.notebook.tqdm_notebook]

whether to use tqdm (to estimate total runtime)

True

Returns:

Type Description

Optimized args or (args, args_hist) if full_output is True

Source code in sensitivity_torch/extras/optimization.py
def minimize_lbfgs(
    f_fn: Callable,
    g_fn: Callable,
    *args: Tensor,
    verbose: bool = False,
    verbose_prefix: str = "",
    lr: float = 1e0,
    max_it: int = 100,
    batched: bool = False,
    full_output: bool = False,
    callback_fn: Callable = None,
    use_writer: bool = False,
    use_tqdm: Union[bool, tqdm_module.std.tqdm, tqdm_module.notebook.tqdm_notebook] = True,
):
    """Minimize a loss function ``f_fn`` with L-BFGS with respect to ``*args``.
    Taken from PyTorch.

    Args:
        f_fn: loss function
        g_fn: gradient of the loss function
        *args: arguments to be optimized
        verbose: whether to print output
        verbose_prefix: prefix to append to verbose output, e.g. indentation
        lr: learning rate, where 1.0 is unstable, use 1e-1 in most cases
        max_it: maximum number of iterates
        batched: whether to optimize a batch of arguments, with batch of losses
        full_output: whether to output optimization history
        callback_fn: callback function of the form ``cb_fn(*args, **kw)``
        use_writer: whether to use tensorflow's Summary Writer (via PyTorch)
        use_tqdm: whether to use tqdm (to estimate total runtime)
    Returns:
        Optimized ``args`` or ``(args, args_hist)`` if ``full_output`` is ``True``
    """
    if isinstance(use_tqdm, bool):
        if use_tqdm:
            print_fn, rng_wrapper = tqdm_module.tqdm.write, tqdm_module.tqdm
        else:
            print_fn, rng_wrapper = print, lambda x: x
    else:
        print_fn, rng_wrapper = use_tqdm.write, use_tqdm

    assert len(args) > 0
    assert g_fn is not None or all([isinstance(arg, torch.Tensor) for arg in args])
    args = [arg.detach().clone() for arg in args]
    for arg in args:
        arg.requires_grad = True
    imprv = float("inf")
    it = 0
    opt = torch.optim.LBFGS(args, lr=lr)
    # args_hist = [[arg.detach().clone() for arg in args]]
    args_hist, grads_hist = [], []

    if callback_fn is not None:
        callback_fn(*args)

    def closure():
        opt.zero_grad()
        if g_fn is None:
            loss = torch.sum(f_fn(*args))
            loss.backward()
            if batched:
                loss = loss / args[0].shape[0]
        else:
            args_ = [arg for arg in args]
            loss = torch.mean(f_fn(*args_))
            gs = g_fn(*args_)
            gs = gs if isinstance(gs, list) or isinstance(gs, tuple) else [gs]
            for arg, g in zip(args, gs):
                arg.grad = torch.detach(g)
        return loss

    tp = TablePrinter(
        ["it", "imprv", "loss", "||g||_2"],
        ["%05d", "%9.4e", "%9.4e", "%9.4e"],
        prefix=verbose_prefix,
        use_writer=use_writer,
    )
    if verbose:
        print_fn(tp.make_header())
    for it in rng_wrapper(range(max_it)):
        args_prev = [arg.detach().clone() for arg in args]
        loss = opt.step(closure)
        if full_output:
            args_hist.append([arg.detach().clone() for arg in args])
            grads_hist.append(
                [arg.grad.detach().clone() if arg.grad is not None else None for arg in args]
            )
        if callback_fn is not None:
            callback_fn(*args)
        if batched:
            imprv = sum(
                torch.mean(torch.norm(arg_prev - arg, dim=tuple(range(-(arg.ndim - 1), 0))))
                for (arg, arg_prev) in zip(args, args_prev)
            )
        else:
            imprv = sum(
                torch.norm(arg_prev - arg).detach() for (arg, arg_prev) in zip(args, args_prev)
            )
        if verbose:
            closure()
            g_norm = sum(arg.grad.norm().detach() for arg in args if arg.grad is not None)
            print_fn(tp.make_values([it, imprv.detach(), loss.detach(), g_norm]))
        if imprv < 1e-9:
            break
        it += 1
    if verbose:
        print_fn(tp.make_footer())
    ret = [arg.detach() for arg in args]
    ret = ret if len(args) > 1 else ret[0]
    args_hist = [z if len(args) > 1 else z[0] for z in args_hist]
    grads_hist = [z if z is None or len(args) > 1 else z[0] for z in grads_hist]
    if full_output:
        return ret, args_hist, grads_hist
    else:
        return ret