Skip to content

Commit

Permalink
fix(make_dvcyaml): make it idempotent for artifacts (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored May 21, 2023
1 parent 03ec6cf commit 31a7908
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/dvclive/dvc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: SLF001
import copy
import os
import random
from pathlib import Path
Expand Down Expand Up @@ -106,7 +107,7 @@ def make_dvcyaml(live):
dvcyaml["plots"] = plots

if live._artifacts:
dvcyaml["artifacts"] = live._artifacts
dvcyaml["artifacts"] = copy.deepcopy(live._artifacts)
for artifact in dvcyaml["artifacts"].values():
abs_path = os.path.abspath(artifact["path"])
abs_dir = os.path.realpath(live.dir)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,18 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch):

with Live(report=None) as live:
live.summary["foo"] = 1


def test_make_dvcyaml_idempotent(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", type="model")

live.make_dvcyaml()

assert load_yaml(live.dvc_file) == {
"artifacts": {
"model": {"path": "../model.pth", "type": "model"},
}
}
18 changes: 17 additions & 1 deletion tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_log_artifact_type_model_provided_name(tmp_dir, mocked_dvc_repo):
}


def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo):
def test_log_artifact_type_model_on_step_and_final(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
Expand All @@ -165,6 +165,22 @@ def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo):
}


def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
for _ in range(3):
live.log_artifact("model.pth", type="model")
live.next_step()

assert load_yaml(live.dvc_file) == {
"artifacts": {
"model": {"path": "../model.pth", "type": "model"},
},
"metrics": ["metrics.json"],
}


def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

Expand Down

0 comments on commit 31a7908

Please sign in to comment.