diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 01e07640f9..a6cbbd25f3 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -205,6 +205,10 @@ def terminate_handler(_, __, model): if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) + if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: