diff --git a/tests/test_classify/test_classifier.py b/tests/test_classify/test_classifier.py index f76dcbfb..67483a00 100644 --- a/tests/test_classify/test_classifier.py +++ b/tests/test_classify/test_classifier.py @@ -195,7 +195,7 @@ def test_init_models_timm(inputs): assert isinstance(classifier.model, model_type) assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) classifier = ClassifierContainer( - model, labels_map=annots.labels_map + my_model, labels_map=annots.labels_map ) assert isinstance(classifier.model, model_type) assert classifier.dataloaders == {}