Skip to content

Commit

Permalink
Fix model download (stfc#150)
Browse files Browse the repository at this point in the history
* Fix model download

---------

Co-authored-by: ElliottKasoar <[email protected]>
Co-authored-by: Alin Marin Elena <[email protected]>
Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
4 people authored Jul 19, 2024
1 parent 57cc38e commit 9336017
Show file tree
Hide file tree
Showing 22 changed files with 200 additions and 197 deletions.
35 changes: 23 additions & 12 deletions aiida_mlip/calculations/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base class for features common to most calculations."""

import shutil

from ase.io import read, write

from aiida.common import InputValidationError, datastructures
Expand Down Expand Up @@ -63,10 +65,13 @@ def validate_inputs(
if (
"arch" in inputs
and "model" in inputs
and inputs["arch"].value is not inputs["model"].architecture
and inputs["arch"].value != inputs["model"].architecture
):
inputvalue = inputs["arch"].value
modelvalue = inputs["model"].architecture
raise InputValidationError(
"'arch' in ModelData and in 'arch' input must be the same"
"'arch' in ModelData and in inputs must be the same, "
f"but they are {modelvalue} and {inputvalue}"
)


Expand Down Expand Up @@ -191,16 +196,14 @@ def prepare_for_submission(
Parameters
----------
folder : aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
Folder where the calculation is run.
Returns
-------
aiida.common.datastructures.CalcInfo
An instance of `aiida.common.datastructures.CalcInfo`.
"""

# Create needed inputs

if "struct" in self.inputs:
structure = self.inputs.struct
elif "config" in self.inputs and "struct" in self.inputs.config.as_dictionary:
Expand All @@ -211,8 +214,8 @@ def prepare_for_submission(
# Transform the structure data in xyz file called input_filename
input_filename = self.inputs.metadata.options.input_filename
atoms = structure.get_ase()
# with folder.open(input_filename, mode="w", encoding='utf8') as file:
write(folder.abspath + "/" + input_filename, images=atoms)
with folder.open(input_filename, mode="w", encoding="utf8") as file:
write(file.name, images=atoms)

log_filename = (self.inputs.log_filename).value
cmd_line = {
Expand All @@ -231,7 +234,7 @@ def prepare_for_submission(
# Define architecture from model if model is given,
# otherwise get architecture from inputs and download default model
self._add_arch_to_cmdline(cmd_line)
self._add_model_to_cmdline(cmd_line)
self._add_model_to_cmdline(cmd_line, folder)

if "config" in self.inputs:
# Add config file to command line
Expand Down Expand Up @@ -290,8 +293,7 @@ def _add_arch_to_cmdline(self, cmd_line: dict) -> dict:
cmd_line["arch"] = architecture

def _add_model_to_cmdline(
self,
cmd_line: dict,
self, cmd_line: dict, folder: aiida.common.folders.Folder
) -> dict:
"""
Find model in inputs or config file and add to command line if needed.
Expand All @@ -301,6 +303,9 @@ def _add_model_to_cmdline(
cmd_line : dict
Dictionary containing the cmd line keys.
folder : ~aiida.common.folders.Folder
Folder where the calculation is run.
Returns
-------
dict
Expand All @@ -311,6 +316,12 @@ def _add_model_to_cmdline(
# Raise error if model is None (different than model not given as input)
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
if model_path:

with (
self.inputs.model.open(mode="rb") as source,
folder.open("mlff.model", mode="wb") as target,
):
shutil.copyfileobj(source, target)

model_path = "mlff.model"
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
2 changes: 1 addition & 1 deletion aiida_mlip/calculations/geomopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def prepare_for_submission(
Parameters
----------
folder : aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
Folder where the calculation is run.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion aiida_mlip/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def prepare_for_submission(
Parameters
----------
folder : aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
Folder where the calculation is run.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion aiida_mlip/calculations/singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def prepare_for_submission(
Parameters
----------
folder : aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
Folder where the calculation is run.
Returns
-------
Expand Down
10 changes: 6 additions & 4 deletions aiida_mlip/calculations/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Class for training machine learning models."""

from pathlib import Path
import shutil

from aiida.common import InputValidationError, datastructures
import aiida.common.folders
Expand Down Expand Up @@ -154,7 +155,7 @@ def prepare_for_submission(
Parameters
----------
folder : aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
Folder where the calculation is run.
Returns
-------
Expand All @@ -175,9 +176,10 @@ def prepare_for_submission(

# Add foundation_model to the config file if fine-tuning is enabled
if self.inputs.fine_tune and "foundation_model" in self.inputs:
model_data = self.inputs.foundation_model
foundation_model_path = model_data.filepath
config_parse += f"\nfoundation_model: {foundation_model_path}"
with self.inputs.foundation_model.open(mode="rb") as source:
with folder.open("mlff.model", mode="wb") as target:
shutil.copyfileobj(source, target)
config_parse += "foundation_model: mlff.model"

# Copy config file content inside the folder where the calculation is run
config_copy = "mlip_train.yml"
Expand Down
132 changes: 49 additions & 83 deletions aiida_mlip/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from pathlib import Path
from typing import Any, Optional, Union
from urllib import request
from urllib.parse import urlparse

from aiida.orm import SinglefileData
from aiida.orm import QueryBuilder, SinglefileData, load_node


class ModelData(SinglefileData):
Expand All @@ -26,17 +25,17 @@ class ModelData(SinglefileData):
----------
architecture : str
Architecture of the mlip model.
filepath : str
Path of the mlip model.
model_hash : str
Hash of the model.
Methods
-------
set_file(file, filename=None, architecture=None, **kwargs)
Set the file for the node.
local_file(file, architecture, filename=None):
from_local(file, architecture, filename=None):
Create a ModelData instance from a local file.
download(url, architecture, filename=None, cache_dir=None, force_download=False)
Download a file from a URL and save it as ModelData.
from_uri(uri, architecture, filename=None, cache_dir=None, keep_file=False)
Download a file from a URI and save it as ModelData.
Other Parameters
----------------
Expand Down Expand Up @@ -69,47 +68,6 @@ def _calculate_hash(file: Union[str, Path]) -> str:
file_hash = sha256.hexdigest()
return file_hash

@classmethod
def _check_existing_file(cls, file: Union[str, Path]) -> Path:
"""
Check if a file already exists and return the path of the existing file.
Parameters
----------
file : Union[str, Path]
Path to the downloaded model file.
Returns
-------
Path
The path of the model file of interest (same as input path if no duplicates
were found).
"""
file_hash = cls._calculate_hash(file)

def is_diff_file(curr_path: Path) -> bool:
"""
Filter to check if two files are different.
Parameters
----------
curr_path : Path
Path to the file to compare with.
Returns
-------
bool
True if the files are different, False otherwise.
"""
return curr_path.is_file() and not curr_path.samefile(file)

file_folder = Path(file).parent
for existing_file in filter(is_diff_file, file_folder.rglob("*")):
if cls._calculate_hash(existing_file) == file_hash:
file.unlink()
return existing_file
return Path(file)

def __init__(
self,
file: Union[str, Path],
Expand All @@ -136,7 +94,6 @@ def __init__(
"""
super().__init__(file, filename, **kwargs)
self.base.attributes.set("architecture", architecture)
self.base.attributes.set("filepath", str(file))

def set_file(
self,
Expand Down Expand Up @@ -164,10 +121,12 @@ def set_file(
"""
super().set_file(file, filename, **kwargs)
self.base.attributes.set("architecture", architecture)
self.base.attributes.set("filepath", str(file))
# here compute hash and set attribute
model_hash = self._calculate_hash(file)
self.base.attributes.set("model_hash", model_hash)

@classmethod
def local_file(
def from_local(
cls,
file: Union[str, Path],
architecture: str,
Expand Down Expand Up @@ -195,31 +154,31 @@ def local_file(

@classmethod
# pylint: disable=too-many-arguments
def download(
def from_uri(
cls,
url: str,
uri: str,
architecture: str,
filename: Optional[str] = None,
filename: Optional[str] = "tmp_file.model",
cache_dir: Optional[Union[str, Path]] = None,
force_download: Optional[bool] = False,
keep_file: Optional[bool] = False,
):
"""
Download a file from a URL and save it as ModelData.
Download a file from a URI and save it as ModelData.
Parameters
----------
url : str
URL of the file to download.
uri : str
URI of the file to download.
architecture : [str]
Architecture of the mlip model.
filename : Optional[str], optional
Name to be used for the file (defaults to the name of provided file).
Name to be used for the file defaults to tmp_file.model.
cache_dir : Optional[Union[str, Path]], optional
Path to the folder where the file has to be saved
(defaults to "~/.cache/mlips/").
force_download : Optional[bool], optional
True to keep the downloaded model even if there are duplicates
(default: False).
keep_file : Optional[bool], optional
True to keep the downloaded model, even if there are duplicates.
(default: False, the file is deleted and only saved in the database).
Returns
-------
Expand All @@ -231,32 +190,39 @@ def download(
)
arch_dir = (cache_dir / architecture) if architecture else cache_dir

# cache_path = cache_dir.resolve()
arch_path = arch_dir.resolve()
arch_path.mkdir(parents=True, exist_ok=True)

model_name = urlparse(url).path.split("/")[-1]
file = arch_path / filename

file = arch_path / filename if filename else arch_path / model_name
# Download file
request.urlretrieve(uri, file)

# If file already exists, use next indexed name
stem = file.stem
i = 1
while file.exists():
i += 1
file = file.with_stem(f"{stem}_{i}")
model = cls.from_local(file=file, architecture=architecture)

# Download file
request.urlretrieve(url, file)
if keep_file:
return model

if force_download:
print(f"filename changed to {file}")
return cls.local_file(file=file, architecture=architecture)
file.unlink(missing_ok=True)

# Check if the same model was used previously
qb = QueryBuilder()
qb.append(
ModelData,
filters={
"attributes.model_hash": model.model_hash,
"attributes.architecture": model.architecture,
"ctime": {"!in": [model.ctime]},
},
project=["attributes", "pk", "ctime"],
)

# Check if the hash of the just downloaded file matches any other file
filepath = cls._check_existing_file(file)
if qb.count() != 0:
model = load_node(
qb.first()[1]
) # This gets the pk of the first model in the query

return cls.local_file(file=filepath, architecture=architecture)
return model

@property
def architecture(self) -> str:
Expand All @@ -271,13 +237,13 @@ def architecture(self) -> str:
return self.base.attributes.get("architecture")

@property
def filepath(self) -> str:
def model_hash(self) -> str:
"""
Return the filepath.
Return hash of the architecture.
Returns
-------
str
Path of the mlip model.
Hash of the MLIP model.
"""
return self.base.attributes.get("filepath")
return self.base.attributes.get("model_hash")
Loading

0 comments on commit 9336017

Please sign in to comment.