Skip to content

Tour of sensitivity_jax

import jax
from jax import numpy as jnp
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_moons

from sensitivity_jax.sensitivity import implicit_jacobian, generate_optimization_fns
from sensitivity_jax.extras.optimization.sqp import minimize_sqp
from sensitivity_jax.extras.optimization.lbfgs import minimize_lbfgs
from sensitivity_jax.extras.optimization.agd import minimize_agd


def feature_map(X, P, q):
    Phi = jnp.cos(X @ P + q)
    return Phi


def optimize_fn(P, q, X, Y):
    """Non-differentiable optimization function using sklearn."""
    Phi = feature_map(X, P, q)

    def optimize_python(Phi, Y):
        model = LogisticRegression(max_iter=200, C=1e-2)
        model.fit(Phi, Y)
        return np.concatenate([model.coef_.reshape(-1), model.intercept_])

    z = jax.pure_callback(
        optimize_python, jax.ShapeDtypeStruct((Phi.shape[-1] + 1,), Phi.dtype), Phi, Y
    )
    return z


def loss_fn(z, P, q, X, Y):
    """The measure of model performance."""
    Phi = feature_map(X, P, q)
    W, b = z[:-1], z[-1]
    prob = jax.nn.sigmoid(Phi @ W + b)
    loss = jnp.sum(-(Y * jnp.log(prob) + (1 - Y) * jnp.log(1 - prob)))
    regularization = 0.5 * jnp.sum(W**2) / 1e-2
    return (loss + regularization) / X.shape[-2]


def k_fn(z, P, q, X, Y):
    """Optimalty conditions of the model – the fixed-point of optimization."""
    return jax.grad(loss_fn)(z, P, q, X, Y)


def main():
    dtype = jnp.float64
    ######################## SETUP ####################################
    # generate data and split into train and test sets
    X, Y = make_moons(n_samples=1000, noise=1e-1)
    X, Y = jnp.array(X, dtype=dtype), jnp.array(Y, dtype=dtype)
    n_train = 200
    Xtr, Ytr = X[:n_train, :], Y[:n_train]
    Xts, Yts = X[n_train:, :], Y[n_train:]

    # generate the Fourier feature map parameters Phi = cos(X @ P + q)
    n_features = 200
    P = jnp.array(np.random.randn(2, n_features), dtype=dtype)
    q = jnp.array(np.zeros(n_features), dtype=dtype)

    # check that the fixed-point is numerically close to zero
    z = optimize_fn(P, q, X, Y)
    k = k_fn(z, P, q, X, Y)
    assert jnp.max(jnp.abs(k)) < 1e-4

    ######################## SENSITIVITY ##############################
    # generate sensitivity Jacobians and check their shape
    JP, Jq = implicit_jacobian(lambda z, P, q: k_fn(z, P, q, X, Y), z, P, q)
    assert JP.shape == z.shape + P.shape
    assert Jq.shape == z.shape + q.shape

    ######################## OPTIMIZATION #############################
    # generate necessary functions for optimization
    opt_fn_ = lambda P: optimize_fn(P, q, Xtr, Ytr)  # model optimization
    k_fn_ = lambda z, P: k_fn(z, P, q, Xtr, Ytr)  # fixed-point
    loss_fn_ = lambda z, P: loss_fn(z, P, q, Xts, Yts)  # loss to improve
    f_fn, g_fn, h_fn = generate_optimization_fns(loss_fn_, opt_fn_, k_fn_)

    # choose any optimization routine
    # Ps = minimize_agd(f_fn, g_fn, P, verbose=True, ai=1e-1, af=1e-1, max_it=100)
    # Ps = minimize_lbfgs(f_fn, g_fn, P, verbose=True, lr=1e-1, max_it=10)
    Ps = minimize_sqp(f_fn, g_fn, h_fn, P, verbose=True, max_it=10, ls_pts_nb=5)

    def predict(z, P, q, X):
        Phi = feature_map(X, P, q)
        W, b = z[:-1], z[-1]
        prob = jax.nn.sigmoid(Phi @ W + b)
        return jnp.round(prob)

    # evaluate the results
    acc0 = jnp.mean(predict(opt_fn_(P), P, q, Xts) == Yts)
    accf = jnp.mean(predict(opt_fn_(Ps), Ps, q, Xts) == Yts)
    print("Accuracy before: %4.2f%%" % (1e2 * acc0))
    print("Accuracy after:  %4.2f%%" % (1e2 * accf))
    print("Loss before:     %9.4e" % loss_fn(opt_fn_(P), P, q, Xts, Yts))
    print("Loss after:      %9.4e" % loss_fn(opt_fn_(Ps), Ps, q, Xts, Yts))


if __name__ == "__main__":
    main()