Skip to content

Commit

Permalink
update 3.0 args in lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Oct 18, 2023
1 parent f2924bd commit 66ef8b7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ def __init__( # noqa: PLR0913
prefix="",
log_model: Union[str, bool] = False,
experiment=None,
dir: Optional[str] = None, # noqa: A002
dir: str = "dvclive", # noqa: A002
resume: bool = False,
report: Optional[str] = None,
save_dvc_exp: bool = False,
dvcyaml: bool = True,
save_dvc_exp: bool = True,
dvcyaml: Union[str, bool] = True,
cache_images: bool = False,
exp_message: Optional[str] = None,
):
super().__init__()
self._prefix = prefix
Expand Down
4 changes: 2 additions & 2 deletions tests/frameworks/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def validation_step(self, *args, **kwargs):
return loss


def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
def test_lightning_val_updates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"""Test the `self.experiment._latest_studio_step -= 1` logic."""
mocked_post, _ = mocked_studio_post

Expand All @@ -281,7 +281,7 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio
# 2: update_train_step_metrics
# 3: log_eval_end_metrics
plots = calls[3][1]["json"]["plots"]
val_loss = plots["dvclive/dvc.yaml::dvclive/plots/metrics/val/loss.tsv"]
val_loss = plots["dvclive/plots/metrics/val/loss.tsv"]
# Without `self.experiment._latest_studio_step -= 1`
# This would be empty
assert len(val_loss["data"]) == 1
Expand Down

0 comments on commit 66ef8b7

Please sign in to comment.