Skip to content

Commit

Permalink
Debugging of ResNet #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Sep 19, 2022
1 parent b5ed009 commit c317d83
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 33 deletions.
58 changes: 31 additions & 27 deletions src/raygun/jax/networks/ResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=Fa
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
output_nc (int) -- the number of channels in output imagesf
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=Fa
with_bias=use_bias),
norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]
model += padder.copy()
# model += padder.copy()
model += [hk.Conv2D(output_nc, kernel_shape=7, padding=p)]
model += [jax.nn.tanh]

Expand Down Expand Up @@ -141,7 +141,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,
key = jax.random.PRNGKey(22)
conv_block += [hk.dropout(key, 0.2)] # TODO

conv_block += padder.copy()
# conv_block += padder.copy()
conv_block += [hk.Conv2D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001)]

return hk.Sequential(*conv_block)
Expand Down Expand Up @@ -170,13 +170,13 @@ def __call__(self, x):
out = x + self.conv_block(x) # add skip connections
return out

'''

class ResnetGenerator3D(hk.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""

def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='REFLECT', activation=jax.nn.relu, add_noise=False, n_downsampling=2):
def __init__(self, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='VALID', activation=jax.nn.relu, add_noise=False, n_downsampling=2):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
Expand All @@ -203,45 +203,46 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use
# if padding_type.lower() == 'reflect': # TODO JAX parallel?
# padder = [torch.nn.ReflectionPad3d(3)]
if padding_type.upper() == 'REPLICATE':
padder = [hk.pad.same(3)]
# padder = [hk.pad.same(3)]
pass
elif padding_type.upper() == 'ZEROS':
p = 3
# p = 3
pass
elif padding_type.upper() == 'VALID':
p = 'VALID'
updown_p = 0 # TODO
updown_p = 'VALID' # TODO

model = []
model += padder.copy()
model += [hk.Conv3D(ngf, kernel_shape=7, padding=p, with_bias=use_bias),
norm_layer(ngf),
activation()]
norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]

for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [hk.Conv3D(ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias), #TODO: Make actually use padding_type for every convolution (currently does zeros if not valid)
norm_layer(ngf * mult * 2),
activation()]
model += [hk.Conv3D(output_channels=ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias), #TODO: Make actually use padding_type for every convolution (currently does zeros if not valid)
norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]

mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks

model += [ResnetBlock3D(ngf * mult, padding_type=padding_type.upper(), norm_layer=norm_layer, use_dropout=use_dropout, with_bias=use_bias, activation=activation)]
model += [ResnetBlock3D(dim=ngf * mult, padding_type=p, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)]

if add_noise:
model += [NoiseBlock()]

for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [hk.Conv3DTranspose(
int(ngf * mult / 2),
output_channels=int(ngf * mult / 2),
kernel_shape=3, stride=2,
padding=updown_p,
with_bias=use_bias),
norm_layer(int(ngf * mult / 2)),
activation()]
model += padder.copy()
model += [hk.Conv3D(output_nc, kernel_shape=7, padding=p)]
model += [jax.nn.tanh()]
norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001),
activation]
# model += padder.copy()
model += [hk.Conv3D(output_channels=output_nc, kernel_shape=7, padding=p)]
model += [jax.nn.tanh]

self.model = hk.Sequential(*model)

Expand Down Expand Up @@ -280,9 +281,11 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,
# if padding_type == 'reflect':
# padder = [torch.nn.ReflectionPad3d(1)]
if padding_type.upper() == 'REPLICATE':
padder = [hk.pad.same(1)]
# padder = [hk.pad.same(1)]
pass
elif padding_type.upper() == 'ZEROS':
p = 1
# p = 1
pass
elif padding_type.upper == 'VALID':
p = 'VALID'
else:
Expand All @@ -291,13 +294,15 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,
conv_block = []
conv_block += padder.copy()

conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(dim), activation()]
conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), activation]

print('dropout check')
if use_dropout: # TODO
key = jax.random.PRNGKey(22)
conv_block += [hk.dropout(key, 0.2)] # TODO

conv_block += padder.copy()
conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(dim)]
# conv_block += padder.copy()
conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001)]

return hk.Sequential(*conv_block)

Expand Down Expand Up @@ -350,4 +355,3 @@ def __init__(self, ndims, **kwargs):
ResnetGenerator3D.__init__(self, **kwargs)
else:
raise ValueError(ndims, 'Only 2D or 3D currently implemented. Feel free to contribute more!')
'''
2 changes: 1 addition & 1 deletion src/raygun/jax/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .UNet import UNet
from .ResidualUNet import ResidualUNet
from .MIRNet2D import MIRNet
from .ResNet import ResnetGenerator2D
from .ResNet import ResNet
from .utils import *
from .NLayerDiscriminator import *
8 changes: 3 additions & 5 deletions src/raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(self, inputs):
def train_step(self, inputs, pmapped):
raise RuntimeError("Unimplemented")

# RIP Queen Elizabeth 9/8/22

class JAXModel(GenericJaxModel):

def __init__(self, learning_rate = 0.5e-4): # TODO set up for **kwargs
Expand All @@ -44,15 +44,13 @@ class MyModel(hk.Module):

def __init__(self, name=None):
super().__init__(name=name)
# self.net = UNet(
# self.net = ResidualUNet(
# ngf=3,
# fmap_inc_factor=2,
# downsample_factors=[[2,2,2],[2,2,2],[2,2,2]]
# )
# self.net = NLayerDiscriminator3D(ngf=3)
# net = getattr(raygun.jax.networks, network_type)
# self.net = net(net_kwargs)
self.net = ResnetGenerator2D(ngf=3)
self.net = ResNet(ndims=2,ngf=3)

def __call__(self, x):
return self.net(x)
Expand Down

0 comments on commit c317d83

Please sign in to comment.