import functools
import copy
from pathlib import Path
import torch
import torch.nn as nn
import jax
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P, NamedSharding
from torch2jax import torch2jax, torch2jax_with_vjp, tree_j2t, tree_t2j
def _setattr(mod, key, delim: str = "."):
if delim not in key:
setattr(mod, key, None)
else:
key, key_remaining = key.split(delim, 1)
_setattr(getattr(mod, key), key_remaining, delim=delim)
def _strip_model(model):
for key in dict(model.named_parameters()).keys():
_setattr(model, key, delim=".")
if __name__ == "__main__":
model = nn.Sequential(nn.Linear(1024 * 1024, 1024), nn.SiLU(), nn.Linear(1024, 16)).to("cuda:0")
params = dict(model.named_parameters())
[p.requires_grad_(False) for p in params.values()]
_strip_model(model) # remove params from the model, leaving only a skeleton
def call_model_torch(x, params):
ys = []
for _ in range(30):
# functional_call uses the model in-place, we need a local copy
local_model_skeleton = copy.deepcopy(model)
ys.append(torch.func.functional_call(local_model_skeleton, params, x))
return sum(ys)
# jax init
devices = jax.devices("cuda")
mesh = jax.make_mesh((len(devices),), P("x"), devices=devices)
params_sharding = NamedSharding(mesh, P()) # fully replicated
batch_sharding = NamedSharding(mesh, P("x", None)) # sharded along batch
x = jax.jit(
lambda: jax.random.normal(jax.random.key(0), (128, 1024 * 1024)),
out_shardings=batch_sharding,
)()
params = jax.tree.map(lambda p: jax.device_put(p, params_sharding), tree_t2j(params))
params_spec = jax.tree.map(lambda _: params_sharding.spec, params)
@jax.jit
@functools.partial(
shard_map,
mesh=mesh,
in_specs=(batch_sharding.spec, params_spec),
out_specs=batch_sharding.spec,
check_rep=False,
)
def fwd_fn(x, params):
return torch2jax_with_vjp(call_model_torch, x, params, output_shapes=x[:, :16])(x, params)
y = fwd_fn(x, params)
# OR using JIT (but without gradients)
fwd_fn = jax.jit(
torch2jax(
call_model_torch, x, params, output_shapes=x[:, :16], output_sharding_spec=P("x", None)
)
)
y = fwd_fn(x, params)
# profile the computation
_ = fwd_fn(x, params)
path = Path("/tmp/profiles/data_parallel")
path.mkdir(parents=True, exist_ok=True)
with jax.profiler.trace(str(path)):
for _ in range(10):
fwd_fn(x, params).block_until_ready()