Skip to content

Commit

Permalink
add backbone_type and tests for invalid weights and type
Browse files Browse the repository at this point in the history
  • Loading branch information
gqcpm committed Jan 7, 2025
1 parent cd18706 commit 8a30dd1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
23 changes: 23 additions & 0 deletions sleap_nn/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,12 @@ class ModelConfig:
init_weight: str = "default"
pre_trained_weights: str = None
backbone_config: BackboneConfig = attrs.field(factory=BackboneConfig)
backbone_type: str = None
head_configs: HeadConfig = attrs.field(factory=HeadConfig)

# post-initialization
def __attrs_post_init__(self):
self.validate_backbone_type()
self.validate_pre_trained_weights()

# validate the pre-trained weights
Expand All @@ -274,3 +276,24 @@ def validate_pre_trained_weights(self):
"ConvNeXt_Large_Weights",
]
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
if self.backbone_type == "convnext":
if self.pre_trained_weights not in convnext_weights:
raise ValueError(
f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
)
elif self.backbone_type == "swint":
if self.pre_trained_weights not in swint_weights:
raise ValueError(
f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
)
elif (
self.backbone_type == "unet"
and self.pre_trained_weights is not None
):
raise ValueError("UNet does not support pre-trained weights.")

def validate_backbone_type(self):
if self.backbone_type not in ["unet", "convnext", "swint"]:
raise ValueError(
'backbone_type must be one of "unet", "convnext", "swint"'
)
25 changes: 9 additions & 16 deletions tests/config/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def default_config():
return ModelConfig(
init_weight="default",
pre_trained_weights=None,
backbone_type="unet",
backbone_config=BackboneConfig(),
head_configs=HeadConfig()
)
Expand All @@ -33,23 +34,15 @@ def test_default_initialization(default_config):
assert default_config.init_weight == "default"
assert default_config.pre_trained_weights == None

# def test_invalid_pre_trained_weights():
# """Test validation failure with an invalid pre_trained_weights."""
# with pytest.raises(ValueError):
# ModelConfig(pre_trained_weights="weights")
def test_invalid_pre_trained_weights():
"""Test validation failure with an invalid pre_trained_weights."""
with pytest.raises(ValueError):
ModelConfig(pre_trained_weights="here", backbone_type="unet")

# def test_invalid_input_size():
# """Test validation failure with an invalid input_size."""
# with pytest.raises(ValueError, match="input_size must be a tuple of two positive integers"):
# ModelConfig(model_type="default", input_size=(224, -1), num_classes=10)

# with pytest.raises(ValueError, match="input_size must be a tuple of two positive integers"):
# ModelConfig(model_type="default", input_size="224,224", num_classes=10)

# def test_invalid_num_classes():
# """Test validation failure with an invalid num_classes."""
# with pytest.raises(ValueError, match="num_classes must be a positive integer"):
# ModelConfig(model_type="default", input_size=(224, 224), num_classes=-5)
def test_invalid_backbonetype():
"""Test validation failure with an invalid pre_trained_weights."""
with pytest.raises(ValueError, match='backbone_type must be one of "unet", "convnext", "swint"'):
ModelConfig(backbone_type="net")

# def test_update_config(default_config):
# """Test updating configuration attributes."""
Expand Down

0 comments on commit 8a30dd1

Please sign in to comment.