Skip to content

sensitivity_torch.differentiation.BATCH_JACOBIAN(fn, args, **config)

Computes the Hessian, assuming the first in/out dimension is the batch.

Source code in sensitivity_torch/differentiation.py
def BATCH_JACOBIAN(fn, args, **config):
    """Computes the Hessian, assuming the first in/out dimension is the batch."""
    single_input = not isinstance(args, (list, tuple))
    args = (args,) if single_input else tuple(args)
    Js = JACOBIAN(lambda *args: torch.sum(fn(*args), 0), args, **config)
    out_shapes = [J.shape[: -len(arg.shape)] for (J, arg) in zip(Js, args)]
    Js = [
        J.reshape((prod(out_shape),) + arg.shape)
        .swapaxes(0, 1)
        .reshape((arg.shape[0],) + out_shape + arg.shape[1:])
        for (J, out_shape, arg) in zip(Js, out_shapes, args)
    ]
    return Js[0] if single_input else Js