Skip to content

Commit

Permalink
Tests: Remove tests specific to SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Sep 7, 2023
1 parent 95c2236 commit 967848f
Showing 1 changed file with 2 additions and 305 deletions.
307 changes: 2 additions & 305 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,71 +2738,6 @@ def test_upgrade_cli_idempotence(self):
assert _get_schema_version(engine) == _get_latest_schema_revision()
engine.dispose()

def test_metrics_materialization_upgrade_succeeds_and_produces_expected_latest_metric_values(
self,
):
"""
Tests the ``89d4b8295536_create_latest_metrics_table`` migration by migrating and querying
the MLflow Tracking SQLite database located at
/mlflow/tests/resources/db/db_version_7ac759974ad8_with_metrics.sql. This database contains
metric entries populated by the following metrics generation script:
https://gist.github.com/dbczumar/343173c6b8982a0cc9735ff19b5571d9.
First, the database is upgraded from its HEAD revision of
``7ac755974ad8_update_run_tags_with_larger_limit`` to the latest revision via
``mlflow db upgrade``.
Then, the test confirms that the metric entries returned by calls
to ``SqlAlchemyStore.get_run()`` are consistent between the latest revision and the
``7ac755974ad8_update_run_tags_with_larger_limit`` revision. This is confirmed by
invoking ``SqlAlchemyStore.get_run()`` for each run id that is present in the upgraded
database and comparing the resulting runs' metric entries to a JSON dump taken from the
SQLite database prior to the upgrade (located at
mlflow/tests/resources/db/db_version_7ac759974ad8_with_metrics_expected_values.json).
This JSON dump can be replicated by installing MLflow version 1.2.0 and executing the
following code from the directory containing this test suite:
.. code-block:: python
import json
import mlflow
from mlflow import MlflowClient
mlflow.set_tracking_uri(
"sqlite:///../../resources/db/db_version_7ac759974ad8_with_metrics.sql"
)
client = MlflowClient()
summary_metrics = {
run.info.run_id: run.data.metrics for run in client.search_runs(experiment_ids="0")
}
with open("dump.json", "w") as dump_file:
json.dump(summary_metrics, dump_file, indent=4)
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
db_resources_path = os.path.normpath(
os.path.join(current_dir, os.pardir, os.pardir, "resources", "db")
)
expected_metric_values_path = os.path.join(
db_resources_path, "db_version_7ac759974ad8_with_metrics_expected_values.json"
)
with TempDir() as tmp_db_dir:
db_path = tmp_db_dir.path("tmp_db.sql")
db_url = "sqlite:///" + db_path
shutil.copyfile(
src=os.path.join(db_resources_path, "db_version_7ac759974ad8_with_metrics.sql"),
dst=db_path,
)

invoke_cli_runner(mlflow.db.commands, ["upgrade", db_url])
store = self._get_store(db_uri=db_url)
with open(expected_metric_values_path) as f:
expected_metric_values = json.load(f)

for run_id, expected_metrics in expected_metric_values.items():
fetched_run = store.get_run(run_id=run_id)
assert fetched_run.data.metrics == expected_metrics

def _generate_large_data(self, nb_runs=1000):
experiment_id = self.store.create_experiment("test_experiment")

Expand Down Expand Up @@ -3423,33 +3358,6 @@ def test_log_inputs_with_duplicates_in_single_request(self):
)


def test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db(monkeypatch):
monkeypatch.setenv("MLFLOW_SQLALCHEMYSTORE_POOLCLASS", "SingletonThreadPool")
store = SqlAlchemyStore("sqlite:///:memory:", ARTIFACT_URI)
experiment_id = store.create_experiment(name="exp1")
run = store.create_run(
experiment_id=experiment_id, user_id="user", start_time=0, tags=[], run_name="name"
)
run_id = run.info.run_id
metric = entities.Metric("mymetric", 1, 0, 0)
store.log_metric(run_id=run_id, metric=metric)
param = entities.Param("myparam", "A")
store.log_param(run_id=run_id, param=param)
fetched_run = store.get_run(run_id=run_id)
assert fetched_run.info.run_id == run_id
assert metric.key in fetched_run.data.metrics
assert param.key in fetched_run.data.params


def test_sqlalchemy_store_can_be_initialized_when_default_experiment_has_been_deleted(
tmp_sqlite_uri,
):
store = SqlAlchemyStore(tmp_sqlite_uri, ARTIFACT_URI)
store.delete_experiment("0")
assert store.get_experiment("0").lifecycle_stage == entities.LifecycleStage.DELETED
SqlAlchemyStore(tmp_sqlite_uri, ARTIFACT_URI)


class TestSqlAlchemyStoreMigratedDB(TestSqlAlchemyStore):
"""
Test case where user has an existing DB with schema generated before MLflow 1.0,
Expand Down Expand Up @@ -3526,8 +3434,8 @@ def test_get_attribute_name():
assert len(entities.RunInfo.get_orderable_attributes()) == 7


def test_get_orderby_clauses(tmp_sqlite_uri):
store = SqlAlchemyStore(tmp_sqlite_uri, ARTIFACT_URI)
def test_get_orderby_clauses():
store = SqlAlchemyStore(DB_URI, ARTIFACT_URI)
with store.ManagedSessionMaker() as session:
# test that ['runs.start_time DESC', 'SqlRun.run_uuid'] is returned by default
parsed = [str(x) for x in _get_orderby_clauses([], session)[1]]
Expand Down Expand Up @@ -3563,214 +3471,3 @@ def test_get_orderby_clauses(tmp_sqlite_uri):
assert "value IS NULL" in select_clause[0]
# test that clause name is in parsed
assert "clause_1" in parsed[0]


def _assert_create_experiment_appends_to_artifact_uri_path_correctly(
artifact_root_uri, expected_artifact_uri_format
):
# Patch `is_local_uri` to prevent the SqlAlchemy store from attempting to create local
# filesystem directories for file URI and POSIX path test cases
with mock.patch("mlflow.store.tracking.sqlalchemy_store.is_local_uri", return_value=False):
with TempDir() as tmp:
dbfile_path = tmp.path("db")
store = SqlAlchemyStore(
db_uri="sqlite:///" + dbfile_path, default_artifact_root=artifact_root_uri
)
exp_id = store.create_experiment(name="exp")
exp = store.get_experiment(exp_id)
cwd = Path.cwd().as_posix()
drive = Path.cwd().drive
if is_windows() and expected_artifact_uri_format.startswith("file:"):
cwd = f"/{cwd}"
drive = f"{drive}/"
assert exp.artifact_location == expected_artifact_uri_format.format(
e=exp_id, cwd=cwd, drive=drive
)


@pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("file://my_server/my_path/my_sub_path", "file://my_server/my_path/my_sub_path/{e}"),
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"),
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}"),
("#path/to/local/folder?", "file://{cwd}/{e}#path/to/local/folder?"),
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"),
("file:///path/to/local/folder", "file:///{drive}path/to/local/folder/{e}"),
(
"file:path/to/local/folder?param=value",
"file://{cwd}/path/to/local/folder/{e}?param=value",
),
(
"file:///path/to/local/folder?param=value#fragment",
"file:///{drive}path/to/local/folder/{e}?param=value#fragment",
),
],
)
def test_create_experiment_appends_to_artifact_local_path_file_uri_correctly_on_windows(
input_uri, expected_uri
):
_assert_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri)


@pytest.mark.skipif(is_windows(), reason="This test fails on Windows")
@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("path/to/local/folder", "{cwd}/path/to/local/folder/{e}"),
("/path/to/local/folder", "/path/to/local/folder/{e}"),
("#path/to/local/folder?", "{cwd}/#path/to/local/folder?/{e}"),
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"),
("file:///path/to/local/folder", "file:///path/to/local/folder/{e}"),
(
"file:path/to/local/folder?param=value",
"file://{cwd}/path/to/local/folder/{e}?param=value",
),
(
"file:///path/to/local/folder?param=value#fragment",
"file:///path/to/local/folder/{e}?param=value#fragment",
),
],
)
def test_create_experiment_appends_to_artifact_local_path_file_uri_correctly(
input_uri, expected_uri
):
_assert_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri)


@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("s3://bucket/path/to/root", "s3://bucket/path/to/root/{e}"),
(
"s3://bucket/path/to/root?creds=mycreds",
"s3://bucket/path/to/root/{e}?creds=mycreds",
),
(
"dbscheme+driver://root@host/dbname?creds=mycreds#myfragment",
"dbscheme+driver://root@host/dbname/{e}?creds=mycreds#myfragment",
),
(
"dbscheme+driver://root:[email protected]?creds=mycreds#myfragment",
"dbscheme+driver://root:[email protected]/{e}?creds=mycreds#myfragment",
),
(
"dbscheme+driver://root:[email protected]/mydb?creds=mycreds#myfragment",
"dbscheme+driver://root:[email protected]/mydb/{e}?creds=mycreds#myfragment",
),
],
)
def test_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri):
_assert_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri)


def _assert_create_run_appends_to_artifact_uri_path_correctly(
artifact_root_uri, expected_artifact_uri_format
):
# Patch `is_local_uri` to prevent the SqlAlchemy store from attempting to create local
# filesystem directories for file URI and POSIX path test cases
with mock.patch("mlflow.store.tracking.sqlalchemy_store.is_local_uri", return_value=False):
with TempDir() as tmp:
dbfile_path = tmp.path("db")
store = SqlAlchemyStore(
db_uri="sqlite:///" + dbfile_path, default_artifact_root=artifact_root_uri
)
exp_id = store.create_experiment(name="exp")
run = store.create_run(
experiment_id=exp_id, user_id="user", start_time=0, tags=[], run_name="name"
)
cwd = Path.cwd().as_posix()
drive = Path.cwd().drive
if is_windows() and expected_artifact_uri_format.startswith("file:"):
cwd = f"/{cwd}"
drive = f"{drive}/"
assert run.info.artifact_uri == expected_artifact_uri_format.format(
e=exp_id, r=run.info.run_id, cwd=cwd, drive=drive
)


@pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
(
"file://my_server/my_path/my_sub_path",
"file://my_server/my_path/my_sub_path/{e}/{r}/artifacts",
),
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"),
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}/{r}/artifacts"),
("#path/to/local/folder?", "file://{cwd}/{e}/{r}/artifacts#path/to/local/folder?"),
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"),
(
"file:///path/to/local/folder",
"file:///{drive}path/to/local/folder/{e}/{r}/artifacts",
),
(
"file:path/to/local/folder?param=value",
"file://{cwd}/path/to/local/folder/{e}/{r}/artifacts?param=value",
),
(
"file:///path/to/local/folder?param=value#fragment",
"file:///{drive}path/to/local/folder/{e}/{r}/artifacts?param=value#fragment",
),
],
)
def test_create_run_appends_to_artifact_local_path_file_uri_correctly_on_windows(
input_uri, expected_uri
):
_assert_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri)


@pytest.mark.skipif(is_windows(), reason="This test fails on Windows")
@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("path/to/local/folder", "{cwd}/path/to/local/folder/{e}/{r}/artifacts"),
("/path/to/local/folder", "/path/to/local/folder/{e}/{r}/artifacts"),
("#path/to/local/folder?", "{cwd}/#path/to/local/folder?/{e}/{r}/artifacts"),
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"),
(
"file:///path/to/local/folder",
"file:///path/to/local/folder/{e}/{r}/artifacts",
),
(
"file:path/to/local/folder?param=value",
"file://{cwd}/path/to/local/folder/{e}/{r}/artifacts?param=value",
),
(
"file:///path/to/local/folder?param=value#fragment",
"file:///path/to/local/folder/{e}/{r}/artifacts?param=value#fragment",
),
],
)
def test_create_run_appends_to_artifact_local_path_file_uri_correctly(input_uri, expected_uri):
_assert_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri)


@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("s3://bucket/path/to/root", "s3://bucket/path/to/root/{e}/{r}/artifacts"),
(
"s3://bucket/path/to/root?creds=mycreds",
"s3://bucket/path/to/root/{e}/{r}/artifacts?creds=mycreds",
),
(
"dbscheme+driver://root@host/dbname?creds=mycreds#myfragment",
"dbscheme+driver://root@host/dbname/{e}/{r}/artifacts?creds=mycreds#myfragment",
),
(
"dbscheme+driver://root:[email protected]?creds=mycreds#myfragment",
"dbscheme+driver://root:[email protected]/{e}/{r}/artifacts"
"?creds=mycreds#myfragment",
),
(
"dbscheme+driver://root:[email protected]/mydb?creds=mycreds#myfragment",
"dbscheme+driver://root:[email protected]/mydb/{e}/{r}/artifacts"
"?creds=mycreds#myfragment",
),
],
)
def test_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri):
_assert_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri)

0 comments on commit 967848f

Please sign in to comment.