Skip to content

Commit

Permalink
Tests: Add test infrastructure from MLflow repository
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Sep 9, 2023
1 parent 6d2cde8 commit 6efe947
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 3 deletions.
Empty file added tests/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions tests/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Source: mlflow:tests/store/tracking/__init__.py
import json

import pytest
from mlflow.entities import RunTag
from mlflow.models import Model
from mlflow.utils.mlflow_tags import MLFLOW_LOGGED_MODELS


class AbstractStoreTest:
def create_test_run(self):
raise Exception("this should be overridden")

def get_store(self):
raise Exception("this should be overridden")

def test_record_logged_model(self):
store = self.get_store()
run_id = self.create_test_run().info.run_id
m = Model(artifact_path="model/path", run_id=run_id, flavors={"tf": "flavor body"})
store.record_logged_model(run_id, m)
self._verify_logged(
store,
run_id=run_id,
params=[],
metrics=[],
tags=[RunTag(MLFLOW_LOGGED_MODELS, json.dumps([m.to_dict()]))],
)
m2 = Model(artifact_path="some/other/path", run_id=run_id, flavors={"R": {"property": "value"}})
store.record_logged_model(run_id, m2)
self._verify_logged(
store,
run_id,
params=[],
metrics=[],
tags=[RunTag(MLFLOW_LOGGED_MODELS, json.dumps([m.to_dict(), m2.to_dict()]))],
)
m3 = Model(artifact_path="some/other/path2", run_id=run_id, flavors={"R2": {"property": "value"}})
store.record_logged_model(run_id, m3)
self._verify_logged(
store,
run_id,
params=[],
metrics=[],
tags=[RunTag(MLFLOW_LOGGED_MODELS, json.dumps([m.to_dict(), m2.to_dict(), m3.to_dict()]))],
)
with pytest.raises(
TypeError,
match="Argument 'mlflow_model' should be mlflow.models.Model, got '<class 'dict'>'",
):
store.record_logged_model(run_id, m.to_dict())

@staticmethod
def _verify_logged(store, run_id, metrics, params, tags):
run = store.get_run(run_id)
all_metrics = sum([store.get_metric_history(run_id, key) for key in run.data.metrics], [])
assert len(all_metrics) == len(metrics)
logged_metrics = [(m.key, m.value, m.timestamp, m.step) for m in all_metrics]
assert set(logged_metrics) == {(m.key, m.value, m.timestamp, m.step) for m in metrics}
logged_tags = set(run.data.tags.items())
assert {(tag.key, tag.value) for tag in tags} <= logged_tags
assert len(run.data.params) == len(params)
assert set(run.data.params.items()) == {(param.key, param.value) for param in params}
7 changes: 4 additions & 3 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Source: mlflow:tests/tracking/test_tracking.py
import json
import math
import os
Expand Down Expand Up @@ -68,9 +69,9 @@
from mlflow.utils.time_utils import get_current_time_millis
from mlflow.utils.uri import extract_db_type_from_uri

from tests.integration.utils import invoke_cli_runner
from tests.store.tracking import AbstractStoreTest
from tests.store.tracking.test_file_store import assert_dataset_inputs_equal
from mlflow_cratedb.adapter.db import CRATEDB
from .abstract import AbstractStoreTest
from .util import invoke_cli_runner, assert_dataset_inputs_equal

DB_URI = "sqlite:///"
ARTIFACT_URI = "artifact_folder"
Expand Down
30 changes: 30 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Source: mlflow:tests/integration/utils.py and mlflow:tests/store/tracking/test_file_store.py
from typing import List

from click.testing import CliRunner
from mlflow.entities import DatasetInput


def invoke_cli_runner(*args, **kwargs):
"""
Helper method to invoke the CliRunner while asserting that the exit code is actually 0.
"""

res = CliRunner().invoke(*args, **kwargs)
assert res.exit_code == 0, f"Got non-zero exit code {res.exit_code}. Output is: {res.output}"
return res


def assert_dataset_inputs_equal(inputs1: List[DatasetInput], inputs2: List[DatasetInput]):
inputs1 = sorted(inputs1, key=lambda inp: (inp.dataset.name, inp.dataset.digest))
inputs2 = sorted(inputs2, key=lambda inp: (inp.dataset.name, inp.dataset.digest))
assert len(inputs1) == len(inputs2)
for idx, inp1 in enumerate(inputs1):
inp2 = inputs2[idx]
assert dict(inp1.dataset) == dict(inp2.dataset)
tags1 = sorted(inp1.tags, key=lambda tag: tag.key)
tags2 = sorted(inp2.tags, key=lambda tag: tag.key)
for idx, tag1 in enumerate(tags1):
tag2 = tags2[idx]
assert tag1.key == tag1.key
assert tag1.value == tag2.value

0 comments on commit 6efe947

Please sign in to comment.