Skip to content

Commit

Permalink
Finish initial JAX NLayerDiscrim build #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 17, 2022
1 parent 31df822 commit a850a3e
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions raygun/jax/networks/NLayerDiscriminator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math
import numpy as np
import jax
import haiku as hk
import functools
Expand Down Expand Up @@ -28,28 +26,28 @@ def __init__(self, output_nc=1, ngf=64, n_layers=3, norm_layer=hk.BatchNorm,

padw = "VALID"
ds_kw = downsampling_kw
sequence = [hk.Conv2D(output_nc, kernel_size=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu(0.2, True)]
sequence = [hk.Conv2D(output_channels=ngf, kernel_size=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
# nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
hk.Conv2D(ngf * nf_mult_prev, kernel_size=ds_kw, stride=2, padding=padw, with_bias=True),
hk.Conv2D(output_channels=ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, with_bias=True),
norm_layer(ngf * nf_mult),
jax.nn.leaky_relu(0.2, True)
]

nf_mult_prev = nf_mult
# nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
torch.nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
hk.Conv2d(ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=True),
norm_layer(ngf * nf_mult),
torch.nn.LeakyReLU(0.2, True)
jax.nn.leaky_relu(0.2, True)
]

sequence += [torch.nn.Conv2d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = torch.nn.Sequential(*sequence)
sequence += [hk.Conv2d(1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = hk.Sequential(*sequence)

@property
def FOV(self):
Expand All @@ -71,15 +69,15 @@ def FOV(self):

return r

def forward(self, input):
def __cal__(self, input):
"""Standard forward."""
return self.model(input)


class NLayerDiscriminator3D(torch.nn.Module):
class NLayerDiscriminator3D(hk.Module):
"""Defines a PatchGAN discriminator"""

def __init__(self, input_nc=1, ngf=64, n_layers=3, norm_layer=torch.nn.BatchNorm3d,
def __init__(self, input_nc=1, ngf=64, n_layers=3, norm_layer=hk.BatchNorm,
kw=4, downsampling_kw=None,
):
"""Construct a PatchGAN discriminator
Expand All @@ -90,40 +88,40 @@ def __init__(self, input_nc=1, ngf=64, n_layers=3, norm_layer=torch.nn.BatchNorm
norm_layer -- normalization layer
"""
super().__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm3d has affine parameters
use_bias = norm_layer.func == torch.nn.InstanceNorm3d
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == hk.InstanceNorm
else:
use_bias = norm_layer == torch.nn.InstanceNorm3d
use_bias = norm_layer == hk.InstanceNorm

if downsampling_kw is None:
downsampling_kw = kw

padw = 1
padw = "VALID"
ds_kw = downsampling_kw
sequence = [torch.nn.Conv3d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), torch.nn.LeakyReLU(0.2, True)]
sequence = [hk.Conv3D(output_channels=ngf, kernel_size=ds_kw, stride=2, padding=padw), jax.nn.leaky_relu(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
# nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
torch.nn.Conv3d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, bias=use_bias),
hk.Conv3D(output_channels=ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, with_bias=True),
norm_layer(ngf * nf_mult),
torch.nn.LeakyReLU(0.2, True)
jax.nn.leaky_relu(0.2, True)
]

nf_mult_prev = nf_mult
# nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
torch.nn.Conv3d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
hk.Conv3D(ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=True),
norm_layer(ngf * nf_mult),
torch.nn.LeakyReLU(0.2, True)
jax.nn.leaky_relu(0.2, True)
]

sequence += [torch.nn.Conv3d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = torch.nn.Sequential(*sequence)
sequence += [hk.Conv3D(1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = hk.Sequential(*sequence)

def forward(self, input):
def __call__(self, input):
"""Standard forward."""
return self.model(input)

Expand Down

0 comments on commit a850a3e

Please sign in to comment.