diff --git a/.gitignore b/.gitignore index b6e4761..0131b77 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ dmypy.json # Pyre type checker .pyre/ + +# Pycharm +*.pyc +.idea/ diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/unet3d_200311_pytorch_for_membranes/__init__.py b/experiments/unet3d_200311_pytorch_for_membranes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200311_00_membranes.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200311_00_membranes.py new file mode 100644 index 0000000..21fed37 --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200311_00_membranes.py @@ -0,0 +1,222 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data + +from torchsummary import summary +import numpy as np + + +experiment_name = 'unet3d_200311_00_membranes' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) + +if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + +if True: + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 512, 512), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=100, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),)), + WeightMatrixWeightedBCE(((0.2, 0.8),)), + WeightMatrixWeightedBCE(((0.3, 0.7),))), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]), + y_true_channels=(np.s_[:], np.s_[:], np.s_[:]), + weigh_losses=np.array([0.2, 0.3, 0.5]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[1:], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'), + min_epoch=5 + ) + ], + writer_path=os.path.join(results_folder, 'tb') +) diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200311_00b_membranes_run_on_test_sets.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200311_00b_membranes_run_on_test_sets.py new file mode 100644 index 0000000..f639a8e --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200311_00b_membranes_run_on_test_sets.py @@ -0,0 +1,90 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data +import torch as t + +from torchsummary import summary +import numpy as np +from pytorch.pytorch_tools.run_models import predict_model_from_h5_parallel_generator +from glob import glob + + +experiment_name = 'unet3d_200311_00_membranes' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid', + predict=True +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +model.load_state_dict(t.load(os.path.join(results_folder, 'model_0073.h5'))) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 +) + +im_list = sorted(glob(os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/all_datasets_test_samples', + '*.h5' +))) + +with t.no_grad(): + for filepath in im_list: + with File(filepath, mode='r') as f: + area_size = list(f['data'].shape) + if area_size[0] > 256: + area_size[0] = 256 + if area_size[1] > 256: + area_size[1] = 256 + if area_size[2] > 256: + area_size[2] = 256 + channels = [[f['data'][:]]] + + predict_model_from_h5_parallel_generator( + model=model, + results_filepath=os.path.join(results_folder, os.path.split(filepath)[1]), + raw_channels=channels, + spacing=(32, 32, 32), + area_size=area_size, + target_shape=(64, 64, 64), + num_result_channels=1, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None + ) diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_00_membranes_epochs300.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_00_membranes_epochs300.py new file mode 100644 index 0000000..353a057 --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_00_membranes_epochs300.py @@ -0,0 +1,223 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data + +from torchsummary import summary +import numpy as np + + +experiment_name = 'unet3d_200312_00_membranes_epochs300' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) + +if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + +if True: + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 512, 512), + # area_size=(32, 128, 128), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=300, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),)), + WeightMatrixWeightedBCE(((0.2, 0.8),)), + WeightMatrixWeightedBCE(((0.3, 0.7),))), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]), + y_true_channels=(np.s_[:], np.s_[:], np.s_[:]), + weigh_losses=np.array([0.2, 0.3, 0.5]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[1:], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'), + min_epoch=5 + ) + ], + writer_path=os.path.join(results_folder, 'tb') +) diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_00b_run_on_test_sets.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_00b_run_on_test_sets.py new file mode 100644 index 0000000..d1790d2 --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_00b_run_on_test_sets.py @@ -0,0 +1,94 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data +import torch as t + +from torchsummary import summary +import numpy as np +from pytorch.pytorch_tools.run_models import predict_model_from_h5_parallel_generator +from glob import glob + + +experiment_name = 'unet3d_200312_00_membranes_epochs300' +net_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) +results_folder = os.path.join( + net_folder, + 'results_0036' +) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid', + predict=True +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +model.load_state_dict(t.load(os.path.join(net_folder, 'model_0036.h5'))) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 +) + +im_list = sorted(glob(os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/all_datasets_test_samples', + '*.h5' +))) + +with t.no_grad(): + for filepath in im_list: + with File(filepath, mode='r') as f: + area_size = list(f['data'].shape) + if area_size[0] > 256: + area_size[0] = 256 + if area_size[1] > 256: + area_size[1] = 256 + if area_size[2] > 256: + area_size[2] = 256 + channels = [[f['data'][:]]] + + predict_model_from_h5_parallel_generator( + model=model, + results_filepath=os.path.join(results_folder, os.path.split(filepath)[1]), + raw_channels=channels, + spacing=(32, 32, 32), + area_size=area_size, + target_shape=(64, 64, 64), + num_result_channels=1, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None + ) diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_01_membranes_epochs100_weigh_with_matrix_sum.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_01_membranes_epochs100_weigh_with_matrix_sum.py new file mode 100644 index 0000000..3b64b7f --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200312_01_membranes_epochs100_weigh_with_matrix_sum.py @@ -0,0 +1,223 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data + +from torchsummary import summary +import numpy as np + + +experiment_name = 'unet3d_200312_01_membranes_epochs100_weigh_with_matrix_sum' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) + +if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + +if True: + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 512, 512), + # area_size=(32, 128, 128), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=100, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),), weigh_with_matrix_sum=True), + WeightMatrixWeightedBCE(((0.2, 0.8),), weigh_with_matrix_sum=True), + WeightMatrixWeightedBCE(((0.3, 0.7),), weigh_with_matrix_sum=True)), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]), + y_true_channels=(np.s_[:], np.s_[:], np.s_[:]), + weigh_losses=np.array([0.2, 0.3, 0.5]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[1:], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'), + min_epoch=5 + ) + ], + writer_path=os.path.join(results_folder, 'tb') +) diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200313_00_membranes_epochs100_weigh_with_matrix_avg_flipz.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200313_00_membranes_epochs100_weigh_with_matrix_avg_flipz.py new file mode 100644 index 0000000..72a949c --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200313_00_membranes_epochs100_weigh_with_matrix_avg_flipz.py @@ -0,0 +1,224 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data + +from torchsummary import summary +import numpy as np + + +experiment_name = 'unet3d_200313_00_membranes_epochs100_weigh_with_matrix_avg_flipz' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) + +if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + +if True: + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + depth_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 512, 512), + # area_size=(32, 128, 128), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=100, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),), weigh_with_matrix_sum=True), + WeightMatrixWeightedBCE(((0.2, 0.8),), weigh_with_matrix_sum=True), + WeightMatrixWeightedBCE(((0.3, 0.7),), weigh_with_matrix_sum=True)), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]), + y_true_channels=(np.s_[:], np.s_[:], np.s_[:]), + weigh_losses=np.array([0.2, 0.3, 0.5]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[1:], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'), + min_epoch=5 + ) + ], + writer_path=os.path.join(results_folder, 'tb') +) diff --git a/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200313_00b_membranes_epochs200_weigh_with_matrix_avg_flipz.py b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200313_00b_membranes_epochs200_weigh_with_matrix_avg_flipz.py new file mode 100644 index 0000000..035d5e5 --- /dev/null +++ b/experiments/unet3d_200311_pytorch_for_membranes/unet3d_200313_00b_membranes_epochs200_weigh_with_matrix_avg_flipz.py @@ -0,0 +1,224 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data + +from torchsummary import summary +import numpy as np + + +experiment_name = 'unet3d_200313_00b_membranes_epochs200_weigh_with_matrix_avg_flipz' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200311_pytorch_for_membranes', + experiment_name +) + +if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + +if True: + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + depth_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 512, 512), + # area_size=(32, 128, 128), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=200, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),), weigh_with_matrix_sum=True), + WeightMatrixWeightedBCE(((0.2, 0.8),), weigh_with_matrix_sum=True), + WeightMatrixWeightedBCE(((0.3, 0.7),), weigh_with_matrix_sum=True)), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]), + y_true_channels=(np.s_[:], np.s_[:], np.s_[:]), + weigh_losses=np.array([0.2, 0.3, 0.5]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[1:], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'), + min_epoch=5 + ) + ], + writer_path=os.path.join(results_folder, 'tb') +) diff --git a/experiments/unet3d_200312_membranes_and_disttransf/__init__.py b/experiments/unet3d_200312_membranes_and_disttransf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/unet3d_200312_membranes_and_disttransf/unet3d_200312_00_mem_and_dt_epochs200.py b/experiments/unet3d_200312_membranes_and_disttransf/unet3d_200312_00_mem_and_dt_epochs200.py new file mode 100644 index 0000000..dfa8fcf --- /dev/null +++ b/experiments/unet3d_200312_membranes_and_disttransf/unet3d_200312_00_mem_and_dt_epochs200.py @@ -0,0 +1,229 @@ + +from pytorch.pytorch_tools.data_generation import parallel_data_generator +import os +from h5py import File +from pytorch.pytorch_tools.piled_unets import PiledUnet +from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, WeightMatrixMSELoss, CombinedLosses +from pytorch.pytorch_tools.training import train_model_with_generators, cb_save_model, cb_run_model_on_data + +from torchsummary import summary +import numpy as np + +from torch.nn import MSELoss + + +experiment_name = 'unet3d_200312_00_mem_and_dt_epochs200' +results_folder = os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/autoseg/cnn_3d_devel', + 'unet3d_200312_membranes_and_disttransf', + experiment_name +) + +if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_dt.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_dt.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + +if True: + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 512, 512), + # area_size=(32, 128, 128), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + +model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 2], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' +) +model.cuda() +summary(model, (1, 64, 64, 64)) + +if not os.path.exists(results_folder): + os.mkdir(results_folder) + +train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=300, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),)), + WeightMatrixWeightedBCE(((0.2, 0.8),)), + WeightMatrixWeightedBCE(((0.3, 0.7),)), + WeightMatrixMSELoss() + ), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3], np.s_[3:4]), + y_true_channels=((0, 2), (0, 2), (0, 2), (1, 2)), + weigh_losses=np.array([0.15, 0.25, 0.3, 0.3]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result1_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=4, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[1:], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=4, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None, + min_epoch=5 + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5'), + min_epoch=5 + ) + ], + writer_path=os.path.join(results_folder, 'tb') +) diff --git a/pytorch_tools/__init__.py b/pytorch_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_tools/data_generation.py b/pytorch_tools/data_generation.py new file mode 100644 index 0000000..bce3939 --- /dev/null +++ b/pytorch_tools/data_generation.py @@ -0,0 +1,1266 @@ + +import os +from h5py import File +import numpy as np +import scipy.ndimage as ndi +from skimage.util import random_noise +from skimage.transform import downscale_local_mean +from scipy.ndimage import zoom +import sys + +from concurrent.futures import ThreadPoolExecutor +from multiprocessing.pool import ThreadPool + + +def _build_equally_spaced_volume_list( + spacing, + area_size, + n_volumes, + transform_ratio, + set_volume=None +): + + # Components + spacing = spacing + half_area_size = (np.array(area_size) / 2).astype(int) + + # This generates the list of all positions, equally spaced and centered around zero + mg = np.mgrid[ + -half_area_size[0]: half_area_size[0]: spacing[0], + -half_area_size[1]: half_area_size[1]: spacing[1], + -half_area_size[2]: half_area_size[2]: spacing[2] + ].squeeze() + mg[0] -= int((mg[0].max() + mg[0].min()) / 2) + mg[1] -= int((mg[1].max() + mg[1].min()) / 2) + mg[2] -= int((mg[2].max() + mg[2].min()) / 2) + mg = mg.reshape(3, np.prod(np.array(mg.shape)[1:])) + positions = mg.swapaxes(0, 1) + + n_transform = int(n_volumes * len(positions) * transform_ratio) + transform = [True] * n_transform + [False] * (n_volumes * len(positions) - n_transform) + np.random.shuffle(transform) + + index_array = [] + + idx = 0 + for volume in range(n_volumes): + for position in positions: + + if set_volume: + index_array.append( + [ + position, + set_volume, # Always volume 0 + transform[idx] + ] + ) + else: + index_array.append( + [ + position, + volume, # Always volume 0 + transform[idx] + ] + ) + idx += 1 + + print('Equally spaced volumes:') + print(' Total samples: {}'.format(len(positions) * n_volumes)) + print(' Volumes: {}'.format(n_volumes)) + print(' Transformed samples: {}'.format(n_transform)) + print('Actual size of index_array: {}'.format(len(index_array))) + + return index_array + + +def _find_bounds(position, crop_shape, full_shape): + + position = np.array(position) + crop_shape = np.array(crop_shape) + full_shape = np.array(full_shape) + + # Start and stop in full volume (not considering volume boundaries) + start = (position - crop_shape / 2 + full_shape / 2).astype('int16') + stop = start + crop_shape + + # Checking for out of bounds + start_corrected = start.copy() + start_corrected[start < 0] = 0 + start_oob = start_corrected - start + stop_corrected = stop.copy() + stop_corrected[stop > full_shape] = full_shape[stop > full_shape] + stop_oob = stop - stop_corrected + + # Making slicings ... + # ... where to take the data from in the full shape ... + s_source = np.s_[ + start_corrected[0]: stop_corrected[0], + start_corrected[1]: stop_corrected[1], + start_corrected[2]: stop_corrected[2] + ] + # ... and where to put it into the crop + s_target = np.s_[ + start_oob[0]: crop_shape[0] - stop_oob[0], + start_oob[1]: crop_shape[1] - stop_oob[1], + start_oob[2]: crop_shape[2] - stop_oob[2] + ] + + return s_source, s_target + + +def _load_data_with_padding( + channels, + position, + target_shape, + auto_pad=False, + return_pad_mask=False, + return_shape_only=False +): + source_shape = np.array(channels[0].shape) + + shape = np.array(target_shape) + if auto_pad: + shape[1:] = np.ceil(np.array(target_shape[1:]) * np.sqrt(2) / 2).astype(int) * 2 + + if return_shape_only: + return shape.tolist() + [len(channels)] + + s_source, s_target = _find_bounds(position, shape, source_shape) + + # Defines the position of actual target data within the padded data + pos_in_pad = ((shape - target_shape) / 2).astype(int) + s_pos_in_pad = np.s_[pos_in_pad[0]: pos_in_pad[0] + target_shape[0], + pos_in_pad[1]: pos_in_pad[1] + target_shape[1], + pos_in_pad[2]: pos_in_pad[2] + target_shape[2]] + + x = [] + for cid, channel in enumerate(channels): + # Load the data according to the definitions above + + vol_pad = np.zeros(shape, dtype=channel.dtype) + + vol_pad[s_target] = channel[s_source] + x.append(vol_pad[..., None]) + + if return_pad_mask: + pad_mask = np.zeros(x[0].shape, dtype=channels[0].dtype) + pad_mask[s_target] = 255 + x.append(pad_mask) + + x = np.concatenate(x, axis=3) + + return x, s_pos_in_pad + + +def _load_data_with_padding_old( + channels, + position, + target_shape, + auto_pad=False, + return_pad_mask=False, + return_shape_only=False, + downsample_output=1 +): + + source_shape = np.array(channels[0].shape) + + shape = np.array(target_shape) * downsample_output + if auto_pad: + shape[1:] = np.ceil(np.array(target_shape[1:]) * np.sqrt(2) / 2).astype(int) * 2 * downsample_output + + # These are used to load the data with zero padding if necessary + start_pos = (position - (shape / 2) + (source_shape / 2)).astype('int16') + stop_pos = start_pos + shape + start_out_of_bounds = np.zeros(start_pos.shape, dtype='int16') + start_out_of_bounds[start_pos < 0] = start_pos[start_pos < 0] + stop_out_of_bounds = stop_pos - source_shape + stop_out_of_bounds[stop_out_of_bounds < 0] = 0 + start_pos[start_pos < 0] = 0 + stop_pos[stop_out_of_bounds > 0] = source_shape[stop_out_of_bounds > 0] + + if return_shape_only: + return (stop_pos[0] + stop_out_of_bounds[0] - start_pos[0] - start_out_of_bounds[0], + stop_pos[1] + stop_out_of_bounds[1] - start_pos[1] - start_out_of_bounds[1], + stop_pos[2] + stop_out_of_bounds[2] - start_pos[2] - start_out_of_bounds[2], len(channels)) + + s_source = np.s_[ + start_pos[0]:stop_pos[0], + start_pos[1]:stop_pos[1], + start_pos[2]:stop_pos[2] + ] + s_target = np.s_[ + stop_out_of_bounds[0]:stop_pos[0] + stop_out_of_bounds[0] - start_pos[0], + stop_out_of_bounds[1]:stop_pos[1] + stop_out_of_bounds[1] - start_pos[1], + stop_out_of_bounds[2]:stop_pos[2] + stop_out_of_bounds[2] - start_pos[2], + ] + + # Defines the position of actual target data within the padded data + pos_in_pad = ((shape / downsample_output - target_shape) / 2).astype(int) + s_pos_in_pad = np.s_[pos_in_pad[0]: pos_in_pad[0] + target_shape[0], + pos_in_pad[1]: pos_in_pad[1] + target_shape[1], + pos_in_pad[2]: pos_in_pad[2] + target_shape[2]] + + x = [] + for cid, channel in enumerate(channels): + # Load the data according to the definitions above + + vol_pad = np.zeros(shape, dtype=channel.dtype) + + vol_pad[s_target] = channel[s_source] + x.append(vol_pad[..., None]) + + if return_pad_mask: + pad_mask = np.zeros(x[0].shape, dtype=channels[0].dtype) + pad_mask[s_target] = 255 + x.append(pad_mask) + + x = np.concatenate(x, axis=3) + # if merge_output is not None and x.shape[3] > 1: + # if merge_output == 'max': + # x = np.max(x, axis=3)[..., None] + # else: + # raise NotImplementedError + + # if downsample_output > 1: + # # x = downscale_local_mean(x, (downsample_output,) * 3 + (1,)).astype(dtype=channels[0].dtype) + # x = zoom(x, (1 / downsample_output,) * 3 + (1,)) + # # x = np.zeros(shape, dtype='uint8') + + return x, s_pos_in_pad + + +def apply_transform(x, + transform_matrix, + channel_axis=0, + fill_mode='nearest', + cval=0., + ndim=3): + """Apply the image transformation specified by a matrix. + + # Arguments + x: 2D numpy array, single image. + transform_matrix: Numpy array specifying the geometric transformation. + channel_axis: Index of axis for channels in the input tensor. + fill_mode: Points outside the boundaries of the input + are filled according to the given mode + (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). + cval: Value used for points outside the boundaries + of the input if `mode='constant'`. + + # Returns + The transformed version of the input. + """ + x = np.rollaxis(x, channel_axis, 0) + final_affine_matrix = transform_matrix[:ndim, :ndim] + final_offset = transform_matrix[:ndim, ndim] + channel_images = [ndi.interpolation.affine_transform( + x_channel, + final_affine_matrix, + final_offset, + order=1, + mode=fill_mode, + cval=cval) for x_channel in x] + x = np.stack(channel_images, axis=0) + x = np.rollaxis(x, 0, channel_axis + 1) + return x + + +def smooth_output(x, smooth_output_sigma): + + for ch in range(x.shape[3]): + x[ch] = ndi.gaussian_filter(x[ch], smooth_output_sigma) + + return x + + +def preprocessing(x, smooth_output_sigma): + + x = smooth_output(x, smooth_output_sigma) + + return x + + +def random_displace_slices(x, displace_slices, fill_mode, cval): + + tx = displace_slices[0] + ty = displace_slices[1] + + if tx > 0 or ty > 0: + + img_channel_axis = 2 + + new_x = [] + + for slc in x: + + shift_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + + new_x.append(apply_transform(slc, shift_matrix, img_channel_axis, + fill_mode=fill_mode, cval=cval, ndim=2)) + + x = np.array(new_x) + + return x + + +def add_random_noise(x, noise): + + if noise is not None: + + noisy = x.astype('float32') + (noise.astype('float32') - 128) + noisy[noisy > 255] = 255 + noisy[noisy < 0] = 0 + + # # Does not multithread, so make sure not to add noise to large data samples + # noisy = random_noise( + # x[crop].astype('float32') / 255, + # mode='gaussian', + # seed=seed, + # clip=True, + # mean=0, + # var=var + # ) + + x = noisy.astype('uint8') + + return x + + +def random_brightness_contrast(x, brightness, contrast): + + if brightness > 0 or contrast > 0: + + x = (x.astype('float32') - 128) * contrast + 128 + x += brightness + + x[x > 255] = 255 + x[x < 0] = 0 + + x = x.astype('uint8') + + return x + + +def random_smooth(x, random_smooth): + """ + + a: angle + s_0: sigma + s_1: sigma + + exp (-( ((x * sin(a) - y * cos(a))^2/(2*s_0^2)) + ((x * cos(a) + y * sin(a))^2/(2*s_1^2)) )) + + :param x: + :param random_smooth_range: (s_0, s_1) + :param seed: + :return: + """ + + # self.random_smooth_s0_range = 0.3 + # self.random_smooth_s1_range = 1.5 + a = random_smooth[0] + s_0 = random_smooth[1] + s_1 = random_smooth[2] + + if s_0 > 0 or s_1 > 0: + + mx, my = np.mgrid[-4:5, -4:5] + + kernel = np.exp(-(((mx * np.sin(a) - my * np.cos(a)) ** 2 / (2 * s_0 ** 2)) + ((mx * np.cos(a) + my * np.sin(a)) ** 2 / (2 * s_1 ** 2)))) + # Like this, the kernel will only work on the x-y planes + kernel = kernel[None, :, :, None] + kernel /= kernel.sum() + + x = ndi.filters.convolve(x, kernel) + + return x + + +def random_transform( + x, + rotation, + shear, + zoom, + horizontal_flip, + vertical_flip, + depth_flip, + fill_mode, + cval +): + """Randomly augment a single image tensor. + + # Arguments + x: 3D tensor, single image. + seed: random seed. + + # Returns + A randomly transformed version of the input (same shape). + """ + + def flip_axis(x, axis): + x = np.asarray(x).swapaxes(axis, 0) + x = x[::-1, ...] + x = x.swapaxes(0, axis) + return x + + def transform_matrix_offset_center(matrix, x, y): + o_x = float(x) / 2 + 0.5 + o_y = float(y) / 2 + 0.5 + offset_matrix = np.array([[1, 0, 0, 0], + [0, 1, 0, o_x], + [0, 0, 1, o_y], + [0, 0, 0, 1]]) + reset_matrix = np.array([[1, 0, 0, 0], + [0, 1, 0, -o_x], + [0, 0, 1, -o_y], + [0, 0, 0, 1]]) + transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) + return transform_matrix + + # x is a single image, so it doesn't have image number at index 0 + img_x_axis = 2 + img_y_axis = 1 + img_z_axis = 0 + img_channel_axis = 3 + + # use composition of homographies + # to generate final transform that needs to be applied + theta = rotation + + zx = zoom[0] + zy = zoom[1] + + # Building the transform matrix + transform_matrix = None + # if theta != 0: + rotation_matrix = np.array([[1, 0, 0, 0], + [0, np.cos(theta), -np.sin(theta), 0], + [0, np.sin(theta), np.cos(theta), 0], + [0, 0, 0, 1]]) + transform_matrix = rotation_matrix + + # if shear != 0: + shear_matrix = np.array([[1, 0, 0, 0], + [0, 1, -np.sin(shear), 0], + [0, 0, np.cos(shear), 0], + [0, 0, 0, 1]]) + transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix) + + # if zx != 1 or zy != 1: + zoom_matrix = np.array([[1, 0, 0, 0], + [0, zy, 0, 0], + [0, 0, zx, 0], + [0, 0, 0, 1]]) + transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix) + + if transform_matrix is not None: + h, w = x.shape[img_x_axis], x.shape[img_y_axis] + transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) + x = apply_transform(x, transform_matrix, img_channel_axis, + fill_mode=fill_mode, cval=cval) + + if horizontal_flip: + x = flip_axis(x, img_x_axis) + + if vertical_flip: + x = flip_axis(x, img_y_axis) + + if depth_flip: + x = flip_axis(x, img_z_axis) + + return x + + +default_aug_dict = dict( + noise_var_range=0., + random_smooth_range=(0, 0), + displace_slices_range=0, + smooth_output_sigma=0., + rotation_range=0., + shear_range=0., + zoom_range=(0., 0.), + horizontal_flip=False, + vertical_flip=False, + depth_flip=False, + fill_mode='nearest', + cval=0., + brightness_range=0., + contrast_range=None +) + + +def _prep_channel(ch, channel_definitions): + dtype = ch[0].dtype + + new_new_channels = [] + + print('Preparing channels ...') + + for gtc in channel_definitions: + # print('Preparing channels {}'.format(gtc['channel'])) + + if type(gtc['channel'][0]) == str: + if gtc['channel'][0] == 'inv': + data = [ch[idx] for idx in gtc['channel'][1:]] + data = 255 - np.sum(data, axis=0) + else: + raise ValueError + else: + data = [ch[idx] for idx in gtc['channel']] + data = np.max(data, axis=0) + data = downscale_local_mean(data, (gtc['downsample'],) * 3).astype(dtype) + if dtype == 'float32': + data[data < -0.5] = -1 + data[np.logical_and(data >= -0.5, data < 0.5)] = 0 + data[data >= 0.5] = 1 + elif dtype == 'uint8': + data[data < 128] = 0 + data[data >= 128] = 255 + else: + raise NotImplementedError + + new_new_channels.append(data) + + return new_new_channels + + +def _prepare_channels(channels, channel_definitions, n_workers=1): + if n_workers == 1: + + new_channels = [] + + for ch in channels: + new_new_channels = _prep_channel(ch, channel_definitions) + + new_channels.append(new_new_channels) + + else: + + with ThreadPoolExecutor(max_workers=n_workers) as tpe: + tasks = [ + tpe.submit(_prep_channel, ch, channel_definitions) + for ch in channels + ] + + new_channels = [task.result() for task in tasks] + + return new_channels + + +def _pre_load_data( + index_array, + raw_channels, + gt_channels, + target_size, + gt_target_size, + add_pad_mask +): + + xs = [] + ys = [] + s_pads_x = [] + s_pads_y = [] + + # Iterate over the positions + for idx, (position, volume, transform) in enumerate(index_array): + + # Load data + x, s_pad_x = _load_data_with_padding(raw_channels[volume], position, target_size, auto_pad=transform) + + if gt_channels is not None: + + # Load data + y, s_pad_y = _load_data_with_padding(gt_channels[volume], position, gt_target_size, + auto_pad=transform, + return_pad_mask=add_pad_mask) + + ys.append(y) + s_pads_y.append(s_pad_y) + xs.append(x) + s_pads_x.append(s_pad_x) + + if not ys: + ys = None + s_pads_y = None + + return xs, ys, s_pads_x, s_pads_y + + +import cv2 +def _get_random_args(aug_dict, shape, noise_load_dict=None, noise_on_channel=None): + """ + :param aug_dict: + :param shape: + :param noise_load_dict: + dict( + filepath='/path/to/noise_file', + size=number_of_elements + ) + :return: + """ + + def _load_noise_from_data(): + # Randomly select chunk position + pos = int(np.random.uniform(0, noise_load_dict['size'] - np.prod(shape))) + # Load the data + noise = noise_load_dict['data'][pos: pos + np.prod(shape)] + # Reshape to match the images + noise = np.reshape(noise, shape) + # Get the proper standard variation + var = np.random.uniform(0, aug_dict['noise_var_range']) + noise *= (var ** 0.5) + return noise + + # FIXME this is still the major bottleneck + if aug_dict['noise_var_range'] > 0: + if noise_load_dict is not None: + if 'data' not in noise_load_dict or noise_load_dict['data'] is None: + print('Trying to load some noise ...') + if os.path.exists(noise_load_dict['filepath']): + with File(noise_load_dict['filepath'], mode='r') as f: + noise_load_dict['data'] = f['data'][:] + else: + print('Noise file does not exist, creating it now ... This may take a while ...') + print('Generating a lot of noise ...') + noise_load_dict['data'] = np.random.normal(0, 1, (noise_load_dict['size'],)) + print('Make some noise!!!') + with File(noise_load_dict['filepath'], mode='w') as f: + f.create_dataset('data', data=noise_load_dict['data']) + + noise = _load_noise_from_data() + + else: + var = np.random.uniform(0, aug_dict['noise_var_range']) + # noise = np.random.normal(0, var ** 0.5, shape) + if noise_on_channel is None: + im = np.zeros((np.prod(shape),)) + noise = cv2.randn(im, 0, var ** 0.5) + noise = (np.reshape(noise, shape) * 127 + 128).astype('uint8') + else: + noise = np.ones(shape, dtype='uint8') * 128 + for ch in noise_on_channel: + n_im = np.zeros((int(np.prod(shape) / shape[3]),)) + n_im = cv2.randn(n_im, 0, var ** 0.5) + n_im = np.reshape(n_im, shape[:3]) + noise[..., ch] = (n_im * 127 + 128).astype('uint8') + else: + noise = None + + # print('Noise.shape = {}'.format(noise.shape)) + + random_smoothing = [0, 0, 0] + random_smoothing[0] = np.random.uniform(0, 1) * np.pi + if aug_dict['random_smooth_range'][0] > 0: + random_smoothing[1] = np.random.uniform(0, aug_dict['random_smooth_range'][0]) + if aug_dict['random_smooth_range'][1] > 0: + random_smoothing[2] = np.random.uniform(0, aug_dict['random_smooth_range'][1]) + + displace_slices = [0, 0] + if aug_dict['displace_slices_range'] > 0: + displace_slices[0] = np.random.uniform(-aug_dict['displace_slices_range'], aug_dict['displace_slices_range']) + displace_slices[1] = np.random.uniform(-aug_dict['displace_slices_range'], aug_dict['displace_slices_range']) + + brightness = 0 + if aug_dict['brightness_range'] > 0: + brightness = np.random.uniform(-aug_dict['brightness_range'], aug_dict['brightness_range']) + contrast = 0 + if aug_dict['contrast_range']: + contrast = np.random.uniform(aug_dict['contrast_range'][0], aug_dict['contrast_range'][1]) + + rotation = 0 + if aug_dict['rotation_range']: + rotation = np.deg2rad(np.random.uniform(-aug_dict['rotation_range'], aug_dict['rotation_range'])) + + shear = 0 + if aug_dict['shear_range'] > 0: + shear = np.deg2rad(np.random.uniform(-aug_dict['shear_range'], aug_dict['shear_range'])) + + zoom = [1, 1] + if aug_dict['zoom_range'][0] != 1 and aug_dict['zoom_range'][1] != 1: + zoom = list(np.random.uniform(aug_dict['zoom_range'][0], aug_dict['zoom_range'][1], 2)) + + horizontal_flip = False + if aug_dict['horizontal_flip']: + horizontal_flip = np.random.random() < 0.5 + + vertical_flip = False + if aug_dict['vertical_flip']: + vertical_flip = np.random.random() < 0.5 + + depth_flip = False + if aug_dict['depth_flip']: + depth_flip = np.random.random() < 0.5 + + return dict( + noise=noise, + random_smooth=random_smoothing, + displace_slices=displace_slices, + rotation=rotation, + shear=shear, + zoom=zoom, + horizontal_flip=horizontal_flip, + vertical_flip=vertical_flip, + depth_flip=depth_flip, + brightness=brightness, + contrast=contrast + ) + + +def _pre_generate_random_values(raw_channels, aug_dict, transform, volume, position, target_shape, noise_load_dict, noise_on_channel, idx): + + # print('Noise for {}'.format(idx)) + + if transform: + shape = _load_data_with_padding( + channels=raw_channels[volume], + position=position, + target_shape=target_shape, + auto_pad=transform, + return_shape_only=True + ) + random_args = _get_random_args(aug_dict, shape, noise_load_dict=noise_load_dict, noise_on_channel=noise_on_channel) + else: + random_args = None + + return random_args + + +def _get_batches_of_transformed_samples( + index_array, + raw_channels, + gt_channels, + target_size, + gt_target_size, + gt_target_channels=None, + yield_xyz=False, + aug_dict=default_aug_dict, + aug_args=None, + add_pad_mask=False, + batch_no=None, + batches=None +): + """ + The steps are: + 1. Simulate bad imaging quality (raw channels, transformed only) + a) random noise + b) random smooth + 2. Smooth the input (raw channels, all samples) + 3. Random transformations (all channels, transformed only) + + :param gt_target_channels: + + example + + gt_target_channels=( + dict(channel=0, downsample=4), + dict(channel=0, downsample=2), + dict(channel=1, downsample=2), + dict(channel=0, downsample=1), + dict(channel=1, downsample=1), + dict(channel=2, downsample=1) + ) + """ + + if batch_no is not None: + assert batches is not None + sys.stdout.write('\r' + 'Started data generation for batch {}/{}'.format(batch_no + 1, batches)) + + if gt_target_channels is not None: + assert not add_pad_mask + + # Initialize the data volumes for raw and groundtruth + batch_x = [] + batch_y = [] + positions = [] + + # # Iterate over the positions + for idx, (position, volume, transform) in enumerate(index_array): + # position, volume, transform = index_array + + aug = aug_args[idx] + + position = position.astype('int16') + positions.append(position) + + # __________________________________ + # Load and transform the raw volumes + + # # Load data + x, s_pad_x = _load_data_with_padding(raw_channels[volume], + position, target_size, + auto_pad=transform) + + # Simulation of bad imaging quality on samples that are supposed to be transformed + if transform: + x = add_random_noise(x, aug['noise']) + x = random_smooth(x, aug['random_smooth']) + x = random_displace_slices(x, + aug['displace_slices'], + aug_dict['fill_mode'], + aug_dict['cval']) + x = random_brightness_contrast(x, aug['brightness'], aug['contrast']) + + # Smoothing is a general pre-processing of the data and is performed on all samples + x = preprocessing(x, aug_dict['smooth_output_sigma']) + + # Save the random state here to be able to perform exactly the same transformations on the GT + rdm_state = np.random.get_state() + + # Random transformations on the raw samples that are set to be transformed + if transform: + x = random_transform(x, + aug['rotation'], + aug['shear'], + aug['zoom'], + aug['horizontal_flip'], + aug['vertical_flip'], + aug['depth_flip'], + aug_dict['fill_mode'], + aug_dict['cval']) + + # Crop x to the target size + batch_x.append(x[s_pad_x]) + # batch_x[idx] = x[s_pad_x] + + # _______________________________ + # Now the same for the GT volumes + if gt_channels is not None: + + # # Load data + if gt_target_channels is None: + y, s_pad_y = _load_data_with_padding(gt_channels[volume], position, gt_target_size, + auto_pad=transform, + return_pad_mask=add_pad_mask) + + else: + data_s_pad_y = np.array([ + _load_data_with_padding( + [gt_channels[volume][gidx]], + (position / gtc['downsample']).astype(int), + gt_target_size, + auto_pad=transform, + return_pad_mask=add_pad_mask + ) + for gidx, gtc in enumerate(gt_target_channels) + ]) + y = data_s_pad_y[:, 0] + s_pad_y = data_s_pad_y[0, 1] + + y = np.concatenate(y, axis=3) + + # Set the random state which was saved initially to transforming the raw images + np.random.set_state(rdm_state) + + # Random transformations on the samples that are set to be transformed + if transform: + y = random_transform(y, + aug['rotation'], + aug['shear'], + aug['zoom'], + aug['horizontal_flip'], + aug['vertical_flip'], + aug['depth_flip'], + aug_dict['fill_mode'], + aug_dict['cval']) + + # Crop y to the target size + batch_y.append(y[s_pad_y]) + # batch_y[idx] = y[s_pad_y] + + batch_x = np.array(batch_x) + if gt_channels is not None: + batch_y = np.array(batch_y) + + # Convert gt to bool + if batch_y.dtype == 'float32': + # Let's assume this only happens if there are the three states (-1, 0, 1) + # now, the first item is true if y = -1, 1 and the second if y = -1, 0 + # That means [0] and [1] = -1 + # [0] and not [1] = 1 + # not [0] and [1] = 0 + batch_y = [batch_y.astype(bool), (batch_y - 1).astype(bool)] + elif batch_y.dtype == 'uint8': + batch_y[batch_y < 128] = 0 + batch_y[batch_y >= 128] = 1 + batch_y = batch_y.astype('bool') + else: + raise ValueError + + if yield_xyz: + if gt_channels is not None: + return batch_x, batch_y, positions + else: + return batch_x, positions + else: + if gt_channels is not None: + return batch_x, batch_y + else: + return batch_x + + +def _initialize( + raw_channels, + gt_channels, + spacing=None, + area_size=None, + areas_and_spacings=None, + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + gt_target_channels=None, + aug_dict=default_aug_dict, + transform_ratio=0., + batch_size=2, + shuffle=False, + add_pad_mask=False, + n_workers=1, + epoch_idx=0, + q=None, + noise_load_dict=None, + yield_xyz=False, + n_workers_noise=1, + noise_on_channels=None +): + + assert ((spacing is not None and area_size is not None) + or areas_and_spacings is not None), 'The areas and spacings have to be specified either with the ' \ + 'parameters area_size and spacing, or with areas_and_spacings.' + if areas_and_spacings is not None: + assert spacing is None and area_size is None + + n_volumes = len(raw_channels) + + if areas_and_spacings is None: + # Generate array of transformations + if type(area_size[0]) == int: + transformation_array = _build_equally_spaced_volume_list( + spacing, + area_size, + n_volumes, + transform_ratio + ) + elif type(area_size[0]) == tuple: + assert n_volumes == len(area_size) + transformation_array = [] + for idx, asize in enumerate(area_size): + transformation_array += _build_equally_spaced_volume_list( + spacing, + asize, + 1, + transform_ratio, + set_volume=idx + ) + else: + raise ValueError + else: + transformation_array = [] + for aas in areas_and_spacings: + + transformation_array += _build_equally_spaced_volume_list( + aas['spacing'], + aas['area_size'], + 1, + transform_ratio, + set_volume=aas['vol'] + ) + + steps_per_epoch = int(len(transformation_array) / batch_size) + # steps_per_epoch = len(transformation_array) + + # Shuffle array + if shuffle: + np.random.shuffle(transformation_array) + + # Pre-generate random values + print('Pre-generating random values...') + + if n_workers_noise == 1: + random_args = [] + for tidx, (position, volume, transform) in enumerate(transformation_array): + + random_args.append(_pre_generate_random_values(raw_channels, aug_dict, transform, volume, position, target_shape, noise_load_dict, noise_on_channels, tidx)) + else: + with ThreadPoolExecutor(max_workers=n_workers_noise) as tpe: + tasks = [ + tpe.submit(_pre_generate_random_values, + raw_channels, aug_dict, transform, volume, position, target_shape, noise_load_dict, + noise_on_channels, tidx) + for tidx, (position, volume, transform) in enumerate(transformation_array) + ] + random_args = [task.result() for task in tasks] + + # Reshape for batch size + random_args = np.reshape(random_args, (steps_per_epoch, batch_size)) + transformation_array = np.reshape(transformation_array, (steps_per_epoch, batch_size, 3)) + + # Generate data for epoch + if n_workers == 1: + + print('Fetching data with one worker...') + print(' ') + + results = [ + _get_batches_of_transformed_samples( + index_array, + raw_channels, + gt_channels, + target_shape, + gt_target_shape, + gt_target_channels=gt_target_channels, + yield_xyz=yield_xyz, + aug_dict=aug_dict, + aug_args=random_args[idx], + add_pad_mask=add_pad_mask, + batch_no=idx, + batches=len(transformation_array) + ) + for idx, index_array in enumerate(transformation_array) + ] + + print(' ') + + else: + + print('Fetching data with {} workers...'.format(n_workers)) + print(' ') + + # with Pool(processes=n_workers) as p: + with ThreadPoolExecutor(max_workers=n_workers) as p: + + print('Submitting tasks') + tasks = [ + p.submit( + _get_batches_of_transformed_samples, + index_array, + raw_channels, + gt_channels, + target_shape, + gt_target_shape, + gt_target_channels, + yield_xyz, + aug_dict, + random_args[idx], + add_pad_mask, + idx, + len(transformation_array) + ) + for idx, index_array in enumerate(transformation_array) + ] + results = [task.result() for task in tasks] + + print(' ') + + return results, steps_per_epoch + + +def parallel_data_generator( + raw_channels, + gt_channels, + spacing=None, + area_size=None, # Can now be a tuple of a shape for each input volume + areas_and_spacings=None, + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + gt_target_channels=None, + stop_after_epoch=True, + aug_dict=default_aug_dict, + transform_ratio=0., + batch_size=2, + shuffle=False, + add_pad_mask=False, + noise_load_dict=None, + n_workers=1, + n_workers_noise=1, + noise_on_channels=None, + yield_epoch_info=False +): + """ + + :param raw_channels: + :param gt_channels: + :param spacing: + :param area_size: + :param areas_and_spacings: The new definition of areas and spacings, individually for each volume. Also one volume + can be defined multiple times to enable different area/spacing combinations. + + example: + + area_and_spacings=( + dict(vol=0, area_size=(256, 256, 256), spacing=(32, 32, 32)), + dict(vol=0, area_size=(512, 512, 512), spacing=(128, 128, 128)), + dict(vol=1, area_size=(512, 512, 512), spacing=(128, 128, 128)) + ) + + :param target_shape: + :param gt_target_shape: + :param gt_target_channels: + + example: + + gt_target_channels=( + dict(chanenl=0, downsample=2), + dict(channel=0, downsample=1), + dict(channel=1, downsample=1) + ) + + :param stop_after_epoch: + :param aug_dict: + :param transform_ratio: + :param batch_size: + :param shuffle: + :param add_pad_mask: + :param noise_load_dict: + :param n_workers: + :return: + """ + + assert ((spacing is not None and area_size is not None) + or areas_and_spacings is not None), 'The areas and spacings have to be specified either with the ' \ + 'parameters area_size and spacing, or with areas_and_spacings.' + if areas_and_spacings is not None: + assert spacing is None and area_size is None + + # Prepare the gt channels if necessary + if gt_target_channels is not None: + gt_channels = _prepare_channels(gt_channels, gt_target_channels, n_workers) + + # Start the generator + n = 0 + results, steps_per_epoch = _initialize( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=spacing, + area_size=area_size, + areas_and_spacings=areas_and_spacings, + target_shape=target_shape, + gt_target_shape=gt_target_shape, + gt_target_channels=gt_target_channels, + aug_dict=aug_dict, + transform_ratio=transform_ratio, + batch_size=batch_size, + shuffle=shuffle, + add_pad_mask=add_pad_mask, + n_workers=n_workers, + noise_load_dict=noise_load_dict, + n_workers_noise=n_workers_noise, + noise_on_channels=noise_on_channels + ) + + epoch = 0 + + while True: + + if n == 0 and not stop_after_epoch: + print('Submitting new job') + + p = ThreadPool(processes=1) + res = p.apply_async(_initialize, ( + raw_channels, + gt_channels, + spacing, + area_size, + areas_and_spacings, + target_shape, + gt_target_shape, + gt_target_channels, + aug_dict, + transform_ratio, + batch_size, + shuffle, + add_pad_mask, + n_workers - 1, + epoch, + None, + noise_load_dict, + False, + n_workers_noise, + noise_on_channels + )) + + last_of_epoch = False + if n == steps_per_epoch - 1: + last_of_epoch = True + if stop_after_epoch and n == steps_per_epoch: + break + if n == steps_per_epoch: + + n = 0 + print('Fetching results') + results, _ = res.get() + print('Joining job') + epoch += 1 + + else: + + # Convert to float + batch_x = results[n][0] + batch_y = results[n][1] + batch_x = batch_x.astype('float32') / 255 + + if type(batch_y) is list: + # Let's assume this only happens if there are the three states (-1, 0, 1) + # now, the first item is true if y = -1, 1 and the second if y = -1, 0 + # That means [0] and [1] = -1 + # [0] and not [1] = 1 + # not [0] and [1] = 0 + t_batch_y = np.zeros(batch_y[0].shape, dtype='float32') + t_batch_y[np.logical_and(batch_y[0], batch_y[1])] = -1 + t_batch_y[np.logical_and(batch_y[0], np.logical_not(batch_y[1]))] = 1 + # t_batch_y[not batch_y[0] and batch_y[1]] = 0 <- not needed, right? + batch_y = t_batch_y + else: + batch_y = batch_y.astype('float32') + + if yield_epoch_info: + yield batch_x, batch_y, epoch, n, last_of_epoch + else: + yield batch_x, batch_y + n += 1 + + +def parallel_test_data_generator( + raw_channels, + spacing=(32, 32, 32), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + smooth_output_sigma=0.5, + n_workers=1): + + # Start the generator + n = 0 + results, steps_per_epoch = _initialize( + raw_channels=raw_channels, + gt_channels=None, + spacing=spacing, + area_size=area_size, + target_shape=target_shape, + gt_target_shape=None, + aug_dict=dict( + smooth_output_sigma=smooth_output_sigma + ), + transform_ratio=0, + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=n_workers, + noise_load_dict=None, + yield_xyz=True + ) + + epoch = 0 + + while True: + + if n == steps_per_epoch: + break + + else: + + # Convert to float + batch_x = results[n][0] + batch_x = batch_x.astype('float32') / 255 + xyz = results[n][1] + + yield batch_x, xyz + n += 1 diff --git a/pytorch_tools/losses.py b/pytorch_tools/losses.py new file mode 100644 index 0000000..334f349 --- /dev/null +++ b/pytorch_tools/losses.py @@ -0,0 +1,115 @@ + +import torch as t +import torch.nn as nn + + +class WeightMatrixMSELoss(nn.Module): + + def __init__(self): + super(WeightMatrixMSELoss, self).__init__() + + def forward(self, y_pred, y_true): + + num_channels = y_pred.shape[1] + + w = y_true[:, -1, :][:, None, :] + + loss = 0. + for c in range(num_channels): + loss += w * (y_pred[:, c, :] - y_true[:, c, :]) ** 2 + + return t.mean(loss) + + +class WeightMatrixWeightedBCE(nn.Module): + + def __init__(self, class_weights, weigh_with_matrix_sum=False): + super(WeightMatrixWeightedBCE, self).__init__() + + self.class_weights = class_weights + self.weigh_with_matrix_sum = weigh_with_matrix_sum + + def forward(self, y_pred, y_true): + + cw = self.class_weights + + num_channels = y_pred.shape[1] + assert len(cw) == num_channels, 'Class weight sets and number of channels have to match!' + + _epsilon = 1e-7 + y_pred = t.clamp(y_pred, _epsilon, 1 - _epsilon) + + w = y_true[:, -1, :][:, None, :] + + loss = 0. + if not self.weigh_with_matrix_sum: + for c in range(num_channels): + loss += w * -(cw[c][1] * y_true[:, c, :] * t.log(y_pred[:, c, :]) + cw[c][0] * (1.0 - y_true[:, c, :]) * t.log(- y_pred[:, c, :] + 1.0)) + else: + for c in range(num_channels): + loss += t.sum(w) / w.nelement() * w * -(cw[c][1] * y_true[:, c, :] * t.log(y_pred[:, c, :]) + cw[c][0] * (1.0 - y_true[:, c, :]) * t.log(- y_pred[:, c, :] + 1.0)) + + return t.mean(loss) + + +class CombinedLosses(nn.Module): + + def __init__(self, losses, y_pred_channels, y_true_channels, weigh_losses=None): + super(CombinedLosses, self).__init__() + + self.losses = losses + self.y_pred_channels = y_pred_channels + self.y_true_channels = y_true_channels + if weigh_losses is None: + weigh_losses = (1,) * len(y_pred_channels) + self.weigh_losses = weigh_losses + + def forward(self, y_pred, y_true): + + loss = 0. + + for idx in range(len(self.y_pred_channels)): + + ypch = self.y_pred_channels[idx] + ytch = self.y_true_channels[idx] + + if type(ypch) is tuple: + raise NotImplementedError + # # TODO: This is from the keras version and has to be translated + # yp = [] + # for slidx, sl in enumerate(ypch): + # if type(ytch[slidx]) == tuple: + # for xidx in range(len(ytch[slidx])): + # yp.append(y_pred[..., sl][..., None]) + # else: + # yp.append(y_pred[..., sl][..., None]) + # yp = t.cat(yp, dim=1) + elif type(ypch) is slice: + yp = y_pred[:, ypch, :] + else: + raise NotImplementedError + + if type(ytch) is tuple: + yt = [] + for sl in ytch: + if type(sl) == int: + yt.append(y_true[:, sl, :][:, None, :]) + elif type(sl) == tuple: + for tsl in sl: + yt.append(y_true[:, sl, :][:, None, :]) + else: + raise ValueError + yt = t.cat(yt, dim=1) + elif type(ytch) is slice: + yt = y_true[:, ytch, :] + else: + raise NotImplementedError + + loss += self.weigh_losses[idx] * self.losses[idx](yp, yt) + + return loss / len(self.y_pred_channels) + + +if __name__ == '__main__': + + pass diff --git a/pytorch_tools/modules.py b/pytorch_tools/modules.py new file mode 100644 index 0000000..a84a527 --- /dev/null +++ b/pytorch_tools/modules.py @@ -0,0 +1,350 @@ + +import torch.nn as nn +from torch import cat +from torch import optim +from torchsummary import summary + +ConvND = {2: nn.Conv2d, 3: nn.Conv3d} +ConvNDTranspose = {2: nn.ConvTranspose2d, 3: nn.ConvTranspose3d} +MaxPoolingND = {2: nn.MaxPool2d, 3: nn.MaxPool3d} +BatchNormND = {2: nn.BatchNorm2d, 3: nn.BatchNorm3d} + + +class Downsampling(nn.Module): + + def __init__( + self, + num_convs, + in_channels, # The number of input channels for the first conv layer + out_channels, # int or List of output channels for each conv layer + kernel_size, + level, + conv_strides=1, + padding=1, + activation='relu', + batch_norm=False, + ndims=2 + ): + super(Downsampling, self).__init__() + + # Parameters + self.num_convs = num_convs + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.level = level + self.conv_strides = conv_strides + self.padding = padding + self.activation = activation + self.batch_norm = batch_norm + self.ndims = ndims + + # _________________________ + # Layer initializations + + self.convolutions = nn.Sequential() + in_ch = in_channels + + for idx in range(num_convs): + if type(out_channels) is tuple: + out_ch = out_channels[idx] + else: + out_ch = out_channels + self.convolutions.add_module('conv{}'.format(idx), ConvND[ndims]( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=kernel_size, + stride=conv_strides, + padding=padding + )) + + if batch_norm: + self.convolutions.add_module('bn{}'.format(idx), BatchNormND[ndims](out_ch)) + + if activation == 'relu': + self.convolutions.add_module('relu{}'.format(idx), nn.ReLU(inplace=True)) + else: + raise NotImplementedError + + in_ch = out_ch + + self.max_pool = MaxPoolingND[ndims](2) + + def forward(self, x): + + x = self.convolutions(x) + skip = x + x = self.max_pool(x) + return x, skip + + +class Upsampling(nn.Module): + + def __init__( + self, + num_convs, + in_channels, # The number of input channels for the first conv layer + out_channels, # int or List of output channels for each conv layer + skip_channels, + kernel_size, + level, + upsampling_size=2, + upsampling_strides=2, + conv_strides=1, + padding=1, + activation='relu', + batch_norm=False, + ndims=2 + ): + super(Upsampling, self).__init__() + + # Parameters + self.num_convs = num_convs + self.in_channels = in_channels + self.out_channels = out_channels + self.skip_channels = skip_channels + self.kernel_size = kernel_size + self.level = level + self.upsampling_size = upsampling_size + self.upsampling_strides = upsampling_strides + self.conv_strides = conv_strides + self.padding = padding + self.activation = activation + self.batch_norm = batch_norm + self.ndims = ndims + + # _________________________ + # Layer initializations + + self.conv_transpose = ConvNDTranspose[ndims]( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=upsampling_size, + stride=upsampling_strides, + padding=0 + ) + + self.convolutions = nn.Sequential() + + in_ch = in_channels + skip_channels + + for idx in range(num_convs): + if type(out_channels) is tuple: + out_ch = out_channels[idx] + else: + out_ch = out_channels + self.convolutions.add_module('conv{}'.format(idx), ConvND[ndims]( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=kernel_size, + stride=conv_strides, + padding=padding + )) + + if batch_norm: + self.convolutions.add_module('bn{}'.format(idx), BatchNormND[ndims](out_ch)) + + if activation == 'relu': + self.convolutions.add_module('relu{}'.format(idx), nn.ReLU(inplace=True)) + else: + raise NotImplementedError + + in_ch = out_ch + + def forward(self, x, skip): + + x = self.conv_transpose(x) + x = cat((skip, x), dim=1) + return self.convolutions(x) + + +class Bottleneck(nn.Module): + + def __init__( + self, + num_convs, + in_channels, + out_channels, + kernel_size, + conv_strides=1, + padding=1, + activation='relu', + batch_norm=False, + ndims=2 + ): + super(Bottleneck, self).__init__() + + # Parameters + self.num_convs = num_convs + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.conv_strides = conv_strides + self.padding = padding + self.activation = activation + self.batch_norm = batch_norm + self.ndims = ndims + + # _________________________ + # Layer initializations + + self.layers = nn.Sequential() + in_ch = in_channels + + for idx in range(num_convs): + if type(out_channels) is tuple: + out_ch = out_channels[idx] + else: + out_ch = out_channels + self.layers.add_module('conv{}'.format(idx), ConvND[ndims]( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=kernel_size, + stride=conv_strides, + padding=padding + )) + + if batch_norm: + self.layers.add_module('bn{}'.format(idx), BatchNormND[ndims](out_ch)) + + if activation == 'relu': + self.layers.add_module('relu{}'.format(idx), nn.ReLU(inplace=True)) + else: + raise NotImplementedError + + in_ch = out_ch + + def forward(self, x): + return self.layers(x) + + +class Unet(nn.Module): + + def __init__( + self, + num_classes, + in_channels, + filter_sizes_down=((16, 32), (32, 64), (64, 128)), + filter_sizes_up=((128, 128), (64, 64), (32, 32)), + filter_sizes_bottleneck=(128, 256), + kernel_size=3, + batch_norm=True, + ndims=2, + return_last_upsampling=False, + output_activation='sigmoid' + ): + super(Unet, self).__init__() + + # Parameters + self.num_classes = num_classes + self.filter_sizes_down = filter_sizes_down + self.filter_sizes_up = filter_sizes_up + self.filter_sizes_bottleneck = filter_sizes_bottleneck + self.kernel_size = kernel_size + self.batch_norm = batch_norm + self.ndims = ndims + self.return_last_upsampling = return_last_upsampling + self.output_activation = output_activation + + # _________________________________ + # Network layer initialization + + in_ch = in_channels + self.downs = nn.ModuleList() + for down_level in range(len(filter_sizes_down)): + self.downs.append(Downsampling( + num_convs=len(filter_sizes_down[down_level]), + in_channels=in_ch, + out_channels=filter_sizes_down[down_level], + kernel_size=kernel_size, + level=down_level, + conv_strides=1, + padding=1, + activation='relu', + batch_norm=batch_norm, + ndims=ndims + )) + in_ch = filter_sizes_down[down_level][-1] + + self.bottleneck = Bottleneck( + num_convs=len(filter_sizes_bottleneck), + in_channels=in_ch, + out_channels=filter_sizes_bottleneck, + kernel_size=kernel_size, + conv_strides=1, + padding=1, + activation='relu', + batch_norm=batch_norm, + ndims=ndims + ) + in_ch = filter_sizes_bottleneck[-1] + + self.ups = nn.ModuleList() + for up_level in range(len(filter_sizes_up)): + self.ups.append(Upsampling( + num_convs=len(filter_sizes_up[up_level]), + in_channels=in_ch, + out_channels=filter_sizes_up[up_level], + skip_channels=filter_sizes_down[len(filter_sizes_up) - up_level - 1][-1], + kernel_size=kernel_size, + level=up_level, + upsampling_strides=2, + upsampling_size=2, + conv_strides=1, + padding=1, + activation='relu', + batch_norm=batch_norm, + ndims=ndims + )) + in_ch = filter_sizes_up[up_level][-1] + + self.output_conv = ConvND[ndims]( + in_channels=in_ch, + out_channels=num_classes, + kernel_size=kernel_size, + stride=1, + padding=1 + ) + if self.output_activation == 'sigmoid': + self.output_act = nn.Sigmoid() + elif self.output_activation == 'softmax': + self.output_act = nn.Softmax() + + def forward(self, x): + + # Auto-encoder + skips = [] + for down_level in range(len(self.filter_sizes_down)): + + x, skip = self.downs[down_level](x) + skips.append(skip) + + # Bottleneck + x = self.bottleneck(x) + + # Auto-decoder + for up_level in range(len(self.filter_sizes_up)): + + x = self.ups[up_level](x, skips[len(self.filter_sizes_up) - up_level - 1]) + + last_upsampling = x + + # Output layer + x = self.output_conv(x) + x = self.output_act(x) + + if not self.return_last_upsampling: + return x + else: + return x, last_upsampling + + +if __name__ == '__main__': + + unet = Unet(1, 1, ndims=3) + + # optimizer = optim.Adam(unet.parameters(), 0.003) + + unet.cuda() + summary(unet, (1, 64, 64, 64)) + diff --git a/pytorch_tools/piled_unets.py b/pytorch_tools/piled_unets.py new file mode 100644 index 0000000..a03e963 --- /dev/null +++ b/pytorch_tools/piled_unets.py @@ -0,0 +1,239 @@ + +import numpy as np + +# from .modules import nn +import torch.nn as nn +from torch import cat +from pytorch.pytorch_tools.modules import Unet +from torchsummary import summary + + +class cNnet(nn.Module): + + def __init__( + self, + n_nets=3, + # default scale ratio: 2 + # + # EXAMPLE with three modules and the last with input shape = (64, 64, 64) + # First net input (256, 256, 256) -> downsample -> (64, 64, 64) + # Second net input (128, 128, 128) -> downsample -> (64, 64, 64) + # Third net input (64, 64, 64) + # + # i.e. scale_ratio=2, n_nets=3: 256 -> 128 -> 64 + # scale_ratio=4, n_nets=2: 256 -> 64 + scale_ratio=2, + module_shape=(64, 64, 64), + input_shapes=None, + initial_ds=None, + crop_and_ds_inputs=False, + crop_and_us_outputs=True, + in_channels=1, + out_channels=None, + num_inputs=1, + filter_sizes_down=( + ((4, 8), (8, 16), (16, 32)), + ((8, 16), (16, 32), (32, 64)), + ((32, 64), (64, 128), (128, 256)) + ), + filter_sizes_bottleneck=( + (32, 64), + (64, 128), + (256, 512) + ), + filter_sizes_up=( + ((32, 32), (16, 16), (8, 8)), + ((64, 64), (32, 32), (16, 16)), + ((256, 256), (128, 128), (64, 64)) + ), + batch_norm=True, + output_activation='softmax', + verbose=False + ): + + super(cNnet, self).__init__() + + # ______________________________ + # Parameters and settings + + # Assertions + assert out_channels is not None, 'The number of output classes for each module needs to be specified!' + assert in_channels is not None, 'The number of input channels needs to be specified!' + assert len(filter_sizes_down) == n_nets + assert len(filter_sizes_up) == n_nets + assert len(filter_sizes_bottleneck) == n_nets + + # ______________________________ + # Define layers + + self.avg_pool_inputs = nn.ModuleDict() + self.unets = nn.ModuleList() + + for net_idx in range(n_nets): + + if crop_and_ds_inputs: + + if 0 < net_idx < n_nets - 1: + + # First, the tensor will be cropped + # and then downsampled + self.avg_pool_inputs[net_idx] = nn.AvgPool3d( + kernel_size=(scale_ratio ** (n_nets - net_idx - 1),) * 3 + ) + + self.unets.append( + Unet( + num_classes=out_channels[net_idx], + in_channels=in_channels + out_channels[net_idx], + filter_sizes_down=filter_sizes_down[net_idx], + filter_sizes_up=filter_sizes_up[net_idx], + filter_sizes_bottleneck=filter_sizes_bottleneck, + kernel_size=3, + batch_norm=batch_norm, + ndims=3, + return_last_upsampling=True, + output_activation=output_activation + ) + ) + + def forward(self, x): + + # TODO: Define the architecture + raise NotImplementedError + + output = x + return output + + +class PiledUnet(nn.Module): + """ + Special case of the cNnet where all nets have the same input scale. + To achieve the same architecture using cNnet, set scale_ratio=1 + """ + + def __init__( + self, + n_nets=3, + in_channels=1, + out_channels=None, + filter_sizes_down=( + ((4, 8), (8, 16), (16, 32)), + ((8, 16), (16, 32), (32, 64)), + ((32, 64), (64, 128), (128, 256)) + ), + filter_sizes_bottleneck=( + (32, 64), + (64, 128), + (256, 512) + ), + filter_sizes_up=( + ((32, 32), (16, 16), (8, 8)), + ((64, 64), (32, 32), (16, 16)), + ((256, 256), (128, 128), (64, 64)) + ), + batch_norm=True, + output_activation='softmax', + predict=False + ): + + super(PiledUnet, self).__init__() + + # ______________________________ + # Parameters and settings + + # Assertions + assert out_channels is not None, 'The number of output classes for each module needs to be specified!' + assert in_channels is not None, 'The number of input channels needs to be specified!' + assert len(filter_sizes_down) == n_nets + assert len(filter_sizes_up) == n_nets + assert len(filter_sizes_bottleneck) == n_nets + + self.n_nets = n_nets + self.predict = predict + + # ______________________________ + # Define layers + + self.unets = nn.ModuleList() + + for net_idx in range(n_nets): + + if net_idx == 0: + in_ch = in_channels + else: + in_ch = in_channels + filter_sizes_up[net_idx][-1][-1] + + self.unets.append( + Unet( + num_classes=out_channels[net_idx], + in_channels=in_ch, + filter_sizes_down=filter_sizes_down[net_idx], + filter_sizes_up=filter_sizes_up[net_idx], + filter_sizes_bottleneck=filter_sizes_bottleneck[net_idx], + kernel_size=3, + batch_norm=batch_norm, + ndims=3, + return_last_upsampling=True, + output_activation=output_activation + ) + ) + + def forward(self, in_x): + + # Define the architecture + outputs = None + out_x = None + + x = in_x + + for net_idx in range(self.n_nets): + + if net_idx > 0: + x = cat((in_x, x), dim=1) + + out_x, x = self.unets[net_idx](x) + + if not self.predict: + if net_idx == 0: + outputs = out_x + else: + outputs = cat((outputs, out_x), dim=1) + + if self.predict: + assert out_x is not None + outputs = out_x + + return outputs + + +if __name__ == '__main__': + + piled_unet = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' + ) + + # # optimizer = optim.Adam(unet.parameters(), 0.003) + # + # piled_unet.cuda() + # summary(piled_unet, (1, 64, 64, 64)) + # + # pass diff --git a/pytorch_tools/run_models.py b/pytorch_tools/run_models.py new file mode 100644 index 0000000..ed4958b --- /dev/null +++ b/pytorch_tools/run_models.py @@ -0,0 +1,133 @@ + +import h5py +import os +import numpy as np +import sys +import torch as t +from pytorch.pytorch_tools.data_generation import parallel_test_data_generator + + +def predict_model_from_h5_parallel_generator( + model, + results_filepath, + raw_channels, + spacing, + area_size, + target_shape, + num_result_channels, + smooth_output_sigma=0.5, + n_workers=16, + compute_empty_volumes=True, + thresh=0, + write_at_area=False, + offset=None, + full_dataset_shape=None +): + + model.eval() + + print('offset = {}'.format(offset)) + print('write_at_area = {}'.format(write_at_area)) + print('full_dataset_shape = {}'.format(full_dataset_shape)) + + def _write_result(dataset, result, position, spacing): + + spacing = np.array(spacing) + spacing_half = (spacing / 2).astype(int) + shape = np.array(dataset.shape[:3]) + shape_half = (shape / 2).astype(int) + result_shape = np.array(result.shape[:3]) + result_shape_half = (result_shape / 2).astype(int) + + # Pre-crop the result + start_crop = result_shape_half - spacing_half + stop_crop = result_shape_half + spacing_half + s_pre_crop = np.s_[ + start_crop[0]: stop_crop[0], + start_crop[1]: stop_crop[1], + start_crop[2]: stop_crop[2] + ] + result_cropped = result[s_pre_crop] + + # All the shapes and positions + result_shape = np.array(result_cropped.shape[:3]) + result_shape_half = (result_shape / 2).astype(int) + position = np.array(position) + + start_pos = position + shape_half - result_shape_half + stop_pos = start_pos + spacing + # print('') + # print('Before correction ...') + # print('start_pos = {}'.format(start_pos)) + # print('stop_pos = {}'.format(stop_pos)) + start_out_of_bounds = np.zeros(start_pos.shape, dtype=start_pos.dtype) + start_out_of_bounds[start_pos < 0] = start_pos[start_pos < 0] + stop_out_of_bounds = stop_pos - shape + stop_out_of_bounds[stop_out_of_bounds < 0] = 0 + start_pos[start_pos < 0] = 0 + stop_pos[stop_out_of_bounds > 0] = shape[stop_out_of_bounds > 0] + # print('After correction ...') + # print('start_pos = {}'.format(start_pos)) + # print('stop_pos = {}'.format(stop_pos)) + + # For the results volume + s_source = np.s_[ + -start_out_of_bounds[0]:stop_pos[0] - start_pos[0] - start_out_of_bounds[0], + -start_out_of_bounds[1]:stop_pos[1] - start_pos[1] - start_out_of_bounds[1], + -start_out_of_bounds[2]:stop_pos[2] - start_pos[2] - start_out_of_bounds[2], + : + ] + # For the target dataset + s_target = np.s_[ + start_pos[0]:stop_pos[0], + start_pos[1]:stop_pos[1], + start_pos[2]:stop_pos[2], + : + ] + + dataset[s_target] = (result_cropped * 255).astype('uint8')[s_source] + + if offset is None: + offset = (0, 0, 0) + + # Generate results file + if not write_at_area: + with h5py.File(results_filepath, 'w') as f: + f.create_dataset('data', shape=tuple(area_size) + (num_result_channels,), dtype='uint8', compression='gzip', chunks=(32, 32, 32, 1)) + else: + if not os.path.exists(results_filepath): + with h5py.File(results_filepath, 'w') as f: + f.create_dataset('data', shape=tuple(full_dataset_shape) + (num_result_channels,), dtype='uint8', compression='gzip') + + for idx, element in enumerate(parallel_test_data_generator( + raw_channels=raw_channels, + spacing=spacing, + area_size=area_size, + target_shape=target_shape, + smooth_output_sigma=smooth_output_sigma, + n_workers=n_workers + )): + im = element[0] + xyz = element[1][0] + np.array(offset) + + # xyz = np.array(xyz) + (np.array(source_size) / 2).astype(int) - (np.array(spacing) / 2).astype(int) + x = xyz[2] + y = xyz[1] + z = xyz[0] + + sys.stdout.write('\r' + 'x = {}; y = {}, z = {}'.format(x, y, z)) + + if compute_empty_volumes or (im < thresh).sum(): + + imx = t.tensor(np.moveaxis(im, 4, 1), dtype=t.float32).cuda() + result = model(imx) + result = np.moveaxis(result.cpu().numpy(), 1, 4) + # overlap = np.array(result.shape[1:4]) - np.array(spacing) + # + with h5py.File(results_filepath, 'a') as f: + # write_test_h5_generator_result(f['data'], result, x, y, z, overlap, ndim=ndim) + _write_result(f['data'], result[0, :], xyz, spacing) + + else: + + print(' skipped...') diff --git a/pytorch_tools/training.py b/pytorch_tools/training.py new file mode 100644 index 0000000..5cc9725 --- /dev/null +++ b/pytorch_tools/training.py @@ -0,0 +1,399 @@ + +from torchsummary import summary +import torch as t +import numpy as np +from torch.utils.tensorboard import SummaryWriter +import os + +from pytorch.pytorch_tools.run_models import predict_model_from_h5_parallel_generator + +from matplotlib import pyplot as plt + + +def cb_save_model( + filepath, + min_epoch=0 +): + + def run(model, epoch): + if epoch >= min_epoch: + print('Saving model ...') + t.save(model.state_dict(), filepath.format(epoch=epoch)) + else: + print('Not saving model: Minimum number of epochs not reached.') + + return run + + +def cb_run_model_on_data( + results_filepath, + raw_channels, + spacing, + area_size, + target_shape, + num_result_channels, + smooth_output_sigma, + n_workers, + compute_empty_volumes, + thresh, + write_at_area, + offset, + full_dataset_shape, + min_epoch=0 +): + + def run(model, epoch): + if epoch >= min_epoch: + print('Running model on data ...') + predict_model_from_h5_parallel_generator( + model=model, + results_filepath=results_filepath.format(epoch=epoch), + raw_channels=raw_channels, + spacing=spacing, + area_size=area_size, + target_shape=target_shape, + num_result_channels=num_result_channels, + smooth_output_sigma=smooth_output_sigma, + n_workers=n_workers, + compute_empty_volumes=compute_empty_volumes, + thresh=thresh, + write_at_area=write_at_area, + offset=offset, + full_dataset_shape=full_dataset_shape + ) + else: + print('Not running model on data: Minimum number of epochs not reached.') + + return run + + +def train_model_with_generators( + model, + train_generator, + val_generator, + n_epochs, + loss_func, + optimizer=None, + l2_reg_param=1e-3, + callbacks=None, + writer_path=None +): + + if writer_path is not None: + if not os.path.exists(writer_path): + os.mkdir(writer_path) + writer_train = SummaryWriter(os.path.join(writer_path, 'train')) + writer_val = SummaryWriter(os.path.join(writer_path, 'val')) + else: + writer_train = None + writer_val = None + + def _on_epoch_end(model, best_val_loss): + print('------------------------------------------') + print('Epoch ended, evaluating model ...') + + # Evaluation at the end of an epoch + with t.no_grad(): + model.eval() + + sum_loss = 0. + eval_not_done = True + it_idx = 0. + + while eval_not_done: + + valx, valy, val_epoch, valn, loe = next(val_generator) + + # Forward pass + tensx = t.tensor(np.moveaxis(valx, 4, 1), dtype=t.float32).cuda() + predy = model(tensx) + + # Compute loss + tensy = t.tensor(np.moveaxis(valy, 4, 1), dtype=t.float32).cuda() + loss = loss_func(predy, tensy) + sum_loss += loss + + it_idx += 1 + + if loe: + break + + val_loss = sum_loss / (it_idx + 1) + + if writer_val: + writer_val.add_scalar('losses/loss', val_loss, epoch) + + if best_val_loss is None or val_loss < best_val_loss: + best_val_loss = val_loss + improvement = True + else: + improvement = False + + print('Epoch: {} | Train loss: {} | Val loss: {} | Best val loss: {}'.format( + epoch, train_loss, val_loss, best_val_loss) + ) + + print('------------------------------------------') + + return best_val_loss, improvement + + if optimizer is None: + # optimizer = t.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), + # lr=1e-4, betas=(0.9, 0.999), weight_decay=l2_reg_param) + optimizer = t.optim.Adam(model.parameters(), + lr=1e-4, betas=(0.9, 0.999), weight_decay=l2_reg_param) + + print(optimizer) + + best_val_loss = None + print('__________________________________________') + print('Epoch = 0') + sum_loss = 0. + + # The train loop + model.train() + for it_idx, (x, y, epoch, n, loe) in enumerate(train_generator): + + optimizer.zero_grad() + + # Forward pass + tensx = t.tensor(np.moveaxis(x, 4, 1), dtype=t.float32).cuda() + predy = model(tensx) + + # Compute loss + tensy = t.tensor(np.moveaxis(y, 4, 1), dtype=t.float32).cuda() + loss = loss_func(predy, tensy) + + # Back propagation + loss.backward() + + optimizer.step() + sum_loss += loss.item() + + train_loss = sum_loss / (n + 1) + print('Iteration = {} | Loss = {} | Average loss = {}'.format(n + 1, loss.item(), train_loss)) + + if loe: + + if writer_train: + writer_train.add_scalar('losses/loss', train_loss, epoch) + + # Evaluate previous epoch + best_val_loss, improvement = _on_epoch_end(model, best_val_loss) + + # Callbacks when the model improved + if callbacks is not None and improvement: + with t.no_grad(): + model.eval() + print('Validation loss improved! Computing callbacks ...') + with t.no_grad(): + for callback in callbacks: + callback(model, epoch) + + # Break if specified number of epochs is reached + if epoch + 1 == n_epochs: + break + + # Initialize new epoch + print('__________________________________________') + print('Epoch = {}'.format(epoch + 1)) + + sum_loss = 0. + model.train() + + +if __name__ == '__main__': + + from pytorch.pytorch_tools.data_generation import parallel_data_generator + import os + from h5py import File + from pytorch.pytorch_tools.piled_unets import PiledUnet + from pytorch.pytorch_tools.losses import WeightMatrixWeightedBCE, CombinedLosses + + if True: + + raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + raw_filepaths = [ + [ + os.path.join(raw_path, 'raw.h5'), + ], + ] + gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + gt_filepaths = [ + [ + os.path.join(gt_path, 'gt_mem.h5'), + os.path.join(gt_path, 'gt_mask_organelle_insides.h5') + ], + ] + raw_channels = [] + for volumes in raw_filepaths: + raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + raws_data.append(File(channel, 'r')['data'][:]) + raw_channels.append(raws_data) + gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in gt_filepaths] + + val_raw_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_raw_filepaths = [ + [ + os.path.join(val_raw_path, 'val_raw_512.h5'), + ], + [ + os.path.join( + '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/psp_200107_00_ds_20141002_hela_wt_xyz8nm_as_multiple_scales/step0_datasets/psp0_200108_02_select_test_and_val_cubes', + 'val0_x1390_y742_z345_pad.h5' + + ) + ] + ] + val_gt_path = '/g/schwab/hennies/phd_project/image_analysis/psp_full_experiments/boundary_raw_and_gt/' + val_gt_filepaths = [ + [ + os.path.join(val_gt_path, 'val_gt_mem.h5'), + os.path.join(val_gt_path, 'val_gt_mask_organelle_insides.h5') + ] + ] + val_raw_channels = [] + for volumes in val_raw_filepaths: + val_raws_data = [] + for chid, channel in enumerate(volumes): + if chid == 1: + # Specifically only load last channel of the membrane prediction + val_raws_data.append(File(channel, 'r')['data'][..., -1]) + else: + val_raws_data.append(File(channel, 'r')['data'][:]) + val_raw_channels.append(val_raws_data) + val_gt_channels = [[File(channel, 'r')['data'][:] for channel in volumes] for volumes in val_gt_filepaths] + + if True: + + data_gen_args = dict( + rotation_range=180, # Angle in degrees + shear_range=20, # Angle in degrees + zoom_range=[0.8, 1.2], # [0.75, 1.5] + horizontal_flip=True, + vertical_flip=True, + noise_var_range=1e-1, + random_smooth_range=[0.6, 1.5], + smooth_output_sigma=0.5, + displace_slices_range=2, + fill_mode='reflect', + cval=0, + brightness_range=92, + contrast_range=(0.5, 2) + ) + + aug_dict_preprocessing = dict( + smooth_output_sigma=0.5 + ) + + train_gen = parallel_data_generator( + raw_channels=raw_channels, + gt_channels=gt_channels, + spacing=(32, 32, 32), + area_size=(32, 128, 128), # (32, 512, 512), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=data_gen_args, + transform_ratio=0.9, + batch_size=2, + shuffle=True, + add_pad_mask=False, + n_workers=16, + noise_load_dict=None, + gt_target_channels=None, + areas_and_spacings=None, + n_workers_noise=16, + noise_on_channels=None, + yield_epoch_info=True + ) + + val_gen = parallel_data_generator( + raw_channels=val_raw_channels[:1], + gt_channels=val_gt_channels, + spacing=(64, 64, 64), + area_size=(256, 256, 256), + target_shape=(64, 64, 64), + gt_target_shape=(64, 64, 64), + stop_after_epoch=False, + aug_dict=aug_dict_preprocessing, + transform_ratio=0., + batch_size=1, + shuffle=False, + add_pad_mask=False, + n_workers=16, + gt_target_channels=None, + yield_epoch_info=True + ) + + model = PiledUnet( + n_nets=3, + in_channels=1, + out_channels=[1, 1, 1], + filter_sizes_down=( + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)), + ((16, 32), (32, 64), (64, 128)) + ), + filter_sizes_bottleneck=( + (128, 256), + (128, 256), + (128, 256) + ), + filter_sizes_up=( + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)), + ((256, 128, 128), (128, 64, 64), (64, 32, 32)) + ), + batch_norm=True, + output_activation='sigmoid' + ) + model.cuda() + summary(model, (1, 64, 64, 64)) + + results_folder = '/g/schwab/hennies/tmp/pytorch_test2' + if not os.path.exists(results_folder): + os.mkdir(results_folder) + + train_model_with_generators( + model, + train_gen, + val_gen, + n_epochs=100, + loss_func=CombinedLosses( + losses=( + WeightMatrixWeightedBCE(((0.1, 0.9),)), + WeightMatrixWeightedBCE(((0.2, 0.8),)), + WeightMatrixWeightedBCE(((0.3, 0.7),))), + y_pred_channels=(np.s_[:1], np.s_[1:2], np.s_[2:3]), + y_true_channels=(np.s_[:], np.s_[:], np.s_[:]), + weigh_losses=np.array([0.2, 0.3, 0.5]) + ), + l2_reg_param=1e-5, + callbacks=[ + cb_run_model_on_data( + results_filepath=os.path.join(results_folder, 'improved_result2_{epoch:04d}.h5'), + raw_channels=val_raw_channels[:1], + spacing=(32, 32, 32), + area_size=(64, 256, 256), + target_shape=(64, 64, 64), + num_result_channels=3, + smooth_output_sigma=aug_dict_preprocessing['smooth_output_sigma'], + n_workers=16, + compute_empty_volumes=True, + thresh=None, + write_at_area=False, + offset=None, + full_dataset_shape=None + ), + cb_save_model( + filepath=os.path.join(results_folder, 'model_{epoch:04d}.h5') + ) + ], + writer_path=os.path.join(results_folder, 'run') + )