diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 773f7e6c..f9323362 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -47,9 +47,12 @@ class Apax(ApaxBase): verbosity of logging during training """ - data: list[ase.Atoms] = zntrack.deps() + data: list[ase.Atoms] | None = zntrack.deps() + data_path: str | pathlib.Path | None = zntrack.deps_path(None) + config: str = zntrack.params_path() - validation_data: list[ase.Atoms] = zntrack.deps() + validation_data: list[ase.Atoms] | None = zntrack.deps() + validation_data_path: str | pathlib.Path | None = zntrack.deps_path(None) model: t.Optional[ApaxBase] = zntrack.deps(None) nl_skin: float = zntrack.params(0.5) log_level: str = zntrack.params("info") @@ -61,6 +64,16 @@ class Apax(ApaxBase): metrics: dict = zntrack.metrics() + def __post_init__(self): + super().__post_init__() + + if self.data is not None and self.data_path is not None: + raise ValueError("You can either provide `data` or `data_path`, not both.") + if self.validation_data is not None and self.validation_data_path is not None: + raise ValueError( + "You can either provide `validation_data` or `validation_data_path`, not both." + ) + @property def parameter(self) -> dict: parameter = yaml.safe_load(self.state.fs.read_text(self.config)) @@ -68,8 +81,12 @@ def parameter(self) -> dict: custom_parameters = { "directory": self.model_directory.as_posix(), "experiment": "", - "train_data_path": self.train_data_file.as_posix(), - "val_data_path": self.validation_data_file.as_posix(), + "train_data_path": self.train_data_file.as_posix() + if self.data is None + else self.data_path, + "val_data_path": self.validation_data_file.as_posix() + if self.validation_data is None + else self.validation_data_path, } if self.model is not None: @@ -98,11 +115,18 @@ def run(self): """Primary method to run which executes all steps of the model training""" if not self.state.restarted: - train_db = znh5md.IO(self.train_data_file.as_posix()) - train_db.extend(self.data) - - val_db = znh5md.IO(self.validation_data_file.as_posix()) - val_db.extend(self.validation_data) + if self.data is not None: + train_db = znh5md.IO(self.train_data_file.as_posix()) + train_db.extend(self.data) + else: + self.train_data_file.write_text(f"Using {self.data_path} instead") + if self.validation_data is not None: + val_db = znh5md.IO(self.validation_data_file.as_posix()) + val_db.extend(self.validation_data) + else: + self.validation_data_file.write_text( + f"Using {self.validation_data_path} instead" + ) csv_path = self.model_directory / "log.csv" if self.state.restarted and csv_path.is_file():