From 31a7908b916cc352a49c985679cd8875b26e3889 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Sun, 21 May 2023 03:26:15 -0700 Subject: [PATCH] fix(make_dvcyaml): make it idempotent for artifacts (#573) --- src/dvclive/dvc.py | 3 ++- tests/test_dvc.py | 15 +++++++++++++++ tests/test_log_artifact.py | 18 +++++++++++++++++- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index b1ba0ed5..4ba39f0e 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -1,4 +1,5 @@ # ruff: noqa: SLF001 +import copy import os import random from pathlib import Path @@ -106,7 +107,7 @@ def make_dvcyaml(live): dvcyaml["plots"] = plots if live._artifacts: - dvcyaml["artifacts"] = live._artifacts + dvcyaml["artifacts"] = copy.deepcopy(live._artifacts) for artifact in dvcyaml["artifacts"].values(): abs_path = os.path.abspath(artifact["path"]) abs_dir = os.path.realpath(live.dir) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 3b6c2710..eb584ad6 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -271,3 +271,18 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch): with Live(report=None) as live: live.summary["foo"] = 1 + + +def test_make_dvcyaml_idempotent(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + with Live() as live: + live.log_artifact("model.pth", type="model") + + live.make_dvcyaml() + + assert load_yaml(live.dvc_file) == { + "artifacts": { + "model": {"path": "../model.pth", "type": "model"}, + } + } diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index 03fdebab..a48137ed 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -149,7 +149,7 @@ def test_log_artifact_type_model_provided_name(tmp_dir, mocked_dvc_repo): } -def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo): +def test_log_artifact_type_model_on_step_and_final(tmp_dir, mocked_dvc_repo): (tmp_dir / "model.pth").touch() with Live() as live: @@ -165,6 +165,22 @@ def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo): } +def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + with Live() as live: + for _ in range(3): + live.log_artifact("model.pth", type="model") + live.next_step() + + assert load_yaml(live.dvc_file) == { + "artifacts": { + "model": {"path": "../model.pth", "type": "model"}, + }, + "metrics": ["metrics.json"], + } + + def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo): (tmp_dir / "model.pth").touch()