Skip to content

Commit

Permalink
updated default logic for patch_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
scap3yvt authored Jan 15, 2024
1 parent 5104dd8 commit 33657df
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions GANDLF/parseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,35 +638,29 @@ def parseConfig(config_file_path, version_check_flag=True):
print("DeprecationWarning: 'opt' has been superseded by 'optimizer'")
params["optimizer"] = params["opt"]

# initialize defaults for patch sampler
temp_patch_sampler_dict = {
"type": "uniform",
"enable_padding": False,
"padding_mode": "symmetric",
"biased_sampling": False,
}
# check if patch_sampler is defined in the config
if "patch_sampler" in params:
# check if user has passed a dict
temp_dict = {}
# if "patch_sampler" is a string, then it is the type of sampler
if isinstance(params["patch_sampler"], str):
temp_dict["type"] = params["patch_sampler"].lower()
print(
"WARNING: Defining 'patch_sampler' as a string will be deprecated in a future release, please use a dictionary instead"
)
temp_patch_sampler_dict["type"] = params["patch_sampler"].lower()
elif isinstance(params["patch_sampler"], dict):
# dict requires special handling
temp_dict = params["patch_sampler"]

# ensure "type" is defined in the dict
if not ("type" in temp_dict):
for key in temp_dict:
if "label" in key:
temp_dict["type"] = "label"
elif "weight" in key:
temp_dict["type"] = "weight"
else:
# default
temp_dict["type"] = "uniform"
break

# initialize defaults for patch sampler
temp_dict[key]["enable_padding"] = temp_dict[key].get("enable_padding", False)
temp_dict[key]["padding_mode"] = temp_dict[key].get("padding_mode", "symmetric")
temp_dict[key]["biased_sampling"] = temp_dict[key].get("biased_sampling", False)
for key in params["patch_sampler"]:
temp_patch_sampler_dict[key] = params["patch_sampler"][key]

# now assign the dict back to the params
params["patch_sampler"] = temp_dict
# now assign the dict back to the params
params["patch_sampler"] = temp_patch_sampler_dict
del temp_patch_sampler_dict

# define defaults
for current_parameter in parameter_defaults:
Expand Down

0 comments on commit 33657df

Please sign in to comment.