From 4aabe9f4f6587fac8696eaacb50e2cbf9ff99dee Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:42:37 -0400 Subject: [PATCH] Add nequip-deploy build --checkpoint --- CHANGELOG.md | 1 + nequip/scripts/deploy.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24812ad5..321c2ee4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Most recent change on the bottom. - `include_file_as_baseline_config` for simple modifications of existing configs - `nequip-deploy --using-dataset` to support data-dependent deployment steps - Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574) +- `nequip-deploy build --checkpoint` to deploy specific checkpoints easily ### Changed - Always require explicit `seed` diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index bcee8255..ba238cba 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -204,6 +204,12 @@ def main(args=None): help="Path to a working directory from a training session to deploy.", type=pathlib.Path, ) + build_parser.add_argument( + "--checkpoint", + help="Which model checkpoint from --train-dir to deploy. Defaults to `best_model.pth`. If --train-dir is provided, this is a relative path; if --model is provided instead, this is an absolute path.", + type=str, + default=None, + ) build_parser.add_argument( "--using-dataset", help="Allow model builders to use a dataset during deployment. By default uses the training dataset, but can point to a YAML file for another dataset.", @@ -246,12 +252,13 @@ def main(args=None): state_dict = None if args.model and args.train_dir: raise ValueError("--model and --train-dir cannot both be specified.") + checkpoint_file = args.checkpoint if args.train_dir is not None: - logging.info("Loading best_model from training session...") + if checkpoint_file is None: + checkpoint_file = "best_model.pth" + logging.info(f"Loading {checkpoint_file} from training session...") + checkpoint_file = str(args.train_dir / "best_model.pth") config = Config.from_file(str(args.train_dir / "config.yaml")) - state_dict = torch.load( - str(args.train_dir / "best_model.pth"), map_location="cpu" - ) elif args.model is not None: logging.info("Building model from config...") config = Config.from_file(str(args.model), defaults=default_config) @@ -278,7 +285,10 @@ def main(args=None): global _current_metadata _current_metadata = {} model = model_from_config(config, dataset=dataset, deploy=True) - if state_dict is not None: + if checkpoint_file is not None: + state_dict = torch.load( + str(args.train_dir / "best_model.pth"), map_location="cpu" + ) model.load_state_dict(state_dict, strict=True) # -- compile --