Skip to content

Commit

Permalink
✅ Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Aug 25, 2023
1 parent a0f1348 commit 588a1a0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
8 changes: 4 additions & 4 deletions tests/models/test_deep_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def test_list_and_num_estimators(self):
with pytest.raises(ValueError):
deep_ensembles([model_1, model_2], num_estimators=2)

# def test_list_singleton(self):
# model_1 = dummy_model(1, 10, 1)
# with pytest.raises(ValueError):
# deep_ensembles([model_1], num_estimators=1)
def test_list_singleton(self):
model_1 = dummy_model(1, 10, 1)
with pytest.raises(ValueError):
deep_ensembles([model_1], num_estimators=1)

def test_model_and_no_num_estimator(self):
model_1 = dummy_model(1, 10, 1)
Expand Down
8 changes: 4 additions & 4 deletions tests/routines/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestClassificationSingle:

def test_cli_main_dummy_binary(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py --logits"):
with ArgvContext("file.py", "--logits"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -46,7 +46,7 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py --evaluate_ood --entropy"):
with ArgvContext("file.py", "--evaluate_ood", "--entropy"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand Down Expand Up @@ -78,7 +78,7 @@ class TestClassificationEnsemble:

def test_cli_main_dummy_binary(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py --mutual_information"):
with ArgvContext("file.py", "--mutual_information"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -100,7 +100,7 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py --evaluate_ood --variation_ratio"):
with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# fmt: off
import sys
from pathlib import Path

import pytest
Expand Down Expand Up @@ -47,8 +48,13 @@ def test_cli_main_resnet(self):
def test_cli_main_other_arguments(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(
"file.py --seed 42 --max_epochs 1 --channels_last",
"file.py",
"--seed",
"42",
"--max_epochs",
"1",
):
print(sys.orig_argv, sys.argv)
args = init_args(ResNet, CIFAR10DataModule)

# datamodule
Expand Down

0 comments on commit 588a1a0

Please sign in to comment.