Skip to content

Commit

Permalink
Bug fixes for channels in JAX tester #13 #14
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 30, 2022
1 parent 02d0728 commit 9677a77
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/raygun/jax/networks/UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __init__(self,
for level in range(self.num_levels - 1)]
for _ in range(self.num_heads)]

def rec_forward(self, level, f_in, total_level):
def rec_forward(self, level, f_in, total_level): # TODO fix unet build

prefix = " "*(total_level-1-level)
print(prefix + "Creating U-Net layer %i" % (total_level-1-level))
Expand Down
95 changes: 46 additions & 49 deletions src/raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#%%
from raygun.jax.networks import UNet, NLayerDiscriminator2D
import raygun
from raygun.jax.networks import *
import jax
import jax.numpy as jnp
from jax import jit
Expand All @@ -12,13 +13,6 @@
import matplotlib.pyplot as plt
from tqdm import trange

from raygun.jax.networks.NLayerDiscriminator import NLayerDiscriminator3D


# PARAMETERS
mp_training = True # mixed-precision training using `jmp`
learning_rate = 0.5e-4


class Params(NamedTuple):
weight: jnp.ndarray
Expand All @@ -42,33 +36,34 @@ def train_step(self, inputs, pmapped):
raise RuntimeError("Unimplemented")


class Model(GenericJaxModel):
class JAXModel(GenericJaxModel):

def __init__(self):
def __init__(self, network_type = 'UNet', learning_rate = 0.5e-4, **net_kwargs): # TODO set up for **kwargs
super().__init__()
self.learning_rate = learning_rate

# to make assigning precision policy easier
class MyModel(hk.Module):

def __init__(self, name=None):
super().__init__(name=name)
self.net = UNet(
ngf=5,
ngf=3,
fmap_inc_factor=2,
downsample_factors=[[2,2,2],[2,2,2],[2,2,2]]
)
# net = getattr(raygun.jax.networks, network_type)
# self.net = net(net_kwargs)

def __call__(self, x):
return self.net(x)

def _forward(x):
net = MyModel()
def _forward(x): # Temporary set of _forward()
net = MyModel
return net(x)

if mp_training:
policy = jmp.get_policy('p=f32,c=f16,o=f32')
else:
policy = jmp.get_policy('p=f32,c=f32,o=f32')

policy = jmp.get_policy('p=f32,c=f32,o=f32')
hk.mixed_precision.set_policy(MyModel, policy)

self.model = hk.without_apply_rng(hk.transform(_forward))
Expand All @@ -84,8 +79,8 @@ def _forward(params, inputs):
def _loss_fn(weight, raw, gt, mask, loss_scale):
pred_affs = self.model.apply(weight, x=raw)
loss = optax.l2_loss(predictions=pred_affs, targets=gt)
loss = loss*2*mask # optax divides loss by 2 so we mult it back
loss_mean = loss.mean(where=mask)
# loss = loss*2*mask # optax divides loss by 2 so we mult it back
loss_mean = loss.mean()
return loss_scale.scale(loss_mean), (pred_affs, loss, loss_mean)

@jit
Expand All @@ -94,20 +89,14 @@ def _apply_optimizer(params, grads):
new_weight = optax.apply_updates(params.weight, updates)
return new_weight, new_opt_state

def _train_step(params, inputs, pmapped=False) -> Tuple[Params, Dict[str, jnp.ndarray], Any]:
def _train_step(params, inputs) -> Tuple[Params, Dict[str, jnp.ndarray], Any]:

raw, gt, mask = inputs['raw'], inputs['gt'], inputs['mask']
raw, gt = inputs['raw'], inputs['gt']

grads, (pred_affs, loss, loss_mean) = jax.grad(
_loss_fn, has_aux=True)(params.weight, raw, gt, mask,
_loss_fn, has_aux=True)(params.weight, raw, gt,
params.loss_scale)

if pmapped:
# sync grads, casting to compute precision (f16) for efficiency
grads = policy.cast_to_compute(grads)
grads = jax.lax.pmean(grads, axis_name='num_devices')
grads = policy.cast_to_param(grads)

# dynamic mixed precision loss scaling
grads = params.loss_scale.unscale(grads)
new_weight, new_opt_state = _apply_optimizer(params, grads)
Expand All @@ -127,18 +116,23 @@ def _train_step(params, inputs, pmapped=False) -> Tuple[Params, Dict[str, jnp.nd
self.train_step = _train_step


def initialize(self, rng_key, inputs, is_training=True):
def initialize(self, rng_key, inputs):
weight = self.model.init(rng_key, inputs['raw'])
opt_state = self.opt.init(weight)
if mp_training:
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
else:
loss_scale = jmp.NoOpLossScale()
loss_scale = jmp.NoOpLossScale()
return Params(weight, opt_state, loss_scale)

class NetworkTestJAX():

class NetworkTestJAX(): # TODO setup for **kwargs

def __init__(self, task=None, im='astronaut', batch_size=None, noise_factor=3, model=Model(), num_epochs=15) -> None:
def __init__(self, net_type = 'UNet',
task=None, im='astronaut',
batch_size=None,
noise_factor=3,
model=JAXModel(),
num_epochs=15,
**net_kwargs) -> None:

self.task = task
self.im = im
n_devices = jax.local_device_count()
Expand All @@ -147,15 +141,18 @@ def __init__(self, task=None, im='astronaut', batch_size=None, noise_factor=3, m
else:
self.batch_size = batch_size
self.noise_factor = noise_factor

self.model = model
self.num_epochs = num_epochs

# TODO
# net_type, net_kwargs

self.inputs = None
self.model_params = None

def im2batch(self):

def im2batch(self, im):
im = jnp.expand_dims(im, 0)
batch = []
for i in range(self.batch_size):
Expand All @@ -168,31 +165,30 @@ def data_engine(self):
self.inputs = {
'raw': jnp.ones([self.batch_size, 1, 132, 132, 132]),
'gt': jnp.zeros([self.batch_size, 3, 40, 40, 40]),
'mask': jnp.ones([self.batch_size, 3, 40, 40, 40])
}
else:
gt_import = getattr(data, self.im)()
if len(gt_import.shape) > 2: # Strips to use only one image
gt_import = gt_import[...,0]
gt = self.im2batch(im= jnp.asarray(gt_import), batch_size=self.batch_size)
gt = self.im2batch(im=jnp.asarray(gt_import))

noise_key = jax.random.PRNGKey(22)
noise = self.im2batch(im=jax.random.uniform(key=noise_key, shape=gt_import.shape), batch_size=batch_size)
noise = self.im2batch(im=jax.random.uniform(key=noise_key, shape=gt_import.shape))

raw = (gt*noise) / self.noise_factor + (gt/self.noise_factor)

self.inputs = {
'raw': raw,
'gt': gt
'gt': gt,
}

# init model
def init_model(self):
if self.inputs is None: # Create data engine if it does not exist
self.data_engine()

rng, inputs = jax.random.PRNGKey(42), self.inputs
self.model_params = self.model.initialize(rng, inputs, is_training=True)
self.rng = jax.random.PRNGKey(42)
self.model_params = self.model.initialize(rng_key=self.rng, inputs=self.inputs)

# test train loop
def train(self) -> None:
Expand All @@ -201,10 +197,11 @@ def train(self) -> None:
for _ in range(self.num_epochs):
t0 = time.time()

model_params, outputs, loss = jax.jit(
self.model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(
self.model_params, self.inputs, False)
self.model_params, outputs, loss = jax.jit(self.model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(self.model_params, self.inputs, False)


print(f'Loss: {loss}, took {time.time()-t0}s')

print(f'Loss: {loss}, took {time.time()-t0}s')
# %%

0 comments on commit 9677a77

Please sign in to comment.