Skip to content

Commit

Permalink
Support message. (#33)
Browse files Browse the repository at this point in the history
Optional string in the `start` event.
Internally cutting to 72 characters to match commit tile limit.
  • Loading branch information
daavoo authored May 4, 2023
1 parent 37d5e26 commit db168bd
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
11 changes: 9 additions & 2 deletions src/dvc_studio_client/post_live_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None):
return studio_token, studio_repo_url


def post_live_metrics(
def post_live_metrics( # noqa: C901
event_type: Literal["start", "data", "done"],
baseline_sha: str,
name: str,
client: Literal["dvc", "dvclive"],
experiment_rev: Optional[str] = None,
machine: Optional[Dict[str, Any]] = None,
message: Optional[str] = None,
metrics: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
plots: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -110,6 +111,8 @@ def post_live_metrics(
"instance": "t2.micro"
}
```
message: (Optional[str]): Custom message to be displayed as the commit
message in Studio UI.
metrics (Optional[Dict[str, Any]]): Updates to DVC metric files.
Defaults to `None`.
Only used when `event_type="data"`.
Expand Down Expand Up @@ -178,7 +181,11 @@ def post_live_metrics(
if machine:
body["machine"] = machine

if event_type == "data":
if event_type == "start":
if message:
# Cutting the message to match the commit title length limit.
body["message"] = message[:72]
elif event_type == "data":
if step is None:
logger.warning("Missing `step` in `data` event.")
return None
Expand Down
6 changes: 5 additions & 1 deletion src/dvc_studio_client/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def Choices(*choices):
}
)
SCHEMAS_BY_TYPE = {
"start": BASE_SCHEMA,
"start": BASE_SCHEMA.extend(
{
"message": str,
}
),
"data": BASE_SCHEMA.extend(
{
Required("step"): int,
Expand Down
61 changes: 61 additions & 0 deletions tests/test_post_live_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,64 @@ def test_post_live_metrics_token_and_repo_url_args(mocker, monkeypatch):
},
timeout=5,
)


def test_post_live_metrics_message(mocker, monkeypatch):
monkeypatch.setenv(DVC_STUDIO_URL, "https://0.0.0.0")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
mocked_response.status_code = 200
mocked_post = mocker.patch("requests.post", return_value=mocked_response)

assert post_live_metrics(
"start",
"f" * 40,
"fooname",
"fooclient",
message="FOO_MESSAGE",
)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
json={
"type": "start",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"message": "FOO_MESSAGE",
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=5,
)

# Test message length limit
assert post_live_metrics(
"start",
"f" * 40,
"fooname",
"fooclient",
message="X" * 100,
)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
json={
"type": "start",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"message": "X" * 72,
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=5,
)

0 comments on commit db168bd

Please sign in to comment.