Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[auto3dseg] Update DiNTS algorithm template #295

Merged
merged 18 commits into from
Sep 21, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ training:
amp: true
auto_scale_allowed: true
data_list_key: null
epoch_divided_factor: 36
input_channels: null
learning_rate: 0.2
log_output_file: "$@bundle_root + '/model_fold' + str(@fold) + '/training.log'"
Expand Down Expand Up @@ -60,11 +61,13 @@ training:
batch: true
smooth_nr: 1.0e-05
smooth_dr: 1.0e-05

optimizer:
_target_: torch.optim.SGD
lr: "@training#learning_rate"
momentum: 0.9
weight_decay: 4.0e-05

lr_scheduler:
_target_: torch.optim.lr_scheduler.PolynomialLR
optimizer: "$@training#optimizer"
Expand All @@ -73,9 +76,21 @@ training:

# fine-tuning
finetune:
activate: false
activate_finetune: false
pretrained_ckpt_name: "$@bundle_root + '/model_fold' + str(@fold) + '/best_metric_model.pt'"

overwrite:
learning_rate: 0.001
lr_scheduler:
_target_: torch.optim.lr_scheduler.ConstantLR
optimizer: "$@training#optimizer"
factor: 1.0
total_iters: '$@training#num_epochs // @training#num_epochs_per_validation + 1'
adapt_valid_mode: false
early_stop_mode: false
num_epochs: 20
num_epochs_per_validation: 1

# validation
validate:
ckpt_name: "$@bundle_root + '/model_fold' + str(@fold) + '/best_metric_model.pt'"
Expand All @@ -87,6 +102,7 @@ validate:
# inference
infer:
ckpt_name: "$@bundle_root + '/model_fold' + str(@fold) + '/best_metric_model.pt'"
save_prob: false
fast: true
data_list_key: testing
log_output_file: "$@bundle_root + '/model_fold' + str(@fold) + '/inference.log'"
Expand Down
77 changes: 61 additions & 16 deletions auto3dseg/algorithm_templates/dints/scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,36 @@ def pre_operation(config_file, **override):

if auto_scale_allowed:
output_classes = parser["training"]["output_classes"]
mem = get_mem_from_visible_gpus()
mem = min(mem) if isinstance(mem, list) else mem
mem = float(mem) / (1024.0**3)

try:
mem = get_mem_from_visible_gpus()
mem = min(mem) if isinstance(mem, list) else mem
mem = float(mem) / (1024.0**3)
except BaseException:
mem = 16.0

mem = max(1.0, mem - 1.0)
mem_bs2 = 6.0 + (20.0 - 6.0) * (output_classes - 2) / (105 - 2)
mem_bs9 = 24.0 + (74.0 - 24.0) * (output_classes - 2) / (105 - 2)
batch_size = 2 + (9 - 2) * (mem - mem_bs2) / (mem_bs9 - mem_bs2)
mem_bs2 = 6.0 + (20.0 - 6.0) * \
(output_classes - 2) / (105 - 2)
mem_bs9 = 24.0 + (74.0 - 24.0) * \
(output_classes - 2) / (105 - 2)
batch_size = 2 + (9 - 2) * \
(mem - mem_bs2) / (mem_bs9 - mem_bs2)
batch_size = int(batch_size)
batch_size = max(batch_size, 1)

parser["training"].update({"num_patches_per_iter": batch_size})
parser["training"].update({"num_patches_per_image": 2 * batch_size})
parser["training"].update(
{"num_patches_per_iter": batch_size})
parser["training"].update(
{"num_patches_per_image": 2 * batch_size})

# estimate data size based on number of images and image size
# estimate data size based on number of images and image
# size
_factor = 1.0

try:
_factor *= 1251.0 / float(parser["stats_summary"]["n_cases"])
_factor *= 1251.0 / \
float(parser["stats_summary"]["n_cases"])
_mean_shape = parser["stats_summary"]["image_stats"]["shape"]["mean"]
_factor *= float(_mean_shape[0]) / 240.0
_factor *= float(_mean_shape[1]) / 240.0
Expand All @@ -110,15 +122,23 @@ def pre_operation(config_file, **override):
_factor *= 96.0 / float(_patch_size[1])
_factor *= 96.0 / float(_patch_size[2])

_factor /= 6.0
if "training#epoch_divided_factor" in override:
epoch_divided_factor = override["training#epoch_divided_factor"]
else:
epoch_divided_factor = parser["training"]["epoch_divided_factor"]
epoch_divided_factor = float(epoch_divided_factor)
_factor /= epoch_divided_factor

_factor = max(1.0, _factor)

_estimated_epochs = 400.0
_estimated_epochs *= _factor

parser["training"].update({"num_epochs": int(_estimated_epochs / float(batch_size))})
parser["training"].update(
{"num_epochs": int(_estimated_epochs / float(batch_size))})

ConfigParser.export_config_file(parser.get(), _file, fmt="yaml", default_flow_style=None)
ConfigParser.export_config_file(
parser.get(), _file, fmt="yaml", default_flow_style=None)

return

Expand Down Expand Up @@ -151,6 +171,7 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
ckpt_name = parser.get_parsed_content("infer")["ckpt_name"]
data_list_key = parser.get_parsed_content("infer")["data_list_key"]
output_path = parser.get_parsed_content("infer")["output_path"]
save_prob = parser.get_parsed_content("infer#save_prob")

if not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
Expand Down Expand Up @@ -198,6 +219,29 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
]
self.post_transforms_prob = transforms.Compose(post_transforms)

if save_prob:
post_transforms += [
transforms.CopyItemsd(
keys="pred",
times=1,
names="prob",
),
transforms.Lambdad(
keys="prob",
func=lambda x: torch.floor(x * 255.0).type(torch.uint8)
),
transforms.SaveImaged(
keys="prob",
meta_keys="pred_meta_dict",
output_dir=os.path.join(output_path, "prob"),
output_postfix="",
resample=False,
print_log=False,
data_root_dir=data_file_base_dir,
output_dtype=np.uint8,
),
]

if softmax:
post_transforms += [transforms.AsDiscreted(keys="pred", argmax=True)]
else:
Expand All @@ -208,11 +252,12 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
keys="pred",
meta_keys="pred_meta_dict",
output_dir=output_path,
output_postfix="seg",
output_postfix="",
resample=False,
print_log=False,
data_root_dir=data_file_base_dir,
)
output_dtype=np.uint8,
),
]
self.post_transforms = transforms.Compose(post_transforms)

Expand All @@ -222,7 +267,7 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
def infer(self, image_file, save_mask=False):
self.model.eval()

batch_data = self.infer_transforms(image_file)
batch_data = self.infer_transforms({"image": image_file})
batch_data = list_data_collate([batch_data])

finished = None
Expand Down
Loading