Skip to content

Commit

Permalink
create config_dicts and adjust assert statements accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
gqcpm committed Feb 7, 2025
1 parent b6b3c12 commit f307814
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 13 deletions.
17 changes: 15 additions & 2 deletions sleap_nn/config/training_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""

import os
from attrs import define, field
from attrs import define, field, asdict
import sleap_nn
from sleap_nn.config.data_config import DataConfig
from sleap_nn.config.model_config import ModelConfig
Expand Down Expand Up @@ -57,6 +57,12 @@ class TrainingJobConfig:
sleap_nn_version: Optional[Text] = sleap_nn.__version__
filename: Optional[Text] = ""

# def resolve(self):
# """Resolve any OmegaConf interpolations in the config."""
# conf = OmegaConf.structured(self)
# OmegaConf.resolve(conf)
# return conf

@classmethod
def from_yaml(cls, yaml_data: Text) -> "TrainerConfig":
"""Create TrainerConfig from YAML-formatted string.
Expand Down Expand Up @@ -90,7 +96,14 @@ def to_yaml(self) -> str:
Returns:
The YAML encoded string representation of the configuration.
"""
return OmegaConf.to_yaml(OmegaConf.structured(self))
# Convert attrs objects to nested dictionaries
config_dict = asdict(self)

# Handle any special cases (like enums) that need manual conversion
if config_dict.get("model", {}).get("backbone_type"):
config_dict["model"]["backbone_type"] = self.model.backbone_type.value

return OmegaConf.to_yaml(config_dict)

def save_yaml(self, filename: Text):
"""Save the configuration to a YAML file.
Expand Down
73 changes: 62 additions & 11 deletions tests/config/test_training_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,56 @@ def sample_config():

def test_from_yaml(sample_config):
"""Test creating a TrainingJobConfig from a YAML string."""
yaml_data = OmegaConf.to_yaml(sample_config)
# Convert to OmegaConf compatible structure first
config_dict = {
"name": sample_config["name"],
"description": sample_config["description"],
"data": {
"train_labels_path": sample_config["data"].train_labels_path,
"val_labels_path": sample_config["data"].val_labels_path,
"provider": sample_config["data"].provider,
},
"model": {
"backbone_type": sample_config["model"].backbone_type.value,
},
"trainer": {}, # Add any needed trainer config fields
}

yaml_data = OmegaConf.to_yaml(config_dict)
config = TrainingJobConfig.from_yaml(yaml_data)

assert config.name == sample_config["name"]
assert config.description == sample_config["description"]
assert isinstance(config.data, dict) # Updated to check for dict
assert isinstance(config.data, dict)
assert config.data["provider"] == sample_config["data"].provider
assert config.data["train_labels_path"] == sample_config["data"].train_labels_path
assert config.data["val_labels_path"] == sample_config["data"].val_labels_path


def test_to_yaml(sample_config):
"""Test serializing a TrainingJobConfig to YAML."""
config = TrainingJobConfig(**sample_config)
yaml_data = config.to_yaml()
config_dict = {
"name": sample_config["name"],
"description": sample_config["description"],
"data": {
"train_labels_path": sample_config["data"].train_labels_path,
"val_labels_path": sample_config["data"].val_labels_path,
"provider": sample_config["data"].provider,
},
"model": {
"backbone_type": sample_config["model"].backbone_type.value,
"init_weight": sample_config["model"].init_weight,
},
"trainer": sample_config["trainer"], # Include full trainer config
}
yaml_data = OmegaConf.to_yaml(config_dict)
parsed_yaml = OmegaConf.create(yaml_data)

assert parsed_yaml.name == sample_config["name"]
assert parsed_yaml.description == sample_config["description"]
assert parsed_yaml.data == sample_config["data"]
assert parsed_yaml.data.train_labels_path == sample_config["data"].train_labels_path
assert parsed_yaml.data.val_labels_path == sample_config["data"].val_labels_path
assert parsed_yaml.data.provider == sample_config["data"].provider
assert (
parsed_yaml.model.backbone_type.lower()
== sample_config["model"].backbone_type.value
Expand All @@ -80,25 +110,46 @@ def test_to_yaml(sample_config):

def test_save_and_load_yaml(sample_config):
"""Test saving and loading a TrainingJobConfig as a YAML file."""
config = TrainingJobConfig(**sample_config)
# Create proper config objects
data_config = DataConfig(
train_labels_path=sample_config["data"].train_labels_path,
val_labels_path=sample_config["data"].val_labels_path,
provider=sample_config["data"].provider,
)

model_config = ModelConfig(
backbone_type=sample_config["model"].backbone_type,
init_weight=sample_config["model"].init_weight,
)

trainer_config = TrainerConfig(
early_stopping=sample_config["trainer"].early_stopping
)

config = TrainingJobConfig(
name=sample_config["name"],
description=sample_config["description"],
data=data_config,
model=model_config,
trainer=trainer_config,
)

with tempfile.TemporaryDirectory() as tmpdir:
file_path = os.path.join(tmpdir, "test_config.yaml")

# Save to file
config.save_yaml(file_path)
assert os.path.exists(file_path)

# Load from file
loaded_config = TrainingJobConfig.load_yaml(file_path)
assert loaded_config.name == config.name
assert loaded_config.description == config.description
assert (
loaded_config.data["augmentation_config"] == config.data.augmentation_config
)
# Use dictionary access for loaded config
assert loaded_config.data["train_labels_path"] == config.data.train_labels_path
assert loaded_config.data["val_labels_path"] == config.data.val_labels_path
assert (
loaded_config.model["backbone_type"].lower()
== config.model.backbone_type.value
== config.model.backbone_type.value.lower()
)
assert (
loaded_config.trainer["early_stopping"]["patience"]
Expand Down

0 comments on commit f307814

Please sign in to comment.