Skip to content

Commit

Permalink
add and test extra_run_config: dict | None = None keyword to Trainer …
Browse files Browse the repository at this point in the history
…to specify run params like batch_size that aren't already recorded by the trainer_args dict
  • Loading branch information
janosh committed Jun 13, 2024
1 parent fb2e75c commit bc87fd0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
21 changes: 16 additions & 5 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(
use_device: str | None = None,
check_cuda_mem: bool = False,
wandb_path: str | None = None,
wandb_kwargs: dict | None = None,
wandb_init_kwargs: dict | None = None,
extra_run_config: dict | None = None,
**kwargs,
) -> None:
"""Initialize all hyper-parameters for trainer.
Expand Down Expand Up @@ -99,7 +100,11 @@ def __init__(
wandb_path (str | None): The project and run name separated by a slash:
"project/run_name". If None, wandb logging is not used.
Default = None
wandb_kwargs (dict): additional kwargs for wandb.init. Default = None
wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init.
Default = None
extra_run_config (dict): Additional hyper-params to be recorded by wandb
that are not included in the trainer_args. Default = None
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
"""
# Store trainer args for reproducibility
Expand Down Expand Up @@ -213,12 +218,18 @@ def __init__(
"Weights and Biases not installed. pip install wandb to use "
"wandb logging."
)
project, run_name = wandb_path.split("/")
if wandb_path.count("/") == 1:
project, run_name = wandb_path.split("/")
else:
raise ValueError(
f"{wandb_path=} should be in the format 'project/run_name' "
"(no extra slashes)"
)
wandb.init(
project=project,
name=run_name,
config=self.trainer_args,
**(wandb_kwargs or {}),
config=self.trainer_args | (extra_run_config or {}),
**(wandb_init_kwargs or {}),
)

def train(
Expand Down
10 changes: 7 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from pymatgen.core import Lattice, Structure

import wandb
from chgnet.data.dataset import StructureData, get_train_val_test_loader
from chgnet.model import CHGNet
from chgnet.trainer import Trainer
Expand Down Expand Up @@ -42,16 +43,19 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
extra_run_config = dict(some_other_hyperparam=42)
trainer = Trainer(
model=chgnet,
targets="efsm",
optimizer="Adam",
criterion="MSE",
learning_rate=1e-2,
epochs=5,
wandb_path="",
wandb_kwargs=dict(anonymous="allow"),
wandb_path="test/run",
wandb_init_kwargs=dict(anonymous="must"),
extra_run_config=extra_run_config,
)
assert dict(wandb.config).items() >= extra_run_config.items()
dir_name = "test_tmp_dir"
test_dir = tmp_path / dir_name
trainer.train(train_loader, val_loader, save_dir=test_dir)
Expand All @@ -70,7 +74,7 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging"
with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012
ctx.setattr("chgnet.trainer.trainer.wandb", None)
_ = Trainer(model=chgnet, wandb_path="radicalai/chgnet-test-finetune")
_ = Trainer(model=chgnet, wandb_path="some-org/some-project")


def test_trainer_composition_model(tmp_path: Path) -> None:
Expand Down

0 comments on commit bc87fd0

Please sign in to comment.