Skip to content

Commit

Permalink
Fix UNet to build half model #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 10, 2022
1 parent cf65101 commit 95865d6
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 114 deletions.
1 change: 0 additions & 1 deletion raygun.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ Metadata-Version: 2.1
Name: raygun
Version: 0.0.1
Summary: Library for testing and implementing image denoising, enhancement, and segmentation techniques on large, volumetric datasets. Built with Gunpowder, Daisy, PyTorch, and eventually JAX (hopefully).
Home-page: UNKNOWN
Author: Jeff Rhoades
Author-email: [email protected]
License: AGPL-3.0
Expand Down
3 changes: 2 additions & 1 deletion raygun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
Main RAYGUN package
"""
from .data import *
from .torch import *
from .torch import *
from .jax import *
2 changes: 1 addition & 1 deletion raygun/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
Main PyTorch package
Main JAX package
"""
2 changes: 1 addition & 1 deletion raygun/jax/networks/MIRNet2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
# from pdb import set_trace as stx

from raygun.utils.antialias import Downsample as downsamp
from raygun.jax.utils.antialias import Downsample as downsamp



Expand Down
123 changes: 81 additions & 42 deletions raygun/jax/networks/UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ def __init__(
activation,
padding='VALID',
residual=False,
# padding_mode='reflect',
norm_layer=None,
data_format='NCDHW'):

super().__init__()

# if activation is not None:
# if isinstance(activation, str):
# self.activation = getattr(jax.nn, activation)
# else:
# self.activation = activation # assume activation is a defined function
# else:
# self.activation = jax.numpy.identity
if activation is not None:
if isinstance(activation, str):
self.activation = getattr(jax.nn, activation)
else:
self.activation = activation() # assume activation is a defined function
else:
self.activation = jax.numpy.identity()

activation = getattr(jax.nn, activation)

self.residual = residual
self.padding = padding

layers = []

Expand Down Expand Up @@ -99,7 +99,7 @@ def crop(self, x, shape):

return x[slices]

def forward(self, x):
def __call__(self, x):
if not self.residual:
return self.conv_pass(x)
else:
Expand All @@ -109,7 +109,60 @@ def forward(self, x):
else:
init_x = self.x_init_map(x)
return self.activation(init_x + res)
# class ConvPass(hk.Module):

# def __init__(
# self,
# out_channels,
# kernel_sizes,
# activation,
# padding='VALID',
# data_format='NCDHW'):

# super().__init__()

# if activation is not None:
# activation = getattr(jax.nn, activation)

# layers = []

# for kernel_size in kernel_sizes:

# self.dims = len(kernel_size)

# conv = {
# 2: hk.Conv2D,
# 3: hk.Conv3D,
# # 4: Conv4d # TODO
# }[self.dims]

# if data_format is None:
# in_data_format = {
# 2: 'NCHW',
# 3: 'NCDHW'
# }[self.dims]
# else:
# in_data_format = data_format

# try:
# layers.append(
# conv(
# output_channels=out_channels,
# kernel_shape=kernel_size,
# padding=padding,
# data_format=in_data_format))
# except KeyError:
# raise RuntimeError(
# "%dD convolution not implemented" % self.dims)

# if activation is not None:
# layers.append(activation)

# self.conv_pass = hk.Sequential(layers)

# def __call__(self, x):

# return self.conv_pass(x)

class ConvDownsample(hk.Module):

Expand Down Expand Up @@ -175,7 +228,7 @@ def __init__(
layers.append(self.activation)
self.conv_pass = hk.Sequential(layers)

def forward(self, x):
def __call__(self, x):
return self.conv_pass(x)


Expand All @@ -196,39 +249,26 @@ def __init__(
strides=downsample_factor,
padding='VALID')

# def forward(self, x):
# if self.flexible:
# try:
# return self.down(x)
# except:
# self.check_mismatch(x.size())
# else:
# self.check_mismatch(x.size())
# return self.down(x)

# def check_mismatch(self, size):
# for d in range(1, self.dims+1):
# if size[-d] % self.downsample_factor[-d] != 0:
# raise RuntimeError(
# "Can not downsample shape %s with factor %s, mismatch "
# "in spatial dimension %d" % (
# size,
# self.downsample_factor,
# self.dims - d))
# return self.down(size)
def __call__(self, x):
if self.flexible:
try:
return self.down(x)
except:
self.check_mismatch(x.size())
else:
self.check_mismatch(x.size())
return self.down(x)

for d in range(1, self.dims + 1):
if x.shape[-d] % self.downsample_factor[-d] != 0:
def check_mismatch(self, size):
for d in range(1, self.dims+1):
if size[-d] % self.downsample_factor[-d] != 0:
raise RuntimeError(
"Can not downsample shape %s with factor %s, mismatch "
"in spatial dimension %d" % (
x.shape,
size,
self.downsample_factor,
self.dims - d))

return self.down(x)



class Upsample(hk.Module):

Expand Down Expand Up @@ -357,13 +397,13 @@ def __init__(self,
num_heads=1,
constant_upsample=False,
downsample_method='max',
padding_type='valid',
padding_type='VALID',
residual=False,
norm_layer=None,
# add_noise=False
name=None
):

super().__init__()
super().__init__(name=name)
self.ndims = len(downsample_factors[0])
self.num_levels = len(downsample_factors) + 1
self.num_heads = num_heads
Expand Down Expand Up @@ -411,7 +451,6 @@ def __init__(self,
norm_layer=norm_layer)
for level in range(self.num_levels)
]
self.dims = self.l_conv[0].dims

# Left downsample
if downsample_method.lower() == 'max':
Expand Down Expand Up @@ -499,7 +538,7 @@ def rec_forward(self, level, f_in, total_level):
return fs_out


def forward(self, x):
def __call__(self, x):

y = self.rec_forward(self.num_levels - 1, x, total_level=self.num_levels)

Expand Down
6 changes: 5 additions & 1 deletion raygun/jax/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
JAX network architectures
"""
"""
from .UNet import UNet
from .ResidualUNet import ResidualUNet
from .MIRNet2D import MIRNet
from .utils import *
Loading

0 comments on commit 95865d6

Please sign in to comment.