Skip to content

Commit

Permalink
fix: refactor custom optimizer into enum
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Dec 9, 2024
1 parent b373818 commit 119514b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
11 changes: 3 additions & 8 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype
from axolotl.utils.optimizers.embedding_scaled import create_embedding_scaled_optimizer
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
Expand Down Expand Up @@ -1726,14 +1727,8 @@ def build(self, total_num_steps):
trainer_kwargs["max_length"] = self.cfg.sequence_len

# Handle custom optimizer
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
"lion_pytorch",
]:
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_arguments_kwargs.get("learning_rate"),
Expand Down
23 changes: 12 additions & 11 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name


class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""

optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name


class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""

Expand Down Expand Up @@ -445,17 +456,7 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
Union[OptimizerNames, CustomSupportedOptimizers]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
Expand Down

0 comments on commit 119514b

Please sign in to comment.