diff --git a/raygun.egg-info/PKG-INFO b/raygun.egg-info/PKG-INFO index 6e9784f1..58d52f51 100644 --- a/raygun.egg-info/PKG-INFO +++ b/raygun.egg-info/PKG-INFO @@ -2,7 +2,6 @@ Metadata-Version: 2.1 Name: raygun Version: 0.0.1 Summary: Library for testing and implementing image denoising, enhancement, and segmentation techniques on large, volumetric datasets. Built with Gunpowder, Daisy, PyTorch, and eventually JAX (hopefully). -Home-page: UNKNOWN Author: Jeff Rhoades Author-email: rhoades@g.harvard.edu License: AGPL-3.0 diff --git a/raygun/__init__.py b/raygun/__init__.py index 43f4943a..4cc263d5 100644 --- a/raygun/__init__.py +++ b/raygun/__init__.py @@ -2,4 +2,5 @@ Main RAYGUN package """ from .data import * -from .torch import * \ No newline at end of file +from .torch import * +from .jax import * \ No newline at end of file diff --git a/raygun/jax/__init__.py b/raygun/jax/__init__.py index 939f1f5d..89ea7e08 100644 --- a/raygun/jax/__init__.py +++ b/raygun/jax/__init__.py @@ -1,3 +1,3 @@ """ - Main PyTorch package + Main JAX package """ \ No newline at end of file diff --git a/raygun/jax/networks/MIRNet2D.py b/raygun/jax/networks/MIRNet2D.py index b3cf2136..a4c9f5e0 100644 --- a/raygun/jax/networks/MIRNet2D.py +++ b/raygun/jax/networks/MIRNet2D.py @@ -13,7 +13,7 @@ import numpy as np # from pdb import set_trace as stx -from raygun.utils.antialias import Downsample as downsamp +from raygun.jax.utils.antialias import Downsample as downsamp diff --git a/raygun/jax/networks/UNet.py b/raygun/jax/networks/UNet.py index bd3098aa..f671a4b0 100644 --- a/raygun/jax/networks/UNet.py +++ b/raygun/jax/networks/UNet.py @@ -14,22 +14,22 @@ def __init__( activation, padding='VALID', residual=False, - # padding_mode='reflect', norm_layer=None, data_format='NCDHW'): super().__init__() + # if activation is not None: + # if isinstance(activation, str): + # self.activation = getattr(jax.nn, activation) + # else: + # self.activation = activation # assume activation is a defined function + # else: + # self.activation = jax.numpy.identity if activation is not None: - if isinstance(activation, str): - self.activation = getattr(jax.nn, activation) - else: - self.activation = activation() # assume activation is a defined function - else: - self.activation = jax.numpy.identity() - + activation = getattr(jax.nn, activation) + self.residual = residual - self.padding = padding layers = [] @@ -99,7 +99,7 @@ def crop(self, x, shape): return x[slices] - def forward(self, x): + def __call__(self, x): if not self.residual: return self.conv_pass(x) else: @@ -109,7 +109,60 @@ def forward(self, x): else: init_x = self.x_init_map(x) return self.activation(init_x + res) +# class ConvPass(hk.Module): + +# def __init__( +# self, +# out_channels, +# kernel_sizes, +# activation, +# padding='VALID', +# data_format='NCDHW'): + +# super().__init__() + +# if activation is not None: +# activation = getattr(jax.nn, activation) + +# layers = [] + +# for kernel_size in kernel_sizes: +# self.dims = len(kernel_size) + +# conv = { +# 2: hk.Conv2D, +# 3: hk.Conv3D, +# # 4: Conv4d # TODO +# }[self.dims] + +# if data_format is None: +# in_data_format = { +# 2: 'NCHW', +# 3: 'NCDHW' +# }[self.dims] +# else: +# in_data_format = data_format + +# try: +# layers.append( +# conv( +# output_channels=out_channels, +# kernel_shape=kernel_size, +# padding=padding, +# data_format=in_data_format)) +# except KeyError: +# raise RuntimeError( +# "%dD convolution not implemented" % self.dims) + +# if activation is not None: +# layers.append(activation) + +# self.conv_pass = hk.Sequential(layers) + +# def __call__(self, x): + +# return self.conv_pass(x) class ConvDownsample(hk.Module): @@ -175,7 +228,7 @@ def __init__( layers.append(self.activation) self.conv_pass = hk.Sequential(layers) - def forward(self, x): + def __call__(self, x): return self.conv_pass(x) @@ -196,39 +249,26 @@ def __init__( strides=downsample_factor, padding='VALID') - # def forward(self, x): - # if self.flexible: - # try: - # return self.down(x) - # except: - # self.check_mismatch(x.size()) - # else: - # self.check_mismatch(x.size()) - # return self.down(x) - - # def check_mismatch(self, size): - # for d in range(1, self.dims+1): - # if size[-d] % self.downsample_factor[-d] != 0: - # raise RuntimeError( - # "Can not downsample shape %s with factor %s, mismatch " - # "in spatial dimension %d" % ( - # size, - # self.downsample_factor, - # self.dims - d)) - # return self.down(size) def __call__(self, x): + if self.flexible: + try: + return self.down(x) + except: + self.check_mismatch(x.size()) + else: + self.check_mismatch(x.size()) + return self.down(x) - for d in range(1, self.dims + 1): - if x.shape[-d] % self.downsample_factor[-d] != 0: + def check_mismatch(self, size): + for d in range(1, self.dims+1): + if size[-d] % self.downsample_factor[-d] != 0: raise RuntimeError( "Can not downsample shape %s with factor %s, mismatch " "in spatial dimension %d" % ( - x.shape, + size, self.downsample_factor, self.dims - d)) - - return self.down(x) - + class Upsample(hk.Module): @@ -357,13 +397,13 @@ def __init__(self, num_heads=1, constant_upsample=False, downsample_method='max', - padding_type='valid', + padding_type='VALID', residual=False, norm_layer=None, - # add_noise=False + name=None ): - super().__init__() + super().__init__(name=name) self.ndims = len(downsample_factors[0]) self.num_levels = len(downsample_factors) + 1 self.num_heads = num_heads @@ -411,7 +451,6 @@ def __init__(self, norm_layer=norm_layer) for level in range(self.num_levels) ] - self.dims = self.l_conv[0].dims # Left downsample if downsample_method.lower() == 'max': @@ -499,7 +538,7 @@ def rec_forward(self, level, f_in, total_level): return fs_out - def forward(self, x): + def __call__(self, x): y = self.rec_forward(self.num_levels - 1, x, total_level=self.num_levels) diff --git a/raygun/jax/networks/__init__.py b/raygun/jax/networks/__init__.py index c86a9990..c6fde06e 100644 --- a/raygun/jax/networks/__init__.py +++ b/raygun/jax/networks/__init__.py @@ -1,3 +1,7 @@ """ JAX network architectures -""" \ No newline at end of file +""" +from .UNet import UNet +from .ResidualUNet import ResidualUNet +from .MIRNet2D import MIRNet +from .utils import * \ No newline at end of file diff --git a/raygun/jax/tests/network_test_jax.py b/raygun/jax/tests/network_test_jax.py index 21cbd981..0aeca1f1 100644 --- a/raygun/jax/tests/network_test_jax.py +++ b/raygun/jax/tests/network_test_jax.py @@ -1,7 +1,5 @@ -import sys -sys.path.append('/n/groups/htem/users/br128/raygun/') -from raygun.jax.networks.UNet import UNet -import os +#%% +from raygun.jax.networks import UNet import jax import jax.numpy as jnp from jax import jit @@ -14,7 +12,6 @@ from typing import Tuple, Any, NamedTuple, Dict - '''To test model with some dummy input and output, run with command `CUDA_VISIBLE_DEVICES=0 python unet_example.py` @@ -53,13 +50,12 @@ def forward(self, inputs): def train_step(self, inputs, pmapped): raise RuntimeError("Unimplemented") - +#%% class Model(GenericJaxModel): def __init__(self): super().__init__() - # we encapsulate the UNet and the ConvPass in one hk.Module # to make assigning precision policy easier class MyModel(hk.Module): @@ -72,7 +68,6 @@ def __init__(self, name=None): ) def __call__(self, x): - # return self.conv(self.unet(x)) return self.unet(x) def _forward(x): @@ -140,6 +135,7 @@ def _train_step(params, inputs, pmapped=False) -> Tuple[Params, Dict[str, jnp.nd self.train_step = _train_step + def initialize(self, rng_key, inputs, is_training=True): weight = self.model.init(rng_key, inputs['raw']) opt_state = self.opt.init(weight) @@ -148,72 +144,69 @@ def initialize(self, rng_key, inputs, is_training=True): else: loss_scale = jmp.NoOpLossScale() return Params(weight, opt_state, loss_scale) - +#%% def split(arr, n_devices): """Splits the first axis of `arr` evenly across the number of devices.""" return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:]) - def create_network(): # returns a model that Gunpowder `Predict` and `Train` node can use return Model() - -if __name__ == "__main__": - - my_model = Model() - - n_devices = jax.local_device_count() - batch_size = 4*n_devices - - raw = jnp.ones([batch_size, 1, 132, 132, 132]) - gt = jnp.zeros([batch_size, 3, 40, 40, 40]) - mask = jnp.ones([batch_size, 3, 40, 40, 40]) - inputs = { +#%% +my_model = Model() + +n_devices = jax.local_device_count() +batch_size = 4*n_devices + +raw = jnp.ones([batch_size, 1, 132, 132, 132]) +gt = jnp.zeros([batch_size, 3, 40, 40, 40]) +mask = jnp.ones([batch_size, 3, 40, 40, 40]) +inputs = { + 'raw': raw, + 'gt': gt, + 'mask': mask, +} +rng= jax.random.PRNGKey(42) +#%% +# init model +if n_devices > 1: + # split input for pmap + raw = split(raw, n_devices) + gt = split(gt, n_devices) + mask = split(mask, n_devices) + single_device_inputs = { 'raw': raw, 'gt': gt, - 'mask': mask, + 'mask': mask } - rng: KeyArray= jax.random.PRNGKey(42) + rng = jnp.broadcast_to(rng, (n_devices,) + rng.shape) + model_params = jax.pmap(my_model.initialize)(rng, single_device_inputs) - # init model - if n_devices > 1: - # split input for pmap - raw = split(raw, n_devices) - gt = split(gt, n_devices) - mask = split(mask, n_devices) - single_device_inputs = { - 'raw': raw, - 'gt': gt, - 'mask': mask - } - rng = jnp.broadcast_to(rng, (n_devices,) + rng.shape) - model_params = jax.pmap(my_model.initialize)(rng, single_device_inputs) +else: + model_params = my_model.initialize(rng, inputs, is_training=True) +#%% +# test forward +y = jit(my_model.forward)(model_params, {'raw': raw}) +assert y['affs'].shape == (batch_size, 3, 40, 40, 40) +# test train loop +for _ in range(10): + t0 = time.time() + + if n_devices > 1: + model_params, outputs, loss = jax.pmap( + my_model.train_step, + axis_name='num_devices', + donate_argnums=(0,), + static_broadcasted_argnums=(2,))( + model_params, inputs, True) else: - model_params = my_model.initialize(rng, inputs, is_training=True) - - # test forward - y = jit(my_model.forward)(model_params, {'raw': raw}) - assert y['affs'].shape == (batch_size, 3, 40, 40, 40) - - # test train loop - for _ in range(10): - t0 = time.time() - - if n_devices > 1: - model_params, outputs, loss = jax.pmap( - my_model.train_step, - axis_name='num_devices', - donate_argnums=(0,), - static_broadcasted_argnums=(2,))( - model_params, inputs, True) - else: - model_params, outputs, loss = jax.jit( - my_model.train_step, - donate_argnums=(0,), - static_argnums=(2,))( - model_params, inputs, False) + model_params, outputs, loss = jax.jit( + my_model.train_step, + donate_argnums=(0,), + static_argnums=(2,))( + model_params, inputs, False) - print(f'Loss: {loss}, took {time.time()-t0}s') \ No newline at end of file + print(f'Loss: {loss}, took {time.time()-t0}s') \ No newline at end of file diff --git a/raygun/jax/tests/network_test_torch.py b/raygun/jax/tests/network_test_torch.py index c4be1d41..5884ea18 100644 --- a/raygun/jax/tests/network_test_torch.py +++ b/raygun/jax/tests/network_test_torch.py @@ -12,6 +12,7 @@ from tqdm import trange torch.cuda.set_device(1) + # %% class Test(): def __init__(self, diff --git a/raygun/jax/tests/network_test_tri.py b/raygun/jax/tests/network_test_tri.py index 76a6600f..c3a01aec 100644 --- a/raygun/jax/tests/network_test_tri.py +++ b/raygun/jax/tests/network_test_tri.py @@ -1,3 +1,4 @@ +#%% import os import jax import jax.numpy as jnp @@ -67,11 +68,6 @@ def __init__(self, name=None): fmap_inc_factor=3, downsample_factors=[[2,2,2],[2,2,2],[2,2,2]], ) - self.conv = ConvPass( - kernel_sizes=[[1,1,1]], - out_channels=3, - activation='sigmoid', - ) def __call__(self, x): # return self.conv(self.unet(x)) @@ -150,7 +146,7 @@ def initialize(self, rng_key, inputs, is_training=True): else: loss_scale = jmp.NoOpLossScale() return Params(weight, opt_state, loss_scale) - +#%% def split(arr, n_devices): """Splits the first axis of `arr` evenly across the number of devices."""