diff --git a/src/raygun/jax/networks/UNet.py b/src/raygun/jax/networks/UNet.py index ea77e3bc..e42a9204 100644 --- a/src/raygun/jax/networks/UNet.py +++ b/src/raygun/jax/networks/UNet.py @@ -493,7 +493,7 @@ def __init__(self, for level in range(self.num_levels - 1)] for _ in range(self.num_heads)] - def rec_forward(self, level, f_in, total_level): + def rec_forward(self, level, f_in, total_level): # TODO fix unet build prefix = " "*(total_level-1-level) print(prefix + "Creating U-Net layer %i" % (total_level-1-level)) diff --git a/src/raygun/jax/tests/network_test_jax.py b/src/raygun/jax/tests/network_test_jax.py index 02a344e5..cbc0494e 100644 --- a/src/raygun/jax/tests/network_test_jax.py +++ b/src/raygun/jax/tests/network_test_jax.py @@ -1,5 +1,6 @@ #%% -from raygun.jax.networks import UNet, NLayerDiscriminator2D +import raygun +from raygun.jax.networks import * import jax import jax.numpy as jnp from jax import jit @@ -12,13 +13,6 @@ import matplotlib.pyplot as plt from tqdm import trange -from raygun.jax.networks.NLayerDiscriminator import NLayerDiscriminator3D - - -# PARAMETERS -mp_training = True # mixed-precision training using `jmp` -learning_rate = 0.5e-4 - class Params(NamedTuple): weight: jnp.ndarray @@ -42,10 +36,11 @@ def train_step(self, inputs, pmapped): raise RuntimeError("Unimplemented") -class Model(GenericJaxModel): +class JAXModel(GenericJaxModel): - def __init__(self): + def __init__(self, network_type = 'UNet', learning_rate = 0.5e-4, **net_kwargs): # TODO set up for **kwargs super().__init__() + self.learning_rate = learning_rate # to make assigning precision policy easier class MyModel(hk.Module): @@ -53,22 +48,22 @@ class MyModel(hk.Module): def __init__(self, name=None): super().__init__(name=name) self.net = UNet( - ngf=5, + ngf=3, fmap_inc_factor=2, downsample_factors=[[2,2,2],[2,2,2],[2,2,2]] ) + # net = getattr(raygun.jax.networks, network_type) + # self.net = net(net_kwargs) def __call__(self, x): return self.net(x) - def _forward(x): - net = MyModel() + def _forward(x): # Temporary set of _forward() + net = MyModel return net(x) - if mp_training: - policy = jmp.get_policy('p=f32,c=f16,o=f32') - else: - policy = jmp.get_policy('p=f32,c=f32,o=f32') + + policy = jmp.get_policy('p=f32,c=f32,o=f32') hk.mixed_precision.set_policy(MyModel, policy) self.model = hk.without_apply_rng(hk.transform(_forward)) @@ -84,8 +79,8 @@ def _forward(params, inputs): def _loss_fn(weight, raw, gt, mask, loss_scale): pred_affs = self.model.apply(weight, x=raw) loss = optax.l2_loss(predictions=pred_affs, targets=gt) - loss = loss*2*mask # optax divides loss by 2 so we mult it back - loss_mean = loss.mean(where=mask) + # loss = loss*2*mask # optax divides loss by 2 so we mult it back + loss_mean = loss.mean() return loss_scale.scale(loss_mean), (pred_affs, loss, loss_mean) @jit @@ -94,20 +89,14 @@ def _apply_optimizer(params, grads): new_weight = optax.apply_updates(params.weight, updates) return new_weight, new_opt_state - def _train_step(params, inputs, pmapped=False) -> Tuple[Params, Dict[str, jnp.ndarray], Any]: + def _train_step(params, inputs) -> Tuple[Params, Dict[str, jnp.ndarray], Any]: - raw, gt, mask = inputs['raw'], inputs['gt'], inputs['mask'] + raw, gt = inputs['raw'], inputs['gt'] grads, (pred_affs, loss, loss_mean) = jax.grad( - _loss_fn, has_aux=True)(params.weight, raw, gt, mask, + _loss_fn, has_aux=True)(params.weight, raw, gt, params.loss_scale) - if pmapped: - # sync grads, casting to compute precision (f16) for efficiency - grads = policy.cast_to_compute(grads) - grads = jax.lax.pmean(grads, axis_name='num_devices') - grads = policy.cast_to_param(grads) - # dynamic mixed precision loss scaling grads = params.loss_scale.unscale(grads) new_weight, new_opt_state = _apply_optimizer(params, grads) @@ -127,18 +116,23 @@ def _train_step(params, inputs, pmapped=False) -> Tuple[Params, Dict[str, jnp.nd self.train_step = _train_step - def initialize(self, rng_key, inputs, is_training=True): + def initialize(self, rng_key, inputs): weight = self.model.init(rng_key, inputs['raw']) opt_state = self.opt.init(weight) - if mp_training: - loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15)) - else: - loss_scale = jmp.NoOpLossScale() + loss_scale = jmp.NoOpLossScale() return Params(weight, opt_state, loss_scale) -class NetworkTestJAX(): + +class NetworkTestJAX(): # TODO setup for **kwargs - def __init__(self, task=None, im='astronaut', batch_size=None, noise_factor=3, model=Model(), num_epochs=15) -> None: + def __init__(self, net_type = 'UNet', + task=None, im='astronaut', + batch_size=None, + noise_factor=3, + model=JAXModel(), + num_epochs=15, + **net_kwargs) -> None: + self.task = task self.im = im n_devices = jax.local_device_count() @@ -147,15 +141,18 @@ def __init__(self, task=None, im='astronaut', batch_size=None, noise_factor=3, m else: self.batch_size = batch_size self.noise_factor = noise_factor - + self.model = model self.num_epochs = num_epochs # TODO + # net_type, net_kwargs + self.inputs = None self.model_params = None - def im2batch(self): + + def im2batch(self, im): im = jnp.expand_dims(im, 0) batch = [] for i in range(self.batch_size): @@ -168,22 +165,21 @@ def data_engine(self): self.inputs = { 'raw': jnp.ones([self.batch_size, 1, 132, 132, 132]), 'gt': jnp.zeros([self.batch_size, 3, 40, 40, 40]), - 'mask': jnp.ones([self.batch_size, 3, 40, 40, 40]) } else: gt_import = getattr(data, self.im)() if len(gt_import.shape) > 2: # Strips to use only one image gt_import = gt_import[...,0] - gt = self.im2batch(im= jnp.asarray(gt_import), batch_size=self.batch_size) + gt = self.im2batch(im=jnp.asarray(gt_import)) noise_key = jax.random.PRNGKey(22) - noise = self.im2batch(im=jax.random.uniform(key=noise_key, shape=gt_import.shape), batch_size=batch_size) + noise = self.im2batch(im=jax.random.uniform(key=noise_key, shape=gt_import.shape)) raw = (gt*noise) / self.noise_factor + (gt/self.noise_factor) self.inputs = { 'raw': raw, - 'gt': gt + 'gt': gt, } # init model @@ -191,8 +187,8 @@ def init_model(self): if self.inputs is None: # Create data engine if it does not exist self.data_engine() - rng, inputs = jax.random.PRNGKey(42), self.inputs - self.model_params = self.model.initialize(rng, inputs, is_training=True) + self.rng = jax.random.PRNGKey(42) + self.model_params = self.model.initialize(rng_key=self.rng, inputs=self.inputs) # test train loop def train(self) -> None: @@ -201,10 +197,11 @@ def train(self) -> None: for _ in range(self.num_epochs): t0 = time.time() - model_params, outputs, loss = jax.jit( - self.model.train_step, - donate_argnums=(0,), - static_argnums=(2,))( - self.model_params, self.inputs, False) + self.model_params, outputs, loss = jax.jit(self.model.train_step, + donate_argnums=(0,), + static_argnums=(2,))(self.model_params, self.inputs, False) + + + print(f'Loss: {loss}, took {time.time()-t0}s') - print(f'Loss: {loss}, took {time.time()-t0}s') \ No newline at end of file +# %%