Skip to content

Commit

Permalink
move membrane net to github
Browse files Browse the repository at this point in the history
  • Loading branch information
jhennies committed Mar 13, 2020
1 parent f543b44 commit 59bfb20
Show file tree
Hide file tree
Showing 19 changed files with 4,035 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

# Pycharm
*.pyc
.idea/
Empty file added experiments/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@

from pytorch.pytorch_tools.data_generation import parallel_data_generator
import os
from h5py import File
from pytorch.pytorch_tools.piled_unets import PiledUnet
from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses
from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data

from torchsummary import summary
import numpy as np


experiment_name = 'unet3d_200311_00_membranes'
results_folder = os.path.join(
'/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel',
'unet3d_200311_pytorch_for_membranes',
experiment_name
)

if True:

raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/'
raw_filepaths = [
[
os.path.join(raw_path, 'raw.h5'),
],
]
gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/'
gt_filepaths = [
[
os.path.join(gt_path, 'gt_mem.h5'),
os.path.join(gt_path, 'gt_mask_organelle_insides.h5')
],
]
raw_channels = []
for volumes in raw_filepaths:
raws_data = []
for chid, channel in enumerate(volumes):
if chid == 1:
# Specifically only load last channel of the membrane prediction
raws_data.append(File(channel, 'r')['data'][..., -1])
else:
raws_data.append(File(channel, 'r')['data'][:])
raw_channels.append(raws_data)
gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths]

val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/'
val_raw_filepaths = [
[
os.path.join(val_raw_path, 'val_raw_512.h5'),
],
[
os.path.join(
'/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes',
'val0_x1390_y742_z345_pad.h5'

)
]
]
val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/'
val_gt_filepaths = [
[
os.path.join(val_gt_path, 'val_gt_mem.h5'),
os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5')
]
]
val_raw_channels = []
for volumes in val_raw_filepaths:
val_raws_data = []
for chid, channel in enumerate(volumes):
if chid == 1:
# Specifically only load last channel of the membrane prediction
val_raws_data.append(File(channel, 'r')['data'][..., -1])
else:
val_raws_data.append(File(channel, 'r')['data'][:])
val_raw_channels.append(val_raws_data)
val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths]

if True:
data_gen_args = dict(
rotation_range=180, # Angle in degrees
shear_range=20, # Angle in degrees
zoom_range=[0.8, 1.2], # [0.75, 1.5]
horizontal_flip=True,
vertical_flip=True,
noise_var_range=1e-1,
random_smooth_range=[0.6, 1.5],
smooth_output_sigma=0.5,
displace_slices_range=2,
fill_mode='reflect',
cval=0,
brightness_range=92,
contrast_range=(0.5, 2)
)

aug_dict_preprocessing = dict(
smooth_output_sigma=0.5
)

train_gen = parallel_data_generator(
raw_channels=raw_channels,
gt_channels=gt_channels,
spacing=(32, 32, 32),
area_size=(32, 512, 512),
target_shape=(64, 64, 64),
gt_target_shape=(64, 64, 64),
stop_after_epoch=False,
aug_dict=data_gen_args,
transform_ratio=0.9,
batch_size=2,
shuffle=True,
add_pad_mask=False,
n_workers=16,
noise_load_dict=None,
gt_target_channels=None,
areas_and_spacings=None,
n_workers_noise=16,
noise_on_channels=None,
yield_epoch_info=True
)

val_gen = parallel_data_generator(
raw_channels=val_raw_channels[:1],
gt_channels=val_gt_channels,
spacing=(64, 64, 64),
area_size=(256, 256, 256),
target_shape=(64, 64, 64),
gt_target_shape=(64, 64, 64),
stop_after_epoch=False,
aug_dict=aug_dict_preprocessing,
transform_ratio=0.,
batch_size=1,
shuffle=False,
add_pad_mask=False,
n_workers=16,
gt_target_channels=None,
yield_epoch_info=True
)

model = PiledUnet(
n_nets=3,
in_channels=1,
out_channels=[1, 1, 1],
filter_sizes_down=(
((16, 32), (32, 64), (64, 128)),
((16, 32), (32, 64), (64, 128)),
((16, 32), (32, 64), (64, 128))
),
filter_sizes_bottleneck=(
(128, 256),
(128, 256),
(128, 256)
),
filter_sizes_up=(
((256, 128, 128), (128, 64, 64), (64, 32, 32)),
((256, 128, 128), (128, 64, 64), (64, 32, 32)),
((256, 128, 128), (128, 64, 64), (64, 32, 32))
),
batch_norm=True,
output_activation='sigmoid'
)
model.cuda()
summary(model, (1, 64, 64, 64))

if not os.path.exists(results_folder):
os.mkdir(results_folder)

train_model_with_generators(
model,
train_gen,
val_gen,
n_epochs=100,
loss_func=CombinedLosses(
losses=(
WeightMatrixWeightedBCE(((0.1, 0.9),)),
WeightMatrixWeightedBCE(((0.2, 0.8),)),
WeightMatrixWeightedBCE(((0.3, 0.7),))),
y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]),
y_true_channels=(np.s_[:], np.s_[:], np.s_[:]),
weigh_losses=np.array([0.2, 0.3, 0.5])
),
l2_reg_param=1e-5,
callbacks=[
cb_run_model_on_data(
results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'),
raw_channels=val_raw_channels[:1],
spacing=(32, 32, 32),
area_size=(64, 256, 256),
target_shape=(64, 64, 64),
num_result_channels=3,
smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'],
n_workers=16,
compute_empty_volumes=True,
thresh=None,
write_at_area=False,
offset=None,
full_dataset_shape=None,
min_epoch=5
),
cb_run_model_on_data(
results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'),
raw_channels=val_raw_channels[1:],
spacing=(32, 32, 32),
area_size=(64, 256, 256),
target_shape=(64, 64, 64),
num_result_channels=3,
smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'],
n_workers=16,
compute_empty_volumes=True,
thresh=None,
write_at_area=False,
offset=None,
full_dataset_shape=None,
min_epoch=5
),
cb_save_model(
filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'),
min_epoch=5
)
],
writer_path=os.path.join(results_folder, 'tb')
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

from pytorch.pytorch_tools.data_generation import parallel_data_generator
import os
from h5py import File
from pytorch.pytorch_tools.piled_unets import PiledUnet
from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses
from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data
import torch as t

from torchsummary import summary
import numpy as np
from pytorch.pytorch_tools.run_models import predict_model_from_h5_parallel_generator
from glob import glob


experiment_name = 'unet3d_200311_00_membranes'
results_folder = os.path.join(
'/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel',
'unet3d_200311_pytorch_for_membranes',
experiment_name
)

model = PiledUnet(
n_nets=3,
in_channels=1,
out_channels=[1, 1, 1],
filter_sizes_down=(
((16, 32), (32, 64), (64, 128)),
((16, 32), (32, 64), (64, 128)),
((16, 32), (32, 64), (64, 128))
),
filter_sizes_bottleneck=(
(128, 256),
(128, 256),
(128, 256)
),
filter_sizes_up=(
((256, 128, 128), (128, 64, 64), (64, 32, 32)),
((256, 128, 128), (128, 64, 64), (64, 32, 32)),
((256, 128, 128), (128, 64, 64), (64, 32, 32))
),
batch_norm=True,
output_activation='sigmoid',
predict=True
)
model.cuda()
summary(model, (1, 64, 64, 64))

model.load_state_dict(t.load(os.path.join(results_folder, 'model_0073.h5')))

if not os.path.exists(results_folder):
os.mkdir(results_folder)

aug_dict_preprocessing = dict(
smooth_output_sigma=0.5
)

im_list = sorted(glob(os.path.join(
'/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/all_datasets_test_samples',
'*.h5'
)))

with t.no_grad():
for filepath in im_list:
with File(filepath, mode='r') as f:
area_size = list(f['data'].shape)
if area_size[0] > 256:
area_size[0] = 256
if area_size[1] > 256:
area_size[1] = 256
if area_size[2] > 256:
area_size[2] = 256
channels = [[f['data'][:]]]

predict_model_from_h5_parallel_generator(
model=model,
results_filepath=os.path.join(results_folder, os.path.split(filepath)[1]),
raw_channels=channels,
spacing=(32, 32, 32),
area_size=area_size,
target_shape=(64, 64, 64),
num_result_channels=1,
smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'],
n_workers=16,
compute_empty_volumes=True,
thresh=None,
write_at_area=False,
offset=None,
full_dataset_shape=None
)
Loading

0 comments on commit 59bfb20

Please sign in to comment.