diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 883c149..ed91fef 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -73,7 +73,10 @@ def on_train_end( ): if self._log_model is True and state.is_world_process_zero: fake_trainer = Trainer( - args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer") + args=args, + model=kwargs.get("model"), + tokenizer=kwargs.get("tokenizer"), + eval_dataset=["fake"], ) name = "best" if args.load_best_model_at_end else "last" output_dir = os.path.join(args.output_dir, name) diff --git a/tests/frameworks/test_huggingface.py b/tests/frameworks/test_huggingface.py index 18b3990..42db6d6 100644 --- a/tests/frameworks/test_huggingface.py +++ b/tests/frameworks/test_huggingface.py @@ -162,6 +162,7 @@ def test_huggingface_log_model( live_callback = callback(live=live, log_model=log_model) args.load_best_model_at_end = best + args.metric_for_best_model = "loss" trainer = Trainer( model,