From 3fd2213ac3b35f9254ca6e934431d3a81fd64701 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 11 Oct 2023 15:34:23 -0400 Subject: [PATCH] nequip-deploy --override --- CHANGELOG.md | 2 +- nequip/scripts/deploy.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 321c2ee4..f0025969 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ Most recent change on the bottom. - `include_file_as_baseline_config` for simple modifications of existing configs - `nequip-deploy --using-dataset` to support data-dependent deployment steps - Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574) -- `nequip-deploy build --checkpoint` to deploy specific checkpoints easily +- `nequip-deploy build --checkpoint` and `--override` to avoid many largely duplicated YAML files ### Changed - Always require explicit `seed` diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index ba238cba..a0772df9 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -210,6 +210,12 @@ def main(args=None): type=str, default=None, ) + build_parser.add_argument( + "--override", + help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.", + type=str, + default=None, + ) build_parser.add_argument( "--using-dataset", help="Allow model builders to use a dataset during deployment. By default uses the training dataset, but can point to a YAML file for another dataset.", @@ -265,6 +271,20 @@ def main(args=None): else: raise ValueError("one of --train-dir or --model must be given") + # Set override options before _set_global_options so that things like allow_tf32 are correctly handled + if args.override is not None: + override_options = yaml.load(args.override, Loader=yaml.Loader) + assert isinstance( + override_options, dict + ), "--override's YAML string must define a dictionary of top-level options" + overridden_keys = set(config.keys()).intersection(override_options.keys()) + set_keys = set(override_options.keys()) - set(overridden_keys) + logging.info( + f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}" + ) + config.update(override_options) + del override_options, overridden_keys, set_keys + _set_global_options(config) check_code_version(config)