Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ class Apax(ApaxBase):
verbosity of logging during training
"""

data: list[ase.Atoms] = zntrack.deps()
data: list[ase.Atoms] | None = zntrack.deps()
data_path: str | pathlib.Path | None = zntrack.deps_path(None)

config: str = zntrack.params_path()
validation_data: list[ase.Atoms] = zntrack.deps()
validation_data: list[ase.Atoms] | None = zntrack.deps()
validation_data_path: str | pathlib.Path | None = zntrack.deps_path(None)
model: t.Optional[ApaxBase] = zntrack.deps(None)
nl_skin: float = zntrack.params(0.5)
log_level: str = zntrack.params("info")
Expand All @@ -61,15 +64,29 @@ class Apax(ApaxBase):

metrics: dict = zntrack.metrics()

def __post_init__(self):
super().__post_init__()

if self.data is not None and self.data_path is not None:
raise ValueError("You can either provide `data` or `data_path`, not both.")
if self.validation_data is not None and self.validation_data_path is not None:
raise ValueError(
"You can either provide `validation_data` or `validation_data_path`, not both."
)

@property
def parameter(self) -> dict:
parameter = yaml.safe_load(self.state.fs.read_text(self.config))

custom_parameters = {
"directory": self.model_directory.as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
"train_data_path": self.train_data_file.as_posix()
if self.data is None
else self.data_path,
"val_data_path": self.validation_data_file.as_posix()
if self.validation_data is None
else self.validation_data_path,
}

if self.model is not None:
Expand Down Expand Up @@ -98,11 +115,18 @@ def run(self):
"""Primary method to run which executes all steps of the model training"""

if not self.state.restarted:
train_db = znh5md.IO(self.train_data_file.as_posix())
train_db.extend(self.data)

val_db = znh5md.IO(self.validation_data_file.as_posix())
val_db.extend(self.validation_data)
if self.data is not None:
train_db = znh5md.IO(self.train_data_file.as_posix())
train_db.extend(self.data)
else:
self.train_data_file.write_text(f"Using {self.data_path} instead")
if self.validation_data is not None:
val_db = znh5md.IO(self.validation_data_file.as_posix())
val_db.extend(self.validation_data)
else:
self.validation_data_file.write_text(
f"Using {self.validation_data_path} instead"
)

csv_path = self.model_directory / "log.csv"
if self.state.restarted and csv_path.is_file():
Expand Down
Loading