Skip to content

Commit

Permalink
Revert "[auto3dseg] Update DiNTS algorithm template (#295)"
Browse files Browse the repository at this point in the history
This reverts commit c1d81b4.
  • Loading branch information
wyli authored Sep 21, 2023
1 parent c1d81b4 commit 442a1ab
Show file tree
Hide file tree
Showing 5 changed files with 746 additions and 995 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ training:
amp: true
auto_scale_allowed: true
data_list_key: null
epoch_divided_factor: 36
input_channels: null
learning_rate: 0.2
log_output_file: "$@bundle_root + '/model_fold' + str(@fold) + '/training.log'"
Expand Down Expand Up @@ -61,13 +60,11 @@ training:
batch: true
smooth_nr: 1.0e-05
smooth_dr: 1.0e-05

optimizer:
_target_: torch.optim.SGD
lr: "@training#learning_rate"
momentum: 0.9
weight_decay: 4.0e-05

lr_scheduler:
_target_: torch.optim.lr_scheduler.PolynomialLR
optimizer: "$@training#optimizer"
Expand All @@ -76,21 +73,9 @@ training:

# fine-tuning
finetune:
activate_finetune: false
activate: false
pretrained_ckpt_name: "$@bundle_root + '/model_fold' + str(@fold) + '/best_metric_model.pt'"

overwrite:
learning_rate: 0.001
lr_scheduler:
_target_: torch.optim.lr_scheduler.ConstantLR
optimizer: "$@training#optimizer"
factor: 1.0
total_iters: '$@training#num_epochs // @training#num_epochs_per_validation + 1'
adapt_valid_mode: false
early_stop_mode: false
num_epochs: 20
num_epochs_per_validation: 1

# validation
validate:
ckpt_name: "$@bundle_root + '/model_fold' + str(@fold) + '/best_metric_model.pt'"
Expand All @@ -102,7 +87,6 @@ validate:
# inference
infer:
ckpt_name: "$@bundle_root + '/model_fold' + str(@fold) + '/best_metric_model.pt'"
save_prob: false
fast: true
data_list_key: testing
log_output_file: "$@bundle_root + '/model_fold' + str(@fold) + '/inference.log'"
Expand Down
77 changes: 16 additions & 61 deletions auto3dseg/algorithm_templates/dints/scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,36 +80,24 @@ def pre_operation(config_file, **override):

if auto_scale_allowed:
output_classes = parser["training"]["output_classes"]

try:
mem = get_mem_from_visible_gpus()
mem = min(mem) if isinstance(mem, list) else mem
mem = float(mem) / (1024.0**3)
except BaseException:
mem = 16.0

mem = get_mem_from_visible_gpus()
mem = min(mem) if isinstance(mem, list) else mem
mem = float(mem) / (1024.0**3)
mem = max(1.0, mem - 1.0)
mem_bs2 = 6.0 + (20.0 - 6.0) * \
(output_classes - 2) / (105 - 2)
mem_bs9 = 24.0 + (74.0 - 24.0) * \
(output_classes - 2) / (105 - 2)
batch_size = 2 + (9 - 2) * \
(mem - mem_bs2) / (mem_bs9 - mem_bs2)
mem_bs2 = 6.0 + (20.0 - 6.0) * (output_classes - 2) / (105 - 2)
mem_bs9 = 24.0 + (74.0 - 24.0) * (output_classes - 2) / (105 - 2)
batch_size = 2 + (9 - 2) * (mem - mem_bs2) / (mem_bs9 - mem_bs2)
batch_size = int(batch_size)
batch_size = max(batch_size, 1)

parser["training"].update(
{"num_patches_per_iter": batch_size})
parser["training"].update(
{"num_patches_per_image": 2 * batch_size})
parser["training"].update({"num_patches_per_iter": batch_size})
parser["training"].update({"num_patches_per_image": 2 * batch_size})

# estimate data size based on number of images and image
# size
# estimate data size based on number of images and image size
_factor = 1.0

try:
_factor *= 1251.0 / \
float(parser["stats_summary"]["n_cases"])
_factor *= 1251.0 / float(parser["stats_summary"]["n_cases"])
_mean_shape = parser["stats_summary"]["image_stats"]["shape"]["mean"]
_factor *= float(_mean_shape[0]) / 240.0
_factor *= float(_mean_shape[1]) / 240.0
Expand All @@ -122,23 +110,15 @@ def pre_operation(config_file, **override):
_factor *= 96.0 / float(_patch_size[1])
_factor *= 96.0 / float(_patch_size[2])

if "training#epoch_divided_factor" in override:
epoch_divided_factor = override["training#epoch_divided_factor"]
else:
epoch_divided_factor = parser["training"]["epoch_divided_factor"]
epoch_divided_factor = float(epoch_divided_factor)
_factor /= epoch_divided_factor

_factor /= 6.0
_factor = max(1.0, _factor)

_estimated_epochs = 400.0
_estimated_epochs *= _factor

parser["training"].update(
{"num_epochs": int(_estimated_epochs / float(batch_size))})
parser["training"].update({"num_epochs": int(_estimated_epochs / float(batch_size))})

ConfigParser.export_config_file(
parser.get(), _file, fmt="yaml", default_flow_style=None)
ConfigParser.export_config_file(parser.get(), _file, fmt="yaml", default_flow_style=None)

return

Expand Down Expand Up @@ -171,7 +151,6 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
ckpt_name = parser.get_parsed_content("infer")["ckpt_name"]
data_list_key = parser.get_parsed_content("infer")["data_list_key"]
output_path = parser.get_parsed_content("infer")["output_path"]
save_prob = parser.get_parsed_content("infer#save_prob")

if not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
Expand Down Expand Up @@ -219,29 +198,6 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
]
self.post_transforms_prob = transforms.Compose(post_transforms)

if save_prob:
post_transforms += [
transforms.CopyItemsd(
keys="pred",
times=1,
names="prob",
),
transforms.Lambdad(
keys="prob",
func=lambda x: torch.floor(x * 255.0).type(torch.uint8)
),
transforms.SaveImaged(
keys="prob",
meta_keys="pred_meta_dict",
output_dir=os.path.join(output_path, "prob"),
output_postfix="",
resample=False,
print_log=False,
data_root_dir=data_file_base_dir,
output_dtype=np.uint8,
),
]

if softmax:
post_transforms += [transforms.AsDiscreted(keys="pred", argmax=True)]
else:
Expand All @@ -252,12 +208,11 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
keys="pred",
meta_keys="pred_meta_dict",
output_dir=output_path,
output_postfix="",
output_postfix="seg",
resample=False,
print_log=False,
data_root_dir=data_file_base_dir,
output_dtype=np.uint8,
),
)
]
self.post_transforms = transforms.Compose(post_transforms)

Expand All @@ -267,7 +222,7 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
def infer(self, image_file, save_mask=False):
self.model.eval()

batch_data = self.infer_transforms({"image": image_file})
batch_data = self.infer_transforms(image_file)
batch_data = list_data_collate([batch_data])

finished = None
Expand Down
Loading

0 comments on commit 442a1ab

Please sign in to comment.