Skip to content

Commit

Permalink
utilize omegaconf.save instead of save_yaml function
Browse files Browse the repository at this point in the history
  • Loading branch information
gqcpm committed Feb 7, 2025
1 parent feebca0 commit 08707d8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
24 changes: 10 additions & 14 deletions sleap_nn/config/training_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ def load_yaml(cls, filename: Text) -> "TrainerConfig":
config = OmegaConf.load(filename)
return cls(**OmegaConf.to_container(config, resolve=True))

def to_yaml(self) -> str:
"""Serialize the configuration into YAML-encoded string format.
def to_yaml(self, filename: Optional[Text] = None) -> None:
"""Serialize and optionally save the configuration to YAML format.
Returns:
The YAML encoded string representation of the configuration.
Args:
filename: Optional path to save the YAML file to. If not provided,
the configuration will only be converted to YAML format.
"""
# Convert attrs objects to nested dictionaries
config_dict = asdict(self)
Expand All @@ -103,16 +104,11 @@ def to_yaml(self) -> str:
if config_dict.get("model", {}).get("backbone_type"):
config_dict["model"]["backbone_type"] = self.model.backbone_type.value

return OmegaConf.to_yaml(config_dict)

def save_yaml(self, filename: Text):
"""Save the configuration to a YAML file.
Arguments:
filename: Path to save the training job file to.
"""
with open(filename, "w") as f:
f.write(self.to_yaml())
# Create OmegaConf object and save if filename provided
conf = OmegaConf.create(config_dict)
if filename is not None:
OmegaConf.save(conf, filename)
return


def load_config(filename: Text, load_training_config: bool = True) -> TrainingJobConfig:
Expand Down
9 changes: 5 additions & 4 deletions tests/config/test_training_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sleap_nn.config.data_config import DataConfig
from sleap_nn.config.trainer_config import TrainerConfig
from omegaconf import OmegaConf
from dataclasses import asdict


@pytest.fixture
Expand Down Expand Up @@ -108,8 +109,8 @@ def test_to_yaml(sample_config):
assert parsed_yaml.trainer == sample_config["trainer"]


def test_save_and_load_yaml(sample_config):
"""Test saving and loading a TrainingJobConfig as a YAML file."""
def test_load_yaml(sample_config):
"""Test loading a TrainingJobConfig from a YAML file."""
# Create proper config objects
data_config = DataConfig(
train_labels_path=sample_config["data"].train_labels_path,
Expand Down Expand Up @@ -137,8 +138,8 @@ def test_save_and_load_yaml(sample_config):
with tempfile.TemporaryDirectory() as tmpdir:
file_path = os.path.join(tmpdir, "test_config.yaml")

# Save to file
config.save_yaml(file_path)
# Use the to_yaml method to save the file
config.to_yaml(filename=file_path)

# Load from file
loaded_config = TrainingJobConfig.load_yaml(file_path)
Expand Down

0 comments on commit 08707d8

Please sign in to comment.