From bc87fd0bf78c77f70c32e1ca96261bd9f4937a6a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 13 Jun 2024 12:40:32 -0400 Subject: [PATCH] add and test extra_run_config: dict | None = None keyword to Trainer to specify run params like batch_size that aren't already recorded by the trainer_args dict --- chgnet/trainer/trainer.py | 21 ++++++++++++++++----- tests/test_trainer.py | 10 +++++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 6c5bc07b..8575b372 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -57,7 +57,8 @@ def __init__( use_device: str | None = None, check_cuda_mem: bool = False, wandb_path: str | None = None, - wandb_kwargs: dict | None = None, + wandb_init_kwargs: dict | None = None, + extra_run_config: dict | None = None, **kwargs, ) -> None: """Initialize all hyper-parameters for trainer. @@ -99,7 +100,11 @@ def __init__( wandb_path (str | None): The project and run name separated by a slash: "project/run_name". If None, wandb logging is not used. Default = None - wandb_kwargs (dict): additional kwargs for wandb.init. Default = None + wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init. + Default = None + extra_run_config (dict): Additional hyper-params to be recorded by wandb + that are not included in the trainer_args. Default = None + **kwargs (dict): additional hyper-params for optimizer, scheduler, etc. """ # Store trainer args for reproducibility @@ -213,12 +218,18 @@ def __init__( "Weights and Biases not installed. pip install wandb to use " "wandb logging." ) - project, run_name = wandb_path.split("/") + if wandb_path.count("/") == 1: + project, run_name = wandb_path.split("/") + else: + raise ValueError( + f"{wandb_path=} should be in the format 'project/run_name' " + "(no extra slashes)" + ) wandb.init( project=project, name=run_name, - config=self.trainer_args, - **(wandb_kwargs or {}), + config=self.trainer_args | (extra_run_config or {}), + **(wandb_init_kwargs or {}), ) def train( diff --git a/tests/test_trainer.py b/tests/test_trainer.py index a7d3f22c..eb0e8a86 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -7,6 +7,7 @@ import torch from pymatgen.core import Lattice, Structure +import wandb from chgnet.data.dataset import StructureData, get_train_val_test_loader from chgnet.model import CHGNet from chgnet.trainer import Trainer @@ -42,6 +43,7 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: train_loader, val_loader, _test_loader = get_train_val_test_loader( data, batch_size=16, train_ratio=0.9, val_ratio=0.05 ) + extra_run_config = dict(some_other_hyperparam=42) trainer = Trainer( model=chgnet, targets="efsm", @@ -49,9 +51,11 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: criterion="MSE", learning_rate=1e-2, epochs=5, - wandb_path="", - wandb_kwargs=dict(anonymous="allow"), + wandb_path="test/run", + wandb_init_kwargs=dict(anonymous="must"), + extra_run_config=extra_run_config, ) + assert dict(wandb.config).items() >= extra_run_config.items() dir_name = "test_tmp_dir" test_dir = tmp_path / dir_name trainer.train(train_loader, val_loader, save_dir=test_dir) @@ -70,7 +74,7 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging" with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012 ctx.setattr("chgnet.trainer.trainer.wandb", None) - _ = Trainer(model=chgnet, wandb_path="radicalai/chgnet-test-finetune") + _ = Trainer(model=chgnet, wandb_path="some-org/some-project") def test_trainer_composition_model(tmp_path: Path) -> None: