Skip to content

Commit

Permalink
move more things to kd plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 30, 2024
1 parent fa3757c commit f98ef8b
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 13 deletions.
9 changes: 3 additions & 6 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.kd import AxolotlKDTrainer
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
Expand Down Expand Up @@ -77,7 +76,6 @@
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForKD,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
Expand Down Expand Up @@ -306,8 +304,6 @@ def _get_trainer_cls(self):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.trainer == "kd":
return AxolotlKDTrainer
return AxolotlTrainer

def build(self, total_num_steps):
Expand Down Expand Up @@ -797,7 +793,6 @@ def build_collator(
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForKD,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
Expand Down Expand Up @@ -828,7 +823,9 @@ def build_collator(
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.trainer == "kd":
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import DataCollatorForKD

collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq
Expand Down
22 changes: 22 additions & 0 deletions src/axolotl/integrations/kd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Plugin init to add KD support to Axolotl.
"""
from axolotl.integrations.base import BasePlugin

from .args import KDArgs # pylint: disable=unused-import. # noqa: F401


class KDPlugin(BasePlugin):
"""
Plugin for KD support in Axolotl.
"""

def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"

def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer

return AxolotlKDTrainer
return None
19 changes: 19 additions & 0 deletions src/axolotl/integrations/kd/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Plugin args for KD support.
"""
from typing import Optional

from pydantic import BaseModel


class KDArgs(BaseModel):
"""
Input args for knowledge distillation.
"""

kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
File renamed without changes.
7 changes: 0 additions & 7 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,6 @@ class Config:
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.

trainer: Optional[Literal["kd"]] = None
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD

datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
Expand Down

0 comments on commit f98ef8b

Please sign in to comment.