Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare for ZnTrack v0.8.0 release #356

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions apax/nodes/analysis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
import pathlib

import zntrack.utils
from ipsuite import base
import ase
import h5py
import znh5md
import zntrack

from .model import Apax
from apax.nodes.model import Apax

log = logging.getLogger(__name__)


class ApaxBatchPrediction(base.ProcessAtoms):
class ApaxBatchPrediction(zntrack.Node):
"""Create and Save the predictions from model on atoms.

Attributes
Expand All @@ -24,13 +27,19 @@ class ApaxBatchPrediction(base.ProcessAtoms):
predictions: list[Atoms] the atoms that have the predicted properties from model
"""

_module_ = "apax.nodes"
data: list[ase.Atoms] = zntrack.deps()

model: Apax = zntrack.deps()
batch_size: int = zntrack.params(1)
frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5")

def run(self):
self.atoms = []
calc = self.model.get_calculator()
data = self.get_data()
self.atoms = calc.batch_eval(data, self.batch_size)
frames = calc.batch_eval(self.data, self.batch_size)
znh5md.write(self.frames_path, frames)

@property
def frames(self) -> list[ase.Atoms]:
with self.state.fs.open(self.frames_path, "rb") as f:
with h5py.File(f, "r") as h5:
return znh5md.IO(file_handle=h5)[:]
29 changes: 14 additions & 15 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import zntrack.utils

from apax.md.simulate import run_md

from .model import ApaxBase
from .utils import check_duplicate_keys
from apax.nodes.model import ApaxBase
from apax.nodes.utils import check_duplicate_keys

log = logging.getLogger(__name__)

Expand All @@ -38,7 +37,7 @@ class ApaxJaxMD(zntrack.Node):
data_id: int = zntrack.params(-1)

model: ApaxBase = zntrack.deps()
repeat = zntrack.params(None)
repeat: typing.Optional[bool] = zntrack.params(None)

config: str = zntrack.params_path(None)

Expand All @@ -47,14 +46,11 @@ class ApaxJaxMD(zntrack.Node):
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()
_parameter: typing.Optional[dict] = None

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
with self.state.fs.open(self.config, "r") as f:
self._parameter = yaml.safe_load(f)

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
Expand All @@ -63,14 +59,17 @@ def _handle_parameter_file(self):
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)

def _write_initial_structure(self):
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

def run(self):
"""Primary method to run which executes all steps of the model training"""

self._handle_parameter_file()
if not self.state.restarted:
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)
self._write_initial_structure()

run_md(self.model._parameter, self._parameter, log_level="info")

Expand Down
84 changes: 39 additions & 45 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import pathlib
import typing as t
Expand All @@ -20,7 +21,10 @@


class ApaxBase(zntrack.Node):
pass
parameter: dict

def get_calculator(self, **kwargs):
raise NotImplementedError


class Apax(ApaxBase):
Expand All @@ -42,11 +46,11 @@ class Apax(ApaxBase):

data: list = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)
validation_data: list[ase.Atoms] = zntrack.deps()
model: t.Optional[ApaxBase] = zntrack.deps(None)
nl_skin: float = zntrack.params(0.5)
transformations: t.Optional[list[dict[str, dict]]] = zntrack.params(None)
log_level: str = zntrack.meta.Text("info")
log_level: str = "info"

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

Expand All @@ -55,38 +59,34 @@ class Apax(ApaxBase):
zntrack.nwd / "val_atoms.extxyz"
)

metrics = zntrack.metrics()
metrics: dict = zntrack.metrics()

_parameter: dict = None
@functools.cached_property
def parameter(self) -> dict:
parameter = yaml.safe_load(self.state.fs.read_text(self.config))

def _post_load_(self) -> None:
self._handle_parameter_file()
custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

def _handle_parameter_file(self):
self._parameter = yaml.safe_load(self.state.fs.read_text(self.config))
if self.model is not None:
param_files = self.model.parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
parameter["checkpoints"].update(base_path)
except KeyError:
parameter["checkpoints"] = base_path

with self.state.use_tmp_path():
custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

if self.model is not None:
param_files = self.model._parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
self._parameter["checkpoints"].update(base_path)
except KeyError:
self._parameter["checkpoints"] = base_path

check_duplicate_keys(custom_parameters, self._parameter["data"], log)
self._parameter["data"].update(custom_parameters)
check_duplicate_keys(custom_parameters, parameter["data"], log)
parameter["data"].update(custom_parameters)
return parameter

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter, log_level=self.log_level)
apax_run(self.parameter, log_level=self.log_level)

def get_metrics(self):
"""In addition to the plots write a model metric"""
Expand All @@ -104,7 +104,7 @@ def run(self):
if self.state.restarted and csv_path.is_file():
metrics_df = pd.read_csv(self.model_directory / "log.csv")

if metrics_df["epoch"].iloc[-1] >= self._parameter["n_epochs"] - 1:
if metrics_df["epoch"].iloc[-1] >= self.parameter["n_epochs"] - 1:
return

self.train_model()
Expand Down Expand Up @@ -156,8 +156,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
calc:
ase calculator object
"""

param_files = [m._parameter["data"]["directory"] for m in self.models]
param_files = [m.parameter["data"]["directory"] for m in self.models]

transformations = []
if self.transformations:
Expand Down Expand Up @@ -192,14 +191,9 @@ class ApaxImport(zntrack.Node):
nl_skin: float = zntrack.params(0.5)
transformations: t.Optional[list[dict[str, dict]]] = zntrack.params(None)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
@functools.cached_property
def parameter(self) -> dict:
return yaml.safe_load(self.state.fs.read_text(self.config))

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Expand All @@ -210,8 +204,8 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
ase calculator object
"""

directory = self._parameter["data"]["directory"]
exp = self._parameter["data"]["experiment"]
directory = self.parameter["data"]["directory"]
exp = self.parameter["data"]["experiment"]
model_dir = directory + "/" + exp

transformations = []
Expand Down Expand Up @@ -251,7 +245,7 @@ class ApaxCalibrate(ApaxBase):
See the apax documentation for available methods.
"""

model: t.Any = zntrack.deps()
model: ApaxBase = zntrack.deps()
validation_data: list[Atoms] = zntrack.deps()
batch_size: int = zntrack.params(32)
criterion: str = zntrack.params("ma_cal")
Expand All @@ -262,7 +256,7 @@ class ApaxCalibrate(ApaxBase):

nl_skin: float = zntrack.params(0.5)

metrics = zntrack.metrics()
metrics: dict = zntrack.metrics()

def run(self):
"""Primary method to run which executes all steps of the model training"""
Expand Down Expand Up @@ -294,7 +288,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
e_factor = self.metrics["e_factor"]
f_factor = self.metrics["f_factor"]

config_file = self.model._parameter["data"]["directory"]
config_file = self.model.parameter["data"]["directory"]

calibration = GlobalCalibration(
energy_factor=e_factor,
Expand Down
Loading
Loading