From b19a8ca0aab36a6705c0dca84fa9ceb11679241c Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:25:18 +0100 Subject: [PATCH] Save model info --- janus_core/helpers/mlip_calculators.py | 36 ++++++++++++++------------ janus_core/helpers/utils.py | 2 ++ tests/test_mlip_calculators.py | 2 ++ tests/test_utils.py | 9 +++++++ 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index b1800c0d..e1ad1a54 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -123,20 +123,20 @@ def choose_calculator( from mace.calculators import mace_mp # Default to "small" model and float64 precision - model = model_path if model_path else "small" + model_path = model_path if model_path else "small" kwargs.setdefault("default_dtype", "float64") - calculator = mace_mp(model=model, device=device, **kwargs) + calculator = mace_mp(model=model_path, device=device, **kwargs) elif arch == "mace_off": from mace import __version__ from mace.calculators import mace_off # Default to "small" model and float64 precision - model = model_path if model_path else "small" + model_path = model_path if model_path else "small" kwargs.setdefault("default_dtype", "float64") - calculator = mace_off(model=model, device=device, **kwargs) + calculator = mace_off(model=model_path, device=device, **kwargs) elif arch == "m3gnet": from matgl import __version__, load_model @@ -152,6 +152,7 @@ def choose_calculator( # Otherwise, load the model if given a path, else use a default model if isinstance(model_path, Potential): potential = model_path + model_path = "loaded_Potential" elif isinstance(model_path, Path): if model_path.is_file(): model_path = model_path.parent @@ -159,7 +160,8 @@ def choose_calculator( elif isinstance(model_path, str): potential = load_model(model_path) else: - potential = load_model("M3GNet-MP-2021.2.8-DIRECT-PES") + model_path = "M3GNet-MP-2021.2.8-DIRECT-PES" + potential = load_model(model_path) calculator = M3GNetCalculator(potential=potential, **kwargs) @@ -176,11 +178,13 @@ def choose_calculator( # Otherwise, load the model if given a path, else use a default model if isinstance(model_path, CHGNet): model = model_path + model_path = "loaded_CHGNet" elif isinstance(model_path, Path): model = CHGNet.from_file(model_path) elif isinstance(model_path, str): model = CHGNet.load(model_name=model_path, use_device=device) else: + model_path = "0.3.0" model = None calculator = CHGNetCalculator(model=model, use_device=device, **kwargs) @@ -195,16 +199,15 @@ def choose_calculator( # Set default path to directory containing config and model location if isinstance(model_path, Path): - path = model_path - if path.is_file(): - path = path.parent + if model_path.is_file(): + model_path = model_path.parent # If a string, assume referring to model_name e.g. "v5.27.2024" elif isinstance(model_path, str): - path = get_figshare_model_ff(model_name=model_path) + model_path = get_figshare_model_ff(model_name=model_path) else: - path = default_path() + model_path = default_path() - calculator = AlignnAtomwiseCalculator(path=path, device=device, **kwargs) + calculator = AlignnAtomwiseCalculator(path=model_path, device=device, **kwargs) elif arch == "sevennet": # Disable constant-imported-as-non-constant @@ -212,15 +215,13 @@ def choose_calculator( from sevenn.sevennet_calculator import SevenNetCalculator if isinstance(model_path, Path): - model = str(model_path) - elif isinstance(model_path, str): - model = model_path - else: - model = "SevenNet-0_11July2024" + model_path = str(model_path) + elif not isinstance(model_path, str): + model_path = "SevenNet-0_11July2024" kwargs.setdefault("file_type", "checkpoint") kwargs.setdefault("sevennet_config", None) - calculator = SevenNetCalculator(model=model, device=device, **kwargs) + calculator = SevenNetCalculator(model=model_path, device=device, **kwargs) else: raise ValueError( @@ -230,5 +231,6 @@ def choose_calculator( calculator.parameters["version"] = __version__ calculator.parameters["arch"] = arch + calculator.parameters["model"] = str(model_path) return calculator diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 3573465c..591a980b 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -257,6 +257,7 @@ def results_to_info( if struct.calc and "arch" in struct.calc.parameters: arch = struct.calc.parameters["arch"] struct.info["arch"] = arch + struct.info["mlip_model"] = struct.calc.parameters["model"] for key in properties & struct.calc.results.keys(): tag = f"{arch}_{key}" @@ -474,6 +475,7 @@ def output_structs( for image in images: if image.calc and "arch" in image.calc.parameters: image.info["arch"] = image.calc.parameters["arch"] + image.info["mlip_model"] = image.calc.parameters["model"] # Add label for system for image in images: diff --git a/tests/test_mlip_calculators.py b/tests/test_mlip_calculators.py index 6d1c7754..5a976ea8 100644 --- a/tests/test_mlip_calculators.py +++ b/tests/test_mlip_calculators.py @@ -56,6 +56,7 @@ def test_mlips(arch, device, kwargs): """Test mace calculators can be configured.""" calculator = choose_calculator(arch=arch, device=device, **kwargs) assert calculator.parameters["version"] is not None + assert calculator.parameters["model"] is not None def test_invalid_arch(): @@ -127,6 +128,7 @@ def test_extra_mlips(arch, device, kwargs): **kwargs, ) assert calculator.parameters["version"] is not None + assert calculator.parameters["model"] is not None except BadZipFile: pytest.skip() diff --git a/tests/test_utils.py b/tests/test_utils.py index aa37d41f..408edbd6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -76,6 +76,13 @@ def test_output_structs( else: results_keys = {"energy", "forces", "stress"} + if arch == "mace_mp": + model = "small" + if arch == "m3gnet": + model = "M3GNet-MP-2021.2.8-DIRECT-PES" + if arch == "chgnet": + model = "0.3.0" + label_keys = {f"{arch}_{key}" for key in results_keys} write_kwargs = {} @@ -114,6 +121,7 @@ def test_output_structs( if "set_info" not in write_kwargs or write_kwargs["set_info"]: assert label_keys <= struct.info.keys() | struct.arrays.keys() assert struct.info["arch"] == arch + assert struct.info["mlip_model"] == model # Check file written correctly if write_results if write_results: @@ -125,6 +133,7 @@ def test_output_structs( if "set_info" not in write_kwargs or write_kwargs["set_info"]: assert label_keys <= atoms.info.keys() | atoms.arrays.keys() assert atoms.info["arch"] == arch + assert atoms.info["mlip_model"] == model # Check calculator results depend on invalidate_calc if invalidate_calc: