Skip to content

Commit

Permalink
add test_wandb_init + test_wandb_log_frequency
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 27, 2024
1 parent 97b6679 commit 96ffdc2
Showing 1 changed file with 78 additions and 8 deletions.
86 changes: 78 additions & 8 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -36,13 +37,13 @@
stresses=stresses,
magmoms=magmoms,
)
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
chgnet = CHGNet.load()


def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
chgnet = CHGNet.load()
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
extra_run_config = dict(some_other_hyperparam=42)
trainer = Trainer(
model=chgnet,
Expand Down Expand Up @@ -81,12 +82,8 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:


def test_trainer_composition_model(tmp_path: Path) -> None:
chgnet = CHGNet.load()
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
trainer = Trainer(
model=chgnet,
targets="efsm",
Expand Down Expand Up @@ -115,3 +112,76 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
expect[0][10] = 0
expect[0][16] = 0
assert torch.all(comparison == expect)


@pytest.fixture()
def mock_wandb():
with patch("chgnet.trainer.trainer.wandb") as mock:
yield mock


def test_wandb_init(mock_wandb):
chgnet = CHGNet.load()
_trainer = Trainer(
model=chgnet,
wandb_path="test-project/test-run",
wandb_init_kwargs={"tags": ["test"]},
)
expected_config = {
"targets": "ef",
"energy_loss_ratio": 1,
"force_loss_ratio": 1,
"stress_loss_ratio": 0.1,
"mag_loss_ratio": 0.1,
"optimizer": "Adam",
"scheduler": "CosLR",
"criterion": "MSE",
"epochs": 50,
"starting_epoch": 0,
"learning_rate": 0.001,
"print_freq": 100,
"torch_seed": None,
"data_seed": None,
"use_device": None,
"check_cuda_mem": False,
"wandb_path": "test-project/test-run",
"wandb_init_kwargs": {"tags": ["test"]},
"extra_run_config": None,
}
mock_wandb.init.assert_called_once_with(
project="test-project", name="test-run", config=expected_config, tags=["test"]
)


def test_wandb_log_frequency(mock_wandb):
trainer = Trainer(model=chgnet, wandb_path="test-project/test-run", epochs=1)

# Test epoch logging
trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir="")
assert (
mock_wandb.log.call_count == 2 * trainer.epochs
), "Expected one train and one val log per epoch"

mock_wandb.log.reset_mock()

# Test batch logging
trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir="")
expected_batch_calls = trainer.epochs * len(train_loader)
assert (
mock_wandb.log.call_count > expected_batch_calls
), "Expected more calls for batch logging"

# Test log content (for both epoch and batch logging)
for call_args in mock_wandb.log.call_args_list:
logged_data = call_args[0][0]
assert isinstance(logged_data, dict), "Logged data should be a dictionary"
assert any(
key.endswith("_mae") for key in logged_data
), "Logged data should contain MAE metrics"

mock_wandb.log.reset_mock()

# Test no logging when wandb_path is not provided
trainer_no_wandb = Trainer(model=chgnet, epochs=1)
trainer_no_wandb.train(train_loader, val_loader)
mock_wandb.log.assert_not_called()

0 comments on commit 96ffdc2

Please sign in to comment.