Skip to content

Commit

Permalink
Attempting inheritence solutions #20
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Sep 8, 2022
1 parent 4b454e5 commit b5ed009
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 52 deletions.
11 changes: 5 additions & 6 deletions src/raygun/jax/networks/NLayerDiscriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, ngf=64, n_layers=3, norm_layer=hk.BatchNorm,

padw = "VALID"
ds_kw = downsampling_kw
sequence = [hk.Conv2D(output_channels=ngf, kernel_shape=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu(0.2, True)]
sequence = [hk.Conv2D(output_channels=ngf, kernel_shape=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu()]
nf_mult = 1
# nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
Expand All @@ -35,16 +35,15 @@ def __init__(self, ngf=64, n_layers=3, norm_layer=hk.BatchNorm,
sequence += [
hk.Conv2D(output_channels=ngf * nf_mult, kernel_shape=ds_kw, stride=2, padding=padw, with_bias=True),
norm_layer(create_scale=False, create_offset=False, decay_rate=0.999), # TODO FIX OFFSET AND DECAY RATE
jax.nn.leaky_relu(0.2, True)
jax.nn.leaky_relu()
]

# nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
hk.Conv2D(output_channels=ngf * nf_mult, kernel_shape=kw, stride=1, padding=padw, with_bias=True),
# norm_layer(ngf * nf_mult),
norm_layer(create_scale=True, create_offset=False, decay_rate=0.999), # TODO FIX OFFSET AND DECAY RATE
jax.nn.leaky_relu(0.2, True)
jax.nn.leaky_relu()
]

sequence += [hk.Conv2D(output_channels=1, kernel_shape=kw, stride=1, padding=padw)] # output 1 channel prediction map
Expand All @@ -70,7 +69,7 @@ def FOV(self):

return r

def __cal__(self, input):
def __call__(self, input):
"""Standard forward."""
return self.model(input)

Expand Down Expand Up @@ -99,7 +98,7 @@ def __init__(self, ngf=64, n_layers=3, norm_layer=hk.BatchNorm,

padw = "VALID"
ds_kw = downsampling_kw
sequence = [hk.Conv3D(output_channels=ngf, kernel_shape=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu(0.2, True)]
sequence = [hk.Conv3D(output_channels=ngf, kernel_shape=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu]
nf_mult = 1
# nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
Expand Down
72 changes: 40 additions & 32 deletions src/raygun/jax/networks/ResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ResnetGenerator2D(hk.Module):
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""

def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='VALID', activation=jax.nn.relu, add_noise=False, n_downsampling=2):
def __init__(self, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='VALID', activation=jax.nn.relu, add_noise=False, n_downsampling=2):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
Expand All @@ -30,37 +30,43 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, us
use_bias = norm_layer.func == hk.InstanceNorm
else:
use_bias = norm_layer == hk.InstanceNorm

p = 0
updown_p = 1
padder = []
# if padding_type.lower() == 'reflect': # TODO parallel in JAX?
# padder = [hk.pad.same(3)]
if padding_type.lower() == 'replicate':
padder = [hk.pad.same(3)]
# padder = [hk.pad.same(3)]
pass
elif padding_type.lower() == 'zeros':
p = 3
# p = 3
pass
elif padding_type.lower() == 'valid':
p = 'valid'
updown_p = 0
p = 'VALID'
updown_p = 'VALID'

model = []
model += padder.copy()
# model += padder.copy()
model += [hk.Conv2D(ngf, kernel_shape=7, padding=p, with_bias=use_bias),
norm_layer(ngf),
activation()]

norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]
print('pass 1')

for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [hk.Conv2D(ngf * mult, ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias),
norm_layer(ngf * mult * 2),
activation()]

mult = 2 ** n_downsampling
model += [hk.Conv2D(output_channels=ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias),
norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]

print('pass 2')

mult = 2 ** n_downsampling # TODO INHERITANCE ISSUE WITH ADDING RESNET 3D Block POSITIONAL ARGUMENTS
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock2D(dim=(ngf * mult), padding_type=p, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)]

model += [ResnetBlock2D(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, with_bias=use_bias, activation=activation)]

print('pass 3')
if add_noise == 'param': # add noise feature if necessary
# model += [ParameterizedNoiseBlock()]
pass
Expand All @@ -69,16 +75,16 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, us

for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [hk.Conv2DTranspose(ngf * mult + (i==0 and (add_noise is not False)),
model += [hk.Conv2DTranspose(
int(ngf * mult / 2),
kernel_shape=3, stride=2,
padding=updown_p,
with_bias=use_bias),
norm_layer(int(ngf * mult / 2)),
activation()]
norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]
model += padder.copy()
model += [hk.Conv2D(output_nc, kernel_shape=7, padding=p)]
model += [jax.nn.tanh()]
model += [jax.nn.tanh]

self.model = hk.Sequential(*model)

Expand Down Expand Up @@ -116,25 +122,27 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,
padder = []
# if padding_type == 'reflect': # TODO parallel in JAX?
# padder = [torch.nn.ReflectionPad2d(1)]
if padding_type == 'replicate':
padder = [hk.pad.same(1)]
elif padding_type == 'zeros':
p = 1
elif padding_type == 'valid':
p = 'valid'
if padding_type.upper() == 'REPLICATE':
# padder = [hk.pad.same(1)]
pass
elif padding_type.upper() == 'ZEROS':
# p = 1
pass
elif padding_type.upper() == 'VALID':
p = 'VALID'
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)

conv_block = []
conv_block += padder.copy()
# conv_block += padder.copy()

conv_block += [hk.Conv2D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(dim), activation()]
conv_block += [hk.Conv2D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), activation]
if use_dropout:
key = jax.random.PRNGKey(22)
conv_block += [hk.dropout(key, 0.2)] # TODO

conv_block += padder.copy()
conv_block += [hk.Conv2D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(dim)]
conv_block += [hk.Conv2D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001)]

return hk.Sequential(*conv_block)

Expand Down Expand Up @@ -162,7 +170,7 @@ def __call__(self, x):
out = x + self.conv_block(x) # add skip connections
return out


'''
class ResnetGenerator3D(hk.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
Expand Down Expand Up @@ -294,7 +302,6 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,
return hk.Sequential(*conv_block)
def crop(self, x, shape):
'''Center-crop x to match spatial dimensions given by shape.'''
x_target_size = x.size()[:-3] + shape
Expand Down Expand Up @@ -343,3 +350,4 @@ def __init__(self, ndims, **kwargs):
ResnetGenerator3D.__init__(self, **kwargs)
else:
raise ValueError(ndims, 'Only 2D or 3D currently implemented. Feel free to contribute more!')
'''
2 changes: 1 addition & 1 deletion src/raygun/jax/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .UNet import UNet
from .ResidualUNet import ResidualUNet
from .MIRNet2D import MIRNet
from .ResNet import ResNet
from .ResNet import ResnetGenerator2D
from .utils import *
from .NLayerDiscriminator import *
24 changes: 11 additions & 13 deletions src/raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#%%
import raygun
from raygun.jax.networks import ResidualUNet, UNet, NLayerDiscriminator
from raygun.jax.networks import *
import jax
import jax.numpy as jnp
from jax import jit
Expand All @@ -10,8 +9,6 @@
import time
from typing import Tuple, Any, NamedTuple, Dict
from skimage import data
import matplotlib.pyplot as plt
from tqdm import trange


class Params(NamedTuple):
Expand All @@ -35,10 +32,10 @@ def forward(self, inputs):
def train_step(self, inputs, pmapped):
raise RuntimeError("Unimplemented")


# RIP Queen Elizabeth 9/8/22
class JAXModel(GenericJaxModel):

def __init__(self, network_type = 'UNet', learning_rate = 0.5e-4, **net_kwargs): # TODO set up for **kwargs
def __init__(self, learning_rate = 0.5e-4): # TODO set up for **kwargs
super().__init__()
self.learning_rate = learning_rate

Expand All @@ -47,14 +44,15 @@ class MyModel(hk.Module):

def __init__(self, name=None):
super().__init__(name=name)
# self.net = ResidualUNet(
# self.net = UNet(
# ngf=3,
# fmap_inc_factor=2,
# downsample_factors=[[2,2,2],[2,2,2],[2,2,2]]
# )
self.net = NLayerDiscriminator(ndims=2, ngf=3)
# self.net = NLayerDiscriminator3D(ngf=3)
# net = getattr(raygun.jax.networks, network_type)
# self.net = net(net_kwargs)
self.net = ResnetGenerator2D(ngf=3)

def __call__(self, x):
return self.net(x)
Expand All @@ -77,7 +75,7 @@ def _forward(params, inputs):
self.forward = _forward

@jit
def _loss_fn(weight, raw, gt, mask, loss_scale):
def _loss_fn(weight, raw, gt, loss_scale):
pred_affs = self.model.apply(weight, x=raw)
loss = optax.l2_loss(predictions=pred_affs, targets=gt)
# loss = loss*2*mask # optax divides loss by 2 so we mult it back
Expand Down Expand Up @@ -166,8 +164,10 @@ def data_engine(self):
self.inputs = {
'raw': jnp.ones([self.batch_size, 1, 132, 132, 132]),
'gt': jnp.zeros([self.batch_size, 3, 40, 40, 40]),
# 'raw': jnp.ones([16, 1, 512, 512, 512]),
# 'gt': jnp.zeros([16, 1, 512, 512, 512])
}
else:
else: # TODO raw/gt shape are off and creating fmaps which are too small for valid convolutions
gt_import = getattr(data, self.im)()
if len(gt_import.shape) > 2: # Strips to use only one image
gt_import = gt_import[...,0]
Expand Down Expand Up @@ -198,9 +198,7 @@ def train(self) -> None:
for _ in range(self.num_epochs):
t0 = time.time()

self.model_params, outputs, loss = jax.jit(self.model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(self.model_params, self.inputs, False)
self.model_params, outputs, loss = jax.jit(self.model.train_step)(self.model_params, self.inputs)


print(f'Loss: {loss}, took {time.time()-t0}s')
Expand Down

0 comments on commit b5ed009

Please sign in to comment.