diff --git a/tests/test_trainer.py b/tests/test_trainer.py index bcf44f64..7124556b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -49,14 +49,17 @@ def test_trainer(tmp_path: Path) -> None: learning_rate=1e-2, epochs=5, ) - dir_name = "test_tmp_dir" - test_dir = tmp_path / dir_name - trainer.train(train_loader, val_loader, save_dir=test_dir) + trainer.train( + train_loader, + val_loader, + save_dir=tmp_path, + save_test_result=tmp_path / "test-preds.json", + ) for param in chgnet.composition_model.parameters(): assert param.requires_grad is False - assert test_dir.is_dir(), "Training dir was not created" + assert tmp_path.is_dir(), "Training dir was not created" - output_files = [file.name for file in test_dir.iterdir()] + output_files = [file.name for file in tmp_path.iterdir()] for prefix in ("epoch", "bestE_", "bestF_"): n_matches = sum(file.startswith(prefix) for file in output_files) assert ( @@ -79,16 +82,14 @@ def test_trainer_composition_model(tmp_path: Path) -> None: learning_rate=1e-2, epochs=5, ) - dir_name = "test_tmp_dir2" - test_dir = tmp_path / dir_name initial_weights = chgnet.composition_model.state_dict()["fc.weight"].clone() trainer.train( - train_loader, val_loader, save_dir=test_dir, train_composition_model=True + train_loader, val_loader, save_dir=tmp_path, train_composition_model=True ) for param in chgnet.composition_model.parameters(): assert param.requires_grad is True - output_files = list(test_dir.iterdir()) + output_files = list(tmp_path.iterdir()) weights_path = next(file for file in output_files if file.name.startswith("epoch")) new_chgnet = CHGNet.from_file(weights_path) for param in new_chgnet.composition_model.parameters():