diff --git a/src/dvclive/env.py b/src/dvclive/env.py index 4b26b651..5230397f 100644 --- a/src/dvclive/env.py +++ b/src/dvclive/env.py @@ -1,7 +1,7 @@ DVCLIVE_LOGLEVEL = "DVCLIVE_LOGLEVEL" DVCLIVE_OPEN = "DVCLIVE_OPEN" DVCLIVE_RESUME = "DVCLIVE_RESUME" -DVC_CHECKPOINT = "DVC_CHECKPOINT" +DVCLIVE_TEST = "DVCLIVE_TEST" DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV" DVC_EXP_NAME = "DVC_EXP_NAME" DVC_ROOT = "DVC_ROOT" diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 3cb6dce8..9fa81f0d 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -4,6 +4,7 @@ import math import os import shutil +import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union @@ -83,8 +84,6 @@ def __init__( self._dvcyaml = dvcyaml self._cache_images = cache_images - os.makedirs(self.dir, exist_ok=True) - self._report_mode: Optional[str] = report self._report_notebook = None self._init_report() @@ -97,7 +96,12 @@ def __init__( self._inside_dvc_pipeline: bool = False self._dvc_repo = None self._include_untracked: List[str] = [] - self._init_dvc() + if env2bool(env.DVCLIVE_TEST): + self._init_test() + else: + self._init_dvc() + + os.makedirs(self.dir, exist_ok=True) if self._resume: self._init_resume() @@ -266,6 +270,23 @@ def _init_report(self): self._report_mode = None logger.debug(f"{self._report_mode=}") + def _init_test(self): + """ + Enables test mode that writes to temp paths and doesn't depend on repo. + + Needed to run integration tests in external libraries like huggingface + accelerate. + """ + with tempfile.TemporaryDirectory() as dirpath: + self._dir = os.path.join(dirpath, self._dir) + if isinstance(self._dvcyaml, str): + self._dvc_file = os.path.join(dirpath, self._dvcyaml) + self._save_dvc_exp = False + logger.warning( + "DVCLive testing mode enabled." + f"Repo will be ignored and output will be written to {dirpath}." + ) + @property def dir(self) -> str: # noqa: A003 return self._dir diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 73e16465..51ef8f36 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -8,7 +8,7 @@ from dvclive import Live from dvclive.dvc import get_dvc_repo -from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT +from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT, DVCLIVE_TEST def test_get_dvc_repo(tmp_dir): @@ -244,3 +244,14 @@ def test_get_exp_name_duplicate(tmp_dir, mocked_dvc_repo, mocker, caplog): assert live._exp_name == "random" msg = "Experiment conflicts with existing experiment 'duplicate'." assert msg in caplog.text + + +def test_test_mode(tmp_dir, monkeypatch, mocked_dvc_repo): + monkeypatch.setenv(DVCLIVE_TEST, "true") + live = Live("dir", dvcyaml="dvc.yaml") + live.make_dvcyaml() + assert live._dir != "dir" + assert live._dvc_file != "dvc.yaml" + assert live._save_dvc_exp is False + assert not os.path.exists("dir") + assert not os.path.exists("dvc.yaml")