diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 0d3a84c7..2b19fe6e 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -559,6 +559,10 @@ def end(self): catch_and_warn(DvcException, logger)(ensure_dir_is_tracked)( self.dir, self._dvc_repo ) + if self._dvcyaml: + catch_and_warn(DvcException, logger)(self._dvc_repo.scm.add)( + self.dvc_file + ) self.make_report() @@ -589,6 +593,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): @catch_and_warn(DvcException, logger, mark_dvclive_only_ended) def save_dvc_exp(self): if self._save_dvc_exp: + if self._dvcyaml: + self._include_untracked.append(self.dvc_file) self._experiment_rev = self._dvc_repo.experiments.save( name=self._exp_name, include_untracked=self._include_untracked, diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 3b552c3d..ed7e4938 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -34,7 +34,7 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): assert live._exp_name is not None mocked_dvc_repo.experiments.save.assert_called_with( name=live._exp_name, - include_untracked=[live.dir], + include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")], force=True, message=None, ) @@ -67,6 +67,7 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): dvc_repo.scm.get_ref.return_value = None dvc_repo.scm.no_commits = False dvc_repo.config = {} + dvc_repo.root_dir = tmp_dir mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) live = Live() assert live._save_dvc_exp @@ -75,7 +76,10 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): live.end() dvc_repo.experiments.save.assert_called_with( - name=live._exp_name, include_untracked=[live.dir], force=True, message=None + name=live._exp_name, + include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")], + force=True, + message=None, ) @@ -87,6 +91,7 @@ def test_exp_save_with_dvc_files(tmp_dir, mocker): dvc_repo.scm.get_rev.return_value = "current_rev" dvc_repo.scm.get_ref.return_value = None dvc_repo.scm.no_commits = False + dvc_repo.root_dir = tmp_dir dvc_repo.config = {} mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) @@ -94,7 +99,10 @@ def test_exp_save_with_dvc_files(tmp_dir, mocker): live.end() dvc_repo.experiments.save.assert_called_with( - name=live._exp_name, include_untracked=[live.dir], force=True, message=None + name=live._exp_name, + include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")], + force=True, + message=None, ) @@ -126,7 +134,7 @@ def test_untracked_dvclive_files_inside_dvc_exp_run_are_added( with Live() as live: live.log_metric("foo", 1) live.next_step() - live._dvc_repo.scm.add.assert_called_with(["dvclive/metrics.json", plot_file]) + live._dvc_repo.scm.add.assert_any_call(["dvclive/metrics.json", plot_file]) def test_dvc_outs_are_not_added(tmp_dir, mocked_dvc_repo, monkeypatch): @@ -144,7 +152,7 @@ def test_dvc_outs_are_not_added(tmp_dir, mocked_dvc_repo, monkeypatch): live.log_metric("foo", 1) live.next_step() - live._dvc_repo.scm.add.assert_called_with(["dvclive/metrics.json"]) + live._dvc_repo.scm.add.assert_any_call(["dvclive/metrics.json"]) def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch): @@ -153,7 +161,7 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch): mocked_dvc_repo.scm.untracked_files.return_value = ["dvclive/metrics.json"] mocked_dvc_repo.scm.add.side_effect = DvcException("foo") - with Live() as live: + with Live(dvcyaml=False) as live: live.summary["foo"] = 1 @@ -162,7 +170,7 @@ def test_exp_save_message(tmp_dir, mocked_dvc_repo): live.end() mocked_dvc_repo.experiments.save.assert_called_with( name=live._exp_name, - include_untracked=[live.dir], + include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")], force=True, message="Custom message", ) diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index 3622d4b3..061c2bd7 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -58,7 +58,7 @@ def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo): live.log_artifact("data") mocked_dvc_repo.experiments.save.assert_called_with( name=live._exp_name, - include_untracked=[live.dir, "data", ".gitignore"], + include_untracked=[live.dir, "data", ".gitignore", str(tmp_dir / "dvc.yaml")], force=True, message=None, )