Skip to content

Commit

Permalink
update training pipeline and supports multi-gpu train
Browse files Browse the repository at this point in the history
Signed-off-by: tangy5 <[email protected]>
  • Loading branch information
tangy5 committed Feb 1, 2023
1 parent e535525 commit e2a2409
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 18 deletions.
4 changes: 2 additions & 2 deletions models/wholeBody_ct_segmentation/configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@
96
],
"sw_batch_size": 1,
"overlap": 0.4,
"overlap": 0.3,
"padding_mode": "replicate",
"mode": "gaussian",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
"device": "$torch.device('cpu')"
},
"postprocessing": {
"_target_": "Compose",
Expand Down
144 changes: 128 additions & 16 deletions models/wholeBody_ct_segmentation/configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
"bundle_root": ".",
"ckpt_dir": "$@bundle_root + '/models'",
"output_dir": "$@bundle_root + '/eval'",
"dataset_dir": "../datasets/sampleTrain",
"images": "$list(sorted(glob.glob(@dataset_dir + '/imagesTr/*.nii.gz')))",
"labels": "$list(sorted(glob.glob(@dataset_dir + '/labelsTr/*.nii.gz')))",
"val_interval": 2,
"dataset_dir": "../datasets/totalSegmentator",
"images": "$list(sorted(glob.glob(@dataset_dir + '/images/*.nii.gz')))",
"labels": "$list(sorted(glob.glob(@dataset_dir + '/labels/*.nii.gz')))",
"val_interval": 5,
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"network_def": {
"_target_": "SegResNet",
Expand Down Expand Up @@ -150,16 +150,18 @@
"transforms": "$@train#deterministic_transforms + @train#random_transforms"
},
"dataset": {
"_target_": "Dataset",
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-1], @labels[:-1])]",
"transform": "@train#preprocessing"
"_target_": "CacheDataset",
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-200], @labels[:-200])]",
"transform": "@train#preprocessing",
"cache_rate": 0.5,
"num_workers": 8
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@train#dataset",
"batch_size": 1,
"batch_size": 16,
"shuffle": true,
"num_workers": 4
"num_workers": 8
},
"inferer": {
"_target_": "SimpleInferer"
Expand Down Expand Up @@ -213,7 +215,7 @@
},
"trainer": {
"_target_": "SupervisedTrainer",
"max_epochs": 5,
"max_epochs": 1000,
"device": "@device",
"train_data_loader": "@train#dataloader",
"network": "@network",
Expand All @@ -229,19 +231,131 @@
"validate": {
"preprocessing": {
"_target_": "Compose",
"transforms": "%train#deterministic_transforms"
"transforms": [
{
"_target_": "LoadImaged",
"keys": [
"image",
"label"
]
},
{
"_target_": "EnsureChannelFirstd",
"keys": [
"image",
"label"
]
},
{
"_target_": "EnsureTyped",
"keys": [
"image",
"label"
]
},
{
"_target_": "Orientationd",
"keys": [
"image",
"label"
],
"axcodes": "RAS"
},
{
"_target_": "Spacingd",
"keys": [
"image",
"label"
],
"pixdim": [
1.5,
1.5,
1.5
],
"mode": [
"bilinear",
"nearest"
]
},
{
"_target_": "NormalizeIntensityd",
"keys": "image",
"nonzero": true
},
{
"_target_": "CropForegroundd",
"keys": [
"image",
"label"
],
"source_key": "image",
"margin": 10,
"k_divisible": [
96,
96,
96
]
},
{
"_target_": "GaussianSmoothd",
"keys": [
"image"
],
"sigma": 0.4
},
{
"_target_": "ScaleIntensityd",
"keys": "image",
"minv": -1.0,
"maxv": 1.0
},
{
"_target_": "CenterSpatialCropd",
"keys": [
"image",
"label"
],
"roi_size": [
160,
160,
160
]
}
]
},
"postprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "Activationsd",
"keys": "pred",
"softmax": true
},
{
"_target_": "AsDiscreted",
"keys": [
"pred",
"label"
],
"argmax": [
true,
false
],
"to_onehot": 105
}
]
},
"dataset": {
"_target_": "Dataset",
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[-1:], @labels[-1:])]",
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[-200:], @labels[-200:])]",
"transform": "@validate#preprocessing"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@validate#dataset",
"batch_size": 1,
"shuffle": false,
"num_workers": 2
"num_workers": 4
},
"inferer": {
"_target_": "SlidingWindowInferer",
Expand All @@ -251,10 +365,8 @@
96
],
"sw_batch_size": 1,
"overlap": 0.4,
"device": "$torch.device('cpu')"
"overlap": 0.25
},
"postprocessing": "%train#postprocessing",
"handlers": [
{
"_target_": "StatsHandler",
Expand Down

0 comments on commit e2a2409

Please sign in to comment.