Skip to content

Commit

Permalink
add test mode (#735)
Browse files Browse the repository at this point in the history
* add test mode

* docstring for _init_test

* add f str in logger
  • Loading branch information
Dave Berenbaum authored Nov 14, 2023
1 parent b71c78e commit d2f861e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/dvclive/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
DVCLIVE_LOGLEVEL = "DVCLIVE_LOGLEVEL"
DVCLIVE_OPEN = "DVCLIVE_OPEN"
DVCLIVE_RESUME = "DVCLIVE_RESUME"
DVC_CHECKPOINT = "DVC_CHECKPOINT"
DVCLIVE_TEST = "DVCLIVE_TEST"
DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV"
DVC_EXP_NAME = "DVC_EXP_NAME"
DVC_ROOT = "DVC_ROOT"
27 changes: 24 additions & 3 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union

Expand Down Expand Up @@ -83,8 +84,6 @@ def __init__(
self._dvcyaml = dvcyaml
self._cache_images = cache_images

os.makedirs(self.dir, exist_ok=True)

self._report_mode: Optional[str] = report
self._report_notebook = None
self._init_report()
Expand All @@ -97,7 +96,12 @@ def __init__(
self._inside_dvc_pipeline: bool = False
self._dvc_repo = None
self._include_untracked: List[str] = []
self._init_dvc()
if env2bool(env.DVCLIVE_TEST):
self._init_test()
else:
self._init_dvc()

os.makedirs(self.dir, exist_ok=True)

if self._resume:
self._init_resume()
Expand Down Expand Up @@ -266,6 +270,23 @@ def _init_report(self):
self._report_mode = None
logger.debug(f"{self._report_mode=}")

def _init_test(self):
"""
Enables test mode that writes to temp paths and doesn't depend on repo.
Needed to run integration tests in external libraries like huggingface
accelerate.
"""
with tempfile.TemporaryDirectory() as dirpath:
self._dir = os.path.join(dirpath, self._dir)
if isinstance(self._dvcyaml, str):
self._dvc_file = os.path.join(dirpath, self._dvcyaml)
self._save_dvc_exp = False
logger.warning(
"DVCLive testing mode enabled."
f"Repo will be ignored and output will be written to {dirpath}."
)

@property
def dir(self) -> str: # noqa: A003
return self._dir
Expand Down
13 changes: 12 additions & 1 deletion tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dvclive import Live
from dvclive.dvc import get_dvc_repo
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT, DVCLIVE_TEST


def test_get_dvc_repo(tmp_dir):
Expand Down Expand Up @@ -244,3 +244,14 @@ def test_get_exp_name_duplicate(tmp_dir, mocked_dvc_repo, mocker, caplog):
assert live._exp_name == "random"
msg = "Experiment conflicts with existing experiment 'duplicate'."
assert msg in caplog.text


def test_test_mode(tmp_dir, monkeypatch, mocked_dvc_repo):
monkeypatch.setenv(DVCLIVE_TEST, "true")
live = Live("dir", dvcyaml="dvc.yaml")
live.make_dvcyaml()
assert live._dir != "dir"
assert live._dvc_file != "dvc.yaml"
assert live._save_dvc_exp is False
assert not os.path.exists("dir")
assert not os.path.exists("dvc.yaml")

0 comments on commit d2f861e

Please sign in to comment.