From c317d83b9403a306d6c67880a670ac808283a752 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 19 Sep 2022 13:10:53 -0400 Subject: [PATCH] Debugging of ResNet #13 --- src/raygun/jax/networks/ResNet.py | 58 +++++++++++++----------- src/raygun/jax/networks/__init__.py | 2 +- src/raygun/jax/tests/network_test_jax.py | 8 ++-- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/raygun/jax/networks/ResNet.py b/src/raygun/jax/networks/ResNet.py index f36340d9..9291adc2 100644 --- a/src/raygun/jax/networks/ResNet.py +++ b/src/raygun/jax/networks/ResNet.py @@ -14,7 +14,7 @@ def __init__(self, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=Fa """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images + output_nc (int) -- the number of channels in output imagesf ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers @@ -82,7 +82,7 @@ def __init__(self, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=Fa with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), activation] - model += padder.copy() + # model += padder.copy() model += [hk.Conv2D(output_nc, kernel_shape=7, padding=p)] model += [jax.nn.tanh] @@ -141,7 +141,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, key = jax.random.PRNGKey(22) conv_block += [hk.dropout(key, 0.2)] # TODO - conv_block += padder.copy() + # conv_block += padder.copy() conv_block += [hk.Conv2D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001)] return hk.Sequential(*conv_block) @@ -170,13 +170,13 @@ def __call__(self, x): out = x + self.conv_block(x) # add skip connections return out -''' + class ResnetGenerator3D(hk.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ - def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='REFLECT', activation=jax.nn.relu, add_noise=False, n_downsampling=2): + def __init__(self, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='VALID', activation=jax.nn.relu, add_noise=False, n_downsampling=2): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images @@ -203,29 +203,30 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use # if padding_type.lower() == 'reflect': # TODO JAX parallel? # padder = [torch.nn.ReflectionPad3d(3)] if padding_type.upper() == 'REPLICATE': - padder = [hk.pad.same(3)] + # padder = [hk.pad.same(3)] + pass elif padding_type.upper() == 'ZEROS': - p = 3 + # p = 3 + pass elif padding_type.upper() == 'VALID': p = 'VALID' - updown_p = 0 # TODO + updown_p = 'VALID' # TODO model = [] - model += padder.copy() model += [hk.Conv3D(ngf, kernel_shape=7, padding=p, with_bias=use_bias), - norm_layer(ngf), - activation()] + norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), + activation] for i in range(n_downsampling): # add downsampling layers mult = 2 ** i - model += [hk.Conv3D(ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias), #TODO: Make actually use padding_type for every convolution (currently does zeros if not valid) - norm_layer(ngf * mult * 2), - activation()] + model += [hk.Conv3D(output_channels=ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias), #TODO: Make actually use padding_type for every convolution (currently does zeros if not valid) + norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), + activation] mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks - model += [ResnetBlock3D(ngf * mult, padding_type=padding_type.upper(), norm_layer=norm_layer, use_dropout=use_dropout, with_bias=use_bias, activation=activation)] + model += [ResnetBlock3D(dim=ngf * mult, padding_type=p, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)] if add_noise: model += [NoiseBlock()] @@ -233,15 +234,15 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=hk.BatchNorm, use for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) model += [hk.Conv3DTranspose( - int(ngf * mult / 2), + output_channels=int(ngf * mult / 2), kernel_shape=3, stride=2, padding=updown_p, with_bias=use_bias), - norm_layer(int(ngf * mult / 2)), - activation()] - model += padder.copy() - model += [hk.Conv3D(output_nc, kernel_shape=7, padding=p)] - model += [jax.nn.tanh()] + norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), + activation] + # model += padder.copy() + model += [hk.Conv3D(output_channels=output_nc, kernel_shape=7, padding=p)] + model += [jax.nn.tanh] self.model = hk.Sequential(*model) @@ -280,9 +281,11 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, # if padding_type == 'reflect': # padder = [torch.nn.ReflectionPad3d(1)] if padding_type.upper() == 'REPLICATE': - padder = [hk.pad.same(1)] + # padder = [hk.pad.same(1)] + pass elif padding_type.upper() == 'ZEROS': - p = 1 + # p = 1 + pass elif padding_type.upper == 'VALID': p = 'VALID' else: @@ -291,13 +294,15 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, conv_block = [] conv_block += padder.copy() - conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(dim), activation()] + conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001), activation] + + print('dropout check') if use_dropout: # TODO key = jax.random.PRNGKey(22) conv_block += [hk.dropout(key, 0.2)] # TODO - conv_block += padder.copy() - conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(dim)] + # conv_block += padder.copy() + conv_block += [hk.Conv3D(dim, kernel_shape=3, padding=p, with_bias=use_bias), norm_layer(create_offset=True, create_scale=True, decay_rate=0.0001)] return hk.Sequential(*conv_block) @@ -350,4 +355,3 @@ def __init__(self, ndims, **kwargs): ResnetGenerator3D.__init__(self, **kwargs) else: raise ValueError(ndims, 'Only 2D or 3D currently implemented. Feel free to contribute more!') -''' \ No newline at end of file diff --git a/src/raygun/jax/networks/__init__.py b/src/raygun/jax/networks/__init__.py index ca9e1d66..d964d9de 100644 --- a/src/raygun/jax/networks/__init__.py +++ b/src/raygun/jax/networks/__init__.py @@ -4,6 +4,6 @@ from .UNet import UNet from .ResidualUNet import ResidualUNet from .MIRNet2D import MIRNet -from .ResNet import ResnetGenerator2D +from .ResNet import ResNet from .utils import * from .NLayerDiscriminator import * \ No newline at end of file diff --git a/src/raygun/jax/tests/network_test_jax.py b/src/raygun/jax/tests/network_test_jax.py index 1f09ee2d..b80783fc 100644 --- a/src/raygun/jax/tests/network_test_jax.py +++ b/src/raygun/jax/tests/network_test_jax.py @@ -32,7 +32,7 @@ def forward(self, inputs): def train_step(self, inputs, pmapped): raise RuntimeError("Unimplemented") -# RIP Queen Elizabeth 9/8/22 + class JAXModel(GenericJaxModel): def __init__(self, learning_rate = 0.5e-4): # TODO set up for **kwargs @@ -44,15 +44,13 @@ class MyModel(hk.Module): def __init__(self, name=None): super().__init__(name=name) - # self.net = UNet( + # self.net = ResidualUNet( # ngf=3, # fmap_inc_factor=2, # downsample_factors=[[2,2,2],[2,2,2],[2,2,2]] # ) # self.net = NLayerDiscriminator3D(ngf=3) - # net = getattr(raygun.jax.networks, network_type) - # self.net = net(net_kwargs) - self.net = ResnetGenerator2D(ngf=3) + self.net = ResNet(ndims=2,ngf=3) def __call__(self, x): return self.net(x)