Skip to content

Commit

Permalink
Add nequip-deploy build --checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Oct 11, 2023
1 parent 3f03c77 commit 4aabe9f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Most recent change on the bottom.
- `include_file_as_baseline_config` for simple modifications of existing configs
- `nequip-deploy --using-dataset` to support data-dependent deployment steps
- Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574)
- `nequip-deploy build --checkpoint` to deploy specific checkpoints easily

### Changed
- Always require explicit `seed`
Expand Down
20 changes: 15 additions & 5 deletions nequip/scripts/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def main(args=None):
help="Path to a working directory from a training session to deploy.",
type=pathlib.Path,
)
build_parser.add_argument(
"--checkpoint",
help="Which model checkpoint from --train-dir to deploy. Defaults to `best_model.pth`. If --train-dir is provided, this is a relative path; if --model is provided instead, this is an absolute path.",
type=str,
default=None,
)
build_parser.add_argument(
"--using-dataset",
help="Allow model builders to use a dataset during deployment. By default uses the training dataset, but can point to a YAML file for another dataset.",
Expand Down Expand Up @@ -246,12 +252,13 @@ def main(args=None):
state_dict = None
if args.model and args.train_dir:
raise ValueError("--model and --train-dir cannot both be specified.")
checkpoint_file = args.checkpoint
if args.train_dir is not None:
logging.info("Loading best_model from training session...")
if checkpoint_file is None:
checkpoint_file = "best_model.pth"
logging.info(f"Loading {checkpoint_file} from training session...")
checkpoint_file = str(args.train_dir / "best_model.pth")
config = Config.from_file(str(args.train_dir / "config.yaml"))
state_dict = torch.load(
str(args.train_dir / "best_model.pth"), map_location="cpu"
)
elif args.model is not None:
logging.info("Building model from config...")
config = Config.from_file(str(args.model), defaults=default_config)
Expand All @@ -278,7 +285,10 @@ def main(args=None):
global _current_metadata
_current_metadata = {}
model = model_from_config(config, dataset=dataset, deploy=True)
if state_dict is not None:
if checkpoint_file is not None:
state_dict = torch.load(
str(args.train_dir / "best_model.pth"), map_location="cpu"
)
model.load_state_dict(state_dict, strict=True)

# -- compile --
Expand Down

0 comments on commit 4aabe9f

Please sign in to comment.