Skip to content

Commit

Permalink
Support artifacts section. (#521)
Browse files Browse the repository at this point in the history
- Update `make_dvcyaml` to write `artifacts` section.
- Extend `log_artifact` to accept `type`, `name`, `desc`, `labels`, `meta`.
  • Loading branch information
daavoo authored Apr 21, 2023
1 parent 67a7ae3 commit a4df3a9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 5 deletions.
13 changes: 11 additions & 2 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def make_dvcyaml(live):
if plots:
dvcyaml["plots"] = plots

if live._artifacts:
dvcyaml["artifacts"] = live._artifacts
for artifact in dvcyaml["artifacts"].values():
abs_path = os.path.realpath(artifact["path"])
abs_dir = os.path.realpath(live.dir)
relative_path = os.path.relpath(abs_path, abs_dir)
artifact["path"] = Path(relative_path).as_posix()

dump_yaml(dvcyaml, live.dvc_file)


Expand Down Expand Up @@ -164,9 +172,10 @@ def get_dvc_stage_template(live):
"cmd": "<python my_code_file.py my_args>",
"deps": ["<my_code_file.py>"],
}
if live._outs:
if live._artifacts:
stage["outs"] = []
for o in live._outs:
for artifact in live._artifacts.values():
o = artifact["path"]
artifact_path = Path(os.getcwd()) / o
artifact_path = artifact_path.relative_to(live._dvc_repo.root_dir).as_posix()
stage["outs"].append(artifact_path)
Expand Down
29 changes: 26 additions & 3 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self._images: Dict[str, Any] = {}
self._params: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._outs: Set[StrPath] = set()
self._artifacts: Dict[str, Dict] = {}
self._inside_with = False
self._dvcyaml = dvcyaml

Expand Down Expand Up @@ -321,19 +321,42 @@ def log_param(self, name: str, val: ParamLike):
"""Saves the given parameter value to yaml"""
self.log_params({name: val})

def log_artifact(self, path: StrPath):
def log_artifact(
self,
path: StrPath,
type: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
desc: Optional[str] = None, # noqa: ARG002
labels: Optional[List[str]] = None, # noqa: ARG002
meta: Optional[Dict[str, Any]] = None, # noqa: ARG002
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
raise InvalidDataTypeError(path, type(path))

if self._dvc_repo is not None:
from dvc.repo.artifacts import name_is_compatible

try:
stage = self._dvc_repo.add(path)
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to dvc add {path}: {e}")
return

self._outs.add(path)
name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta") and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)

dvc_file = stage[0].addressing

if self._save_dvc_exp:
Expand Down
59 changes: 59 additions & 0 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dvclive import Live
from dvclive.serialize import load_yaml


def test_log_artifact(tmp_dir, dvc_repo):
Expand Down Expand Up @@ -44,3 +45,61 @@ def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo):
include_untracked=[live.dir, "data", ".gitignore"],
force=True,
)


def test_log_artifact_type_model(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", type="model")

assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "../model.pth", "type": "model"}}
}


def test_log_artifact_type_model_provided_name(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", type="model", name="custom")

assert load_yaml(live.dvc_file) == {
"artifacts": {"custom": {"path": "../model.pth", "type": "model"}}
}


def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
for _ in range(3):
live.log_artifact("model.pth", type="model")
live.next_step()
live.log_artifact("model.pth", type="model", labels=["final"])
assert load_yaml(live.dvc_file) == {
"artifacts": {
"model": {"path": "../model.pth", "type": "model", "labels": ["final"]},
},
"metrics": ["metrics.json"],
}


def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo):
(tmp_dir / "model.pth").touch()

attrs = {
"type": "model",
"name": "foo",
"desc": "bar",
"labels": ["foo"],
"meta": {"foo": "bar"},
}
with Live() as live:
live.log_artifact("model.pth", **attrs)
attrs.pop("name")
assert load_yaml(live.dvc_file) == {
"artifacts": {
"foo": {"path": "../model.pth", **attrs},
}
}

0 comments on commit a4df3a9

Please sign in to comment.