diff --git a/sleap_nn/config/model_config.py b/sleap_nn/config/model_config.py index 1bd48cd1..8aae8628 100644 --- a/sleap_nn/config/model_config.py +++ b/sleap_nn/config/model_config.py @@ -259,10 +259,12 @@ class ModelConfig: init_weight: str = "default" pre_trained_weights: str = None backbone_config: BackboneConfig = attrs.field(factory=BackboneConfig) + backbone_type: str = None head_configs: HeadConfig = attrs.field(factory=HeadConfig) # post-initialization def __attrs_post_init__(self): + self.validate_backbone_type() self.validate_pre_trained_weights() # validate the pre-trained weights @@ -274,3 +276,24 @@ def validate_pre_trained_weights(self): "ConvNeXt_Large_Weights", ] swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"] + if self.backbone_type == "convnext": + if self.pre_trained_weights not in convnext_weights: + raise ValueError( + f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}" + ) + elif self.backbone_type == "swint": + if self.pre_trained_weights not in swint_weights: + raise ValueError( + f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}" + ) + elif ( + self.backbone_type == "unet" + and self.pre_trained_weights is not None + ): + raise ValueError("UNet does not support pre-trained weights.") + + def validate_backbone_type(self): + if self.backbone_type not in ["unet", "convnext", "swint"]: + raise ValueError( + 'backbone_type must be one of "unet", "convnext", "swint"' + ) diff --git a/tests/config/test_model_config.py b/tests/config/test_model_config.py index 377d6f82..8ecf00c8 100644 --- a/tests/config/test_model_config.py +++ b/tests/config/test_model_config.py @@ -23,6 +23,7 @@ def default_config(): return ModelConfig( init_weight="default", pre_trained_weights=None, + backbone_type="unet", backbone_config=BackboneConfig(), head_configs=HeadConfig() ) @@ -33,23 +34,15 @@ def test_default_initialization(default_config): assert default_config.init_weight == "default" assert default_config.pre_trained_weights == None -# def test_invalid_pre_trained_weights(): -# """Test validation failure with an invalid pre_trained_weights.""" -# with pytest.raises(ValueError): -# ModelConfig(pre_trained_weights="weights") +def test_invalid_pre_trained_weights(): + """Test validation failure with an invalid pre_trained_weights.""" + with pytest.raises(ValueError): + ModelConfig(pre_trained_weights="here", backbone_type="unet") -# def test_invalid_input_size(): -# """Test validation failure with an invalid input_size.""" -# with pytest.raises(ValueError, match="input_size must be a tuple of two positive integers"): -# ModelConfig(model_type="default", input_size=(224, -1), num_classes=10) - -# with pytest.raises(ValueError, match="input_size must be a tuple of two positive integers"): -# ModelConfig(model_type="default", input_size="224,224", num_classes=10) - -# def test_invalid_num_classes(): -# """Test validation failure with an invalid num_classes.""" -# with pytest.raises(ValueError, match="num_classes must be a positive integer"): -# ModelConfig(model_type="default", input_size=(224, 224), num_classes=-5) +def test_invalid_backbonetype(): + """Test validation failure with an invalid pre_trained_weights.""" + with pytest.raises(ValueError, match='backbone_type must be one of "unet", "convnext", "swint"'): + ModelConfig(backbone_type="net") # def test_update_config(default_config): # """Test updating configuration attributes."""