Skip to content

Commit

Permalink
make sure to capture non-null defaults from config validation (#1415)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Mar 26, 2024
1 parent ff939d8 commit 601b77b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
).model_dump(exclude_unset=True)
).model_dump(exclude_none=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)


Expand Down
28 changes: 12 additions & 16 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ class PeftConfig(BaseModel):
loftq_config: Optional[LoftQConfig] = None


class AutoType(str, Enum):
"""auto type string configuration subset - used for bf16"""

AUTO = "auto"


class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""

Expand Down Expand Up @@ -307,12 +301,14 @@ class HyperparametersConfig(BaseModel):
},
)

train_on_inputs: Optional[bool] = None
train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None

learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
Expand All @@ -323,7 +319,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler: Optional[SchedulerType] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
Expand Down Expand Up @@ -473,7 +469,7 @@ class Config:
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None

bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
float16: Optional[bool] = None # for non-AMP cases
Expand All @@ -487,7 +483,7 @@ class Config:

unfrozen_parameters: Optional[List[str]] = None

sequence_len: int = Field(default=1024)
sequence_len: int = Field(default=512)
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
Expand Down Expand Up @@ -548,10 +544,10 @@ class Config:
sample_packing_eff_est: Optional[float] = None
axolotl_config_path: Optional[str] = None

is_falcon_derived_model: Optional[bool] = Field(default=False)
is_llama_derived_model: Optional[bool] = Field(default=False)
is_mistral_derived_model: Optional[bool] = Field(default=False)
is_qwen_derived_model: Optional[bool] = Field(default=False)
is_falcon_derived_model: Optional[bool] = Field(default=None)
is_llama_derived_model: Optional[bool] = Field(default=None)
is_mistral_derived_model: Optional[bool] = Field(default=None)
is_qwen_derived_model: Optional[bool] = Field(default=None)

@field_validator("datasets", mode="before")
@classmethod
Expand Down
12 changes: 12 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ class TestValidation(BaseValidation):
Test the validation module
"""

def test_defaults(self, minimal_cfg):
test_cfg = DictDefault(
{
"weight_decay": None,
}
| minimal_cfg
)
cfg = validate_config(test_cfg)

assert cfg.train_on_inputs is False
assert cfg.weight_decay is None

def test_datasets_min_length(self):
cfg = DictDefault(
{
Expand Down

0 comments on commit 601b77b

Please sign in to comment.