From d99df362a3ccd2a6ae8b97aa39d885550b1440d1 Mon Sep 17 00:00:00 2001 From: Rosie Wood Date: Tue, 31 Oct 2023 10:51:05 +0000 Subject: [PATCH] add tests --- tests/test_classify/test_classifier.py | 71 +++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/test_classify/test_classifier.py b/tests/test_classify/test_classifier.py index 0ab78bb3..f76dcbfb 100644 --- a/tests/test_classify/test_classifier.py +++ b/tests/test_classify/test_classifier.py @@ -73,6 +73,12 @@ def test_init_models_string(inputs, infer_inputs): model, labels_map=annots.labels_map, dataloaders=dataloaders ) assert isinstance(classifier.model, model_type) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + classifier = ClassifierContainer( + model, labels_map=annots.labels_map + ) + assert isinstance(classifier.model, model_type) + assert classifier.dataloaders == {} def test_init_models_string_errors(inputs): @@ -96,6 +102,12 @@ def test_init_resnet18_torch(inputs): my_model, labels_map=annots.labels_map, dataloaders=dataloaders ) # resnet18 as nn.Module assert isinstance(classifier.model, models.ResNet) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + classifier = ClassifierContainer( + my_model, labels_map=annots.labels_map + ) + assert isinstance(classifier.model, models.ResNet) + assert classifier.dataloaders == {} # test loading model from pickle file using torch load @@ -109,6 +121,12 @@ def test_init_resnet18_pickle(inputs, sample_dir): my_model, labels_map=annots.labels_map, dataloaders=dataloaders ) # resnet18 as pkl (from sample files) assert isinstance(classifier.model, models.ResNet) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + classifier = ClassifierContainer( + my_model, labels_map=annots.labels_map + ) + assert isinstance(classifier.model, models.ResNet) + assert classifier.dataloaders == {} # test loading model from hugging face @@ -125,6 +143,12 @@ def test_init_resnet18_hf(inputs): my_model, labels_map=annots.labels_map, dataloaders=dataloaders ) assert isinstance(classifier.model, model_type) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + classifier = ClassifierContainer( + my_model, labels_map=annots.labels_map + ) + assert isinstance(classifier.model, model_type) + assert classifier.dataloaders == {} # test loading model using timm @@ -140,6 +164,12 @@ def test_init_resnet18_timm(inputs): my_model, labels_map=annots.labels_map, dataloaders=dataloaders ) assert isinstance(classifier.model, timm.models.ResNet) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + classifier = ClassifierContainer( + my_model, labels_map=annots.labels_map + ) + assert isinstance(classifier.model, timm.models.ResNet) + assert classifier.dataloaders == {} @pytest.mark.dependency(name="timm_models", scope="session") @@ -163,15 +193,54 @@ def test_init_models_timm(inputs): my_model, labels_map=annots.labels_map, dataloaders=dataloaders ) assert isinstance(classifier.model, model_type) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + classifier = ClassifierContainer( + model, labels_map=annots.labels_map + ) + assert isinstance(classifier.model, model_type) + assert classifier.dataloaders == {} # test loading object from pickle file +def test_load_no_dataloaders(inputs, sample_dir): + annots, dataloaders = inputs + classifier = ClassifierContainer( + None, None, None, load_path=f"{sample_dir}/test.pkl" + ) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + # without explicitly passing dataloaders as None + classifier = ClassifierContainer( + None, None, load_path=f"{sample_dir}/test.pkl" + ) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + +def test_load_w_dataloaders(inputs, sample_dir): + annots, dataloaders = inputs + # rename keys + dataloaders["new_train"]= dataloaders.pop("train") + dataloaders["new_val"]= dataloaders.pop("val") + dataloaders["new_test"]= dataloaders.pop("test") + + classifier = ClassifierContainer( + None, None, dataloaders=dataloaders, load_path=f"{sample_dir}/test.pkl" + ) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val", "new_train", "new_test", "new_val"]) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + def test_init_load(inputs, load_classifier): annots, dataloaders = inputs classifier = load_classifier - assert list(classifier.dataloaders.keys()) == list(dataloaders.keys()) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) assert classifier.labels_map == annots.labels_map assert isinstance(classifier.model, models.ResNet)