Skip to content

Commit

Permalink
make sure everything stays in the same dtype when using dpo + FSDP (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Apr 22, 2024
1 parent 60f5ce0 commit 68601ec
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
Expand Down Expand Up @@ -1569,6 +1570,9 @@ def build(self, total_num_steps):
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)

dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False):
setup_quantized_peft_meta_for_training(model)

return model, lora_config


def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
except AttributeError:
pass

0 comments on commit 68601ec

Please sign in to comment.