From 055365998e84f0da3192e7fb4a58e3bbe39248e2 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Fri, 3 Jan 2025 00:21:46 -0800 Subject: [PATCH] Rename var --- torchtitan/config_manager.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 01c73491..400d86a4 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import asdict, dataclass, field, fields, is_dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import tyro @@ -24,10 +24,6 @@ } -def string_list(raw_arg): - return raw_arg.split(",") - - @dataclass class Job: config_file: Optional[str] = None @@ -452,7 +448,7 @@ def _update(self, instance: "JobConfig") -> None: for f in fields(self): setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name))) - def parse_args(self): + def parse_args(self) -> None: """ Parse CLI arguments, optionally load from a TOML file, merge with defaults, and return a JobConfig instance. @@ -479,7 +475,7 @@ def _load_toml(file_path: str) -> Dict[str, Any]: logger.exception(f"Error while loading config file: {file_path}") raise e - def _dict_to_dataclass(self, config_class, data: Dict[str, Any]) -> Any: + def _dict_to_dataclass(self, config_class: Callable, data: Dict[str, Any]) -> Any: """Recursively convert dictionaries to nested dataclasses.""" if not is_dataclass(config_class): return data @@ -494,23 +490,21 @@ def _dict_to_dataclass(self, config_class, data: Dict[str, Any]) -> Any: kwargs[f.name] = value return config_class(**kwargs) - def _merge_with_defaults( - self, source: "JobConfig", defaults: "JobConfig" - ) -> "JobConfig": + def _merge_with_defaults(self, target, defaults) -> Any: """Recursively merge two dataclass instances (source overrides defaults).""" merged_kwargs = {} - for f in fields(source): - source_val = getattr(source, f.name) + for f in fields(target): + target_val = getattr(target, f.name) default_val = getattr(defaults, f.name) - if is_dataclass(source_val) and is_dataclass(default_val): + if is_dataclass(target_val) and is_dataclass(default_val): merged_kwargs[f.name] = self._merge_with_defaults( - source_val, default_val + target_val, default_val ) else: merged_kwargs[f.name] = ( - source_val if source_val is not None else default_val + target_val if target_val is not None else default_val ) - return type(source)(**merged_kwargs) + return type(target)(**merged_kwargs) def _validate_config(self) -> None: # TODO: Add more mandatory validations