diff --git a/models/wholeBody_ct_segmentation/configs/inference.json b/models/wholeBody_ct_segmentation/configs/inference.json index 312ef5a1..afbb7050 100644 --- a/models/wholeBody_ct_segmentation/configs/inference.json +++ b/models/wholeBody_ct_segmentation/configs/inference.json @@ -96,10 +96,10 @@ 96 ], "sw_batch_size": 1, - "overlap": 0.4, + "overlap": 0.3, "padding_mode": "replicate", "mode": "gaussian", - "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" + "device": "$torch.device('cpu')" }, "postprocessing": { "_target_": "Compose", diff --git a/models/wholeBody_ct_segmentation/configs/train.json b/models/wholeBody_ct_segmentation/configs/train.json index ede6ff0d..9011492e 100755 --- a/models/wholeBody_ct_segmentation/configs/train.json +++ b/models/wholeBody_ct_segmentation/configs/train.json @@ -7,10 +7,10 @@ "bundle_root": ".", "ckpt_dir": "$@bundle_root + '/models'", "output_dir": "$@bundle_root + '/eval'", - "dataset_dir": "../datasets/sampleTrain", - "images": "$list(sorted(glob.glob(@dataset_dir + '/imagesTr/*.nii.gz')))", - "labels": "$list(sorted(glob.glob(@dataset_dir + '/labelsTr/*.nii.gz')))", - "val_interval": 2, + "dataset_dir": "../datasets/totalSegmentator", + "images": "$list(sorted(glob.glob(@dataset_dir + '/images/*.nii.gz')))", + "labels": "$list(sorted(glob.glob(@dataset_dir + '/labels/*.nii.gz')))", + "val_interval": 5, "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "network_def": { "_target_": "SegResNet", @@ -150,16 +150,18 @@ "transforms": "$@train#deterministic_transforms + @train#random_transforms" }, "dataset": { - "_target_": "Dataset", - "data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-1], @labels[:-1])]", - "transform": "@train#preprocessing" + "_target_": "CacheDataset", + "data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-200], @labels[:-200])]", + "transform": "@train#preprocessing", + "cache_rate": 0.5, + "num_workers": 8 }, "dataloader": { "_target_": "DataLoader", "dataset": "@train#dataset", - "batch_size": 1, + "batch_size": 16, "shuffle": true, - "num_workers": 4 + "num_workers": 8 }, "inferer": { "_target_": "SimpleInferer" @@ -213,7 +215,7 @@ }, "trainer": { "_target_": "SupervisedTrainer", - "max_epochs": 5, + "max_epochs": 1000, "device": "@device", "train_data_loader": "@train#dataloader", "network": "@network", @@ -229,11 +231,123 @@ "validate": { "preprocessing": { "_target_": "Compose", - "transforms": "%train#deterministic_transforms" + "transforms": [ + { + "_target_": "LoadImaged", + "keys": [ + "image", + "label" + ] + }, + { + "_target_": "EnsureChannelFirstd", + "keys": [ + "image", + "label" + ] + }, + { + "_target_": "EnsureTyped", + "keys": [ + "image", + "label" + ] + }, + { + "_target_": "Orientationd", + "keys": [ + "image", + "label" + ], + "axcodes": "RAS" + }, + { + "_target_": "Spacingd", + "keys": [ + "image", + "label" + ], + "pixdim": [ + 1.5, + 1.5, + 1.5 + ], + "mode": [ + "bilinear", + "nearest" + ] + }, + { + "_target_": "NormalizeIntensityd", + "keys": "image", + "nonzero": true + }, + { + "_target_": "CropForegroundd", + "keys": [ + "image", + "label" + ], + "source_key": "image", + "margin": 10, + "k_divisible": [ + 96, + 96, + 96 + ] + }, + { + "_target_": "GaussianSmoothd", + "keys": [ + "image" + ], + "sigma": 0.4 + }, + { + "_target_": "ScaleIntensityd", + "keys": "image", + "minv": -1.0, + "maxv": 1.0 + }, + { + "_target_": "CenterSpatialCropd", + "keys": [ + "image", + "label" + ], + "roi_size": [ + 160, + 160, + 160 + ] + } + ] + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + }, + { + "_target_": "AsDiscreted", + "keys": [ + "pred", + "label" + ], + "argmax": [ + true, + false + ], + "to_onehot": 105 + } + ] }, "dataset": { "_target_": "Dataset", - "data": "$[{'image': i, 'label': l} for i, l in zip(@images[-1:], @labels[-1:])]", + "data": "$[{'image': i, 'label': l} for i, l in zip(@images[-200:], @labels[-200:])]", "transform": "@validate#preprocessing" }, "dataloader": { @@ -241,7 +355,7 @@ "dataset": "@validate#dataset", "batch_size": 1, "shuffle": false, - "num_workers": 2 + "num_workers": 4 }, "inferer": { "_target_": "SlidingWindowInferer", @@ -251,10 +365,8 @@ 96 ], "sw_batch_size": 1, - "overlap": 0.4, - "device": "$torch.device('cpu')" + "overlap": 0.25 }, - "postprocessing": "%train#postprocessing", "handlers": [ { "_target_": "StatsHandler",