Skip to content

Commit

Permalink
134 model and architecture bug (#138)
Browse files Browse the repository at this point in the history
* add test for no model or arch

---------

Co-authored-by: ElliottKasoar <[email protected]>
Co-authored-by: Alin Marin Elena <[email protected]>
Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
4 people authored Jun 7, 2024
1 parent b250ed7 commit c54d881
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 59 deletions.
110 changes: 75 additions & 35 deletions aiida_mlip/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def validate_inputs(
The inputs dictionary.
port_namespace : `aiida.engine.processes.ports.PortNamespace`
An instance of aiida's `PortNameSpace`.
An instance of aiida's `PortNamespace`.
Raises
------
Expand All @@ -44,6 +44,30 @@ def validate_inputs(
raise InputValidationError(
"Structure must be specified through 'struct' or 'config'"
)
if (
"arch" not in inputs
and "model" not in inputs
and ("config" not in inputs or "arch" not in inputs["config"])
):
raise InputValidationError(
"'arch' must be specified in inputs, config file or ModelData"
)

if "model" not in inputs and (
"config" not in inputs or "model" not in inputs["config"]
):
raise InputValidationError(
"'model' must be specified either in the inputs or in the config file"
)

if (
"arch" in inputs
and "model" in inputs
and inputs["arch"].value is not inputs["model"].architecture
):
raise InputValidationError(
"'arch' in ModelData and in 'arch' input must be the same"
)


class BaseJanus(CalcJob): # numpydoc ignore=PR01
Expand Down Expand Up @@ -206,42 +230,10 @@ def prepare_for_submission(

# Define architecture from model if model is given,
# otherwise get architecture from inputs and download default model
architecture = None
architecture = (
str((self.inputs.model).architecture)
if "model" in self.inputs and hasattr(self.inputs.model, "architecture")
else str(self.inputs.arch.value) if "arch" in self.inputs else None
)

if architecture:
cmd_line["arch"] = architecture

model_path = None
if "model" in self.inputs:
# Raise error if model is None
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
else:
if "config" in self.inputs and "model" in self.inputs.config:
model_path = None
else:
if "arch" in self.inputs:
# if model is not given (which is different than it being None)
model_path = ModelData.download(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
).filepath
if model_path:
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
self._add_arch_to_cmdline(cmd_line)
self._add_model_to_cmdline(cmd_line)

if "config" in self.inputs:
# Check if there are values in the config file that are also in the command
# line and do not store them as only the cmd line parameters will be used
config_dict = self.inputs.config.as_dictionary
overlapping_params = cmd_line.keys() & config_dict.keys()
# Store the other parameters
self.inputs.config.store_content(skip=overlapping_params)
# Add config file to command line
cmd_line["config"] = "config.yaml"
config_parse = self.inputs.config.get_content()
Expand Down Expand Up @@ -274,3 +266,51 @@ def prepare_for_submission(
]

return calcinfo

def _add_arch_to_cmdline(self, cmd_line: dict) -> dict:
"""
Find architecture in inputs or config file and add to command line if needed.
Parameters
----------
cmd_line : dict
Dictionary containing the cmd line keys.
Returns
-------
dict
Dictionary containing the cmd line keys updated with the architecture.
"""
architecture = None
if "model" in self.inputs and hasattr(self.inputs.model, "architecture"):
architecture = str((self.inputs.model).architecture)
elif "arch" in self.inputs:
architecture = str(self.inputs.arch.value)
if architecture:
cmd_line["arch"] = architecture

def _add_model_to_cmdline(
self,
cmd_line: dict,
) -> dict:
"""
Find model in inputs or config file and add to command line if needed.
Parameters
----------
cmd_line : dict
Dictionary containing the cmd line keys.
Returns
-------
dict
Dictionary containing the cmd line keys updated with the model.
"""
model_path = None
if "model" in self.inputs:
# Raise error if model is None (different than model not given as input)
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
if model_path:
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
3 changes: 3 additions & 0 deletions tests/calculations/configs/config_noarch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
properties:
- "energy"
model: "small"
3 changes: 3 additions & 0 deletions tests/calculations/configs/config_nomodel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
properties:
- "energy"
arch: "mace_mp"
67 changes: 43 additions & 24 deletions tests/calculations/test_singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiida.orm import Str, StructureData
from aiida.plugins import CalculationFactory

from aiida_mlip.data.config import JanusConfigfile
from aiida_mlip.data.model import ModelData


Expand Down Expand Up @@ -61,48 +62,66 @@ def test_singlepoint(fixture_sandbox, generate_calc_job, janus_code, model_folde
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)


def test_singlepoint_model_download(fixture_sandbox, generate_calc_job, janus_code):
"""Test generating singlepoint calculation job."""

def test_sp_nostruct(fixture_sandbox, generate_calc_job, model_folder, janus_code):
"""Test singlepoint calculation with error input"""
entry_point_name = "mlip.sp"
model_file = model_folder / "mace_mp_small.model"
# pylint:disable=line-too-long
inputs = {
"metadata": {"options": {"resources": {"num_machines": 1}}},
"code": janus_code,
"arch": Str("mace"),
"precision": Str("float64"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
"model": ModelData.local_file(model_file, architecture="mace"),
"device": Str("cpu"),
}
with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

retrieve_list = [
calc_info.uuid,
"aiida.log",
"aiida-results.xyz",
"aiida-stdout.txt",
]
def test_sp_nomodel(fixture_sandbox, generate_calc_job, config_folder, janus_code):
"""Test singlepoint calculation with missing model"""
entry_point_name = "mlip.sp"

# Check the attributes of the returned `CalcInfo`
assert fixture_sandbox.get_content_list() == ["aiida.xyz"]
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)
inputs = {
"code": janus_code,
"metadata": {"options": {"resources": {"num_machines": 1}}},
"config": JanusConfigfile(config_folder / "config_nomodel.yml"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

def test_sp_nostruct(fixture_sandbox, generate_calc_job, model_folder, janus_code):
"""Test singlepoint calculation with error input"""

def test_sp_noarch(fixture_sandbox, generate_calc_job, config_folder, janus_code):
"""Test singlepoint calculation with missing architecture"""
entry_point_name = "mlip.sp"
model_file = model_folder / "mace_mp_small.model"
# pylint:disable=line-too-long

inputs = {
"code": janus_code,
"metadata": {"options": {"resources": {"num_machines": 1}}},
"config": JanusConfigfile(config_folder / "config_noarch.yml"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)


def test_two_arch(fixture_sandbox, generate_calc_job, model_folder, janus_code):
"""Test singlepoint calculation with two defined architectures"""
entry_point_name = "mlip.sp"
model_file = model_folder / "mace_mp_small.model"

inputs = {
"code": janus_code,
"arch": Str("mace"),
"precision": Str("float64"),
"model": ModelData.local_file(model_file, architecture="mace"),
"device": Str("cpu"),
"metadata": {"options": {"resources": {"num_machines": 1}}},
"model": ModelData.local_file(model_file, architecture="mace_mp"),
"arch": Str("chgnet"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

Expand Down

0 comments on commit c54d881

Please sign in to comment.