Skip to content

ResNet 50 example

from __future__ import annotations

from pprint import pprint

from tqdm import tqdm
from datasets import load_dataset
import torch
from torchvision.models.resnet import resnet18
from torch import nn
from torch.func import functional_call
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import jax
from jax import numpy as jnp
import optax

from torch2jax import tree_t2j, torch2jax_with_vjp, tree_j2t, t2j, j2t

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_jax = jax.devices(device.type)[0]

Loading the dataset and the model (in PyTorch)

dataset = load_dataset("mnist", split="train")

def collate_torch_fn(batch):
    imgs = torch.stack([ToTensor()(x["image"]).repeat((3, 1, 1)) for x in batch]).to(device)
    labels = torch.tensor([x["label"] for x in batch]).to(device)
    return imgs, labels

collate_jax_fn = lambda batch: tree_t2j(collate_torch_fn(batch))
model = nn.Sequential(resnet18(), nn.Linear(1000, 10))
model.to(device)
model.eval()

opts = dict(batch_size=32, shuffle=True, num_workers=0)
dl = DataLoader(dataset, **opts)
dl_jax = DataLoader(dataset, **dict(opts, collate_fn=collate_jax_fn))
dl_torch = DataLoader(dataset, **dict(opts, collate_fn=collate_torch_fn))

Let's convert the torch model to a function, using torch.func.functional_call

params, buffers = dict(model.named_parameters()), dict(model.named_buffers())


def torch_fwd_fn(params, buffers, input):
    buffers = {k: torch.clone(v) for k, v in buffers.items()}
    return functional_call(model, (params, buffers), args=input)


Xt, yt = next(iter(dl_torch))
nondiff_argnums = (1, 2)  # buffers, input
jax_fwd_fn = jax.jit(
    torch2jax_with_vjp(torch_fwd_fn, params, buffers, Xt, nondiff_argnums=nondiff_argnums)
)
params_jax, buffers_jax = tree_t2j(params), tree_t2j(buffers)

Let's use torch's CrossEntropyLoss

Xt, yt = next(iter(dl_torch))
torch_ce_fn = lambda yp, y: nn.CrossEntropyLoss()(yp, y)
jax_ce_fn = torch2jax_with_vjp(torch_ce_fn, model(Xt), yt)

jax_l_fn = jax.jit(
    lambda params_jax, X, y: jnp.mean(jax_ce_fn(jax_fwd_fn(params_jax, buffers_jax, X), y))
)
jax_g_fn = jax.jit(jax.grad(jax_l_fn))
torch_g_fn = torch.func.grad(
    lambda params, Xt, yt: torch_ce_fn(torch_fwd_fn(params, buffers, Xt), yt)
)
X, y = next(iter(dl_jax))
gs_jax = jax_g_fn(params_jax, X, y)
gs_torch = torch_g_fn(params, *tree_j2t((X, y)))

# let's compute error in gradients between JAX and Torch (the errors are 0!)
errors = {k: float(jnp.linalg.norm(v - t2j(gs_torch[k]))) for k, v in gs_jax.items()}
pprint(errors)

{'0.bn1.bias': 6.606649449736324e-09,
'0.bn1.weight': 1.0237145575686668e-09,
'0.conv1.weight': 1.9232666659263487e-07,
'0.fc.bias': 0.0,
'0.fc.weight': 0.0,
'0.layer1.0.bn1.bias': 4.424356436771859e-09,
'0.layer1.0.bn1.weight': 5.933196711715993e-10,
'0.layer1.0.bn2.bias': 2.3588471176339e-09,
'0.layer1.0.bn2.weight': 4.533372566228877e-10,
'0.layer1.0.conv1.weight': 1.4028480599392879e-08,
'0.layer1.0.conv2.weight': 1.1964990775936712e-08,
'0.layer1.1.bn1.bias': 8.75052974524948e-10,
'0.layer1.1.bn1.weight': 2.0072446482721773e-10,
'0.layer1.1.bn2.bias': 5.820766091346741e-11,
'0.layer1.1.bn2.weight': 2.9103830456733704e-11,
'0.layer1.1.conv1.weight': 1.1259264631746646e-08,
'0.layer1.1.conv2.weight': 1.1262083710050774e-08,
'0.layer2.0.bn1.bias': 0.0,
'0.layer2.0.bn1.weight': 0.0,
'0.layer2.0.bn2.bias': 0.0,
'0.layer2.0.bn2.weight': 0.0,
'0.layer2.0.conv1.weight': 0.0,
'0.layer2.0.conv2.weight': 0.0,
'0.layer2.0.downsample.0.weight': 6.819701248161891e-09,
'0.layer2.0.downsample.1.bias': 0.0,
'0.layer2.0.downsample.1.weight': 0.0,
'0.layer2.1.bn1.bias': 0.0,
'0.layer2.1.bn1.weight': 0.0,
'0.layer2.1.bn2.bias': 0.0,
'0.layer2.1.bn2.weight': 5.820766091346741e-11,
'0.layer2.1.conv1.weight': 0.0,
'0.layer2.1.conv2.weight': 0.0,
'0.layer3.0.bn1.bias': 0.0,
'0.layer3.0.bn1.weight': 0.0,
'0.layer3.0.bn2.bias': 0.0,
'0.layer3.0.bn2.weight': 0.0,
'0.layer3.0.conv1.weight': 0.0,
'0.layer3.0.conv2.weight': 0.0,
'0.layer3.0.downsample.0.weight': 0.0,
'0.layer3.0.downsample.1.bias': 0.0,
'0.layer3.0.downsample.1.weight': 0.0,
'0.layer3.1.bn1.bias': 0.0,
'0.layer3.1.bn1.weight': 0.0,
'0.layer3.1.bn2.bias': 0.0,
'0.layer3.1.bn2.weight': 0.0,
'0.layer3.1.conv1.weight': 0.0,
'0.layer3.1.conv2.weight': 0.0,
'0.layer4.0.bn1.bias': 0.0,
'0.layer4.0.bn1.weight': 0.0,
'0.layer4.0.bn2.bias': 0.0,
'0.layer4.0.bn2.weight': 0.0,
'0.layer4.0.conv1.weight': 0.0,
'0.layer4.0.conv2.weight': 0.0,
'0.layer4.0.downsample.0.weight': 0.0,
'0.layer4.0.downsample.1.bias': 0.0,
'0.layer4.0.downsample.1.weight': 0.0,
'0.layer4.1.bn1.bias': 0.0,
'0.layer4.1.bn1.weight': 0.0,
'0.layer4.1.bn2.bias': 0.0,
'0.layer4.1.bn2.weight': 0.0,
'0.layer4.1.conv1.weight': 0.0,
'0.layer4.1.conv2.weight': 0.0,
'1.bias': 0.0,
'1.weight': 0.0}

Train loop

This isn't very efficient because torch synchronizes for every batch when called from JAX. Train in PyTorch, but you can do inference in JAX fast.

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params_jax)
update_fn, apply_updates = jax.jit(optimizer.update), jax.jit(optax.apply_updates)
for i, (X, y) in enumerate(tqdm(dl_jax, total=len(dl_jax))):
    gs = jax_g_fn(params_jax, X, y)
    updates, opt_state = update_fn(gs, opt_state)
    params_jax2 = apply_updates(params_jax, updates)
    if i > 10:
        break