Skip to content

Commit

Permalink
nequip-deploy --override
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Oct 11, 2023
1 parent 4aabe9f commit 3fd2213
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Most recent change on the bottom.
- `include_file_as_baseline_config` for simple modifications of existing configs
- `nequip-deploy --using-dataset` to support data-dependent deployment steps
- Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574)
- `nequip-deploy build --checkpoint` to deploy specific checkpoints easily
- `nequip-deploy build --checkpoint` and `--override` to avoid many largely duplicated YAML files

### Changed
- Always require explicit `seed`
Expand Down
20 changes: 20 additions & 0 deletions nequip/scripts/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ def main(args=None):
type=str,
default=None,
)
build_parser.add_argument(
"--override",
help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.",
type=str,
default=None,
)
build_parser.add_argument(
"--using-dataset",
help="Allow model builders to use a dataset during deployment. By default uses the training dataset, but can point to a YAML file for another dataset.",
Expand Down Expand Up @@ -265,6 +271,20 @@ def main(args=None):
else:
raise ValueError("one of --train-dir or --model must be given")

# Set override options before _set_global_options so that things like allow_tf32 are correctly handled
if args.override is not None:
override_options = yaml.load(args.override, Loader=yaml.Loader)
assert isinstance(
override_options, dict
), "--override's YAML string must define a dictionary of top-level options"
overridden_keys = set(config.keys()).intersection(override_options.keys())
set_keys = set(override_options.keys()) - set(overridden_keys)
logging.info(
f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}"
)
config.update(override_options)
del override_options, overridden_keys, set_keys

_set_global_options(config)
check_code_version(config)

Expand Down

0 comments on commit 3fd2213

Please sign in to comment.