-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests: Remove tests specific to SQLite
- Loading branch information
Showing
1 changed file
with
2 additions
and
305 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2738,71 +2738,6 @@ def test_upgrade_cli_idempotence(self): | |
assert _get_schema_version(engine) == _get_latest_schema_revision() | ||
engine.dispose() | ||
|
||
def test_metrics_materialization_upgrade_succeeds_and_produces_expected_latest_metric_values( | ||
self, | ||
): | ||
""" | ||
Tests the ``89d4b8295536_create_latest_metrics_table`` migration by migrating and querying | ||
the MLflow Tracking SQLite database located at | ||
/mlflow/tests/resources/db/db_version_7ac759974ad8_with_metrics.sql. This database contains | ||
metric entries populated by the following metrics generation script: | ||
https://gist.github.com/dbczumar/343173c6b8982a0cc9735ff19b5571d9. | ||
First, the database is upgraded from its HEAD revision of | ||
``7ac755974ad8_update_run_tags_with_larger_limit`` to the latest revision via | ||
``mlflow db upgrade``. | ||
Then, the test confirms that the metric entries returned by calls | ||
to ``SqlAlchemyStore.get_run()`` are consistent between the latest revision and the | ||
``7ac755974ad8_update_run_tags_with_larger_limit`` revision. This is confirmed by | ||
invoking ``SqlAlchemyStore.get_run()`` for each run id that is present in the upgraded | ||
database and comparing the resulting runs' metric entries to a JSON dump taken from the | ||
SQLite database prior to the upgrade (located at | ||
mlflow/tests/resources/db/db_version_7ac759974ad8_with_metrics_expected_values.json). | ||
This JSON dump can be replicated by installing MLflow version 1.2.0 and executing the | ||
following code from the directory containing this test suite: | ||
.. code-block:: python | ||
import json | ||
import mlflow | ||
from mlflow import MlflowClient | ||
mlflow.set_tracking_uri( | ||
"sqlite:///../../resources/db/db_version_7ac759974ad8_with_metrics.sql" | ||
) | ||
client = MlflowClient() | ||
summary_metrics = { | ||
run.info.run_id: run.data.metrics for run in client.search_runs(experiment_ids="0") | ||
} | ||
with open("dump.json", "w") as dump_file: | ||
json.dump(summary_metrics, dump_file, indent=4) | ||
""" | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
db_resources_path = os.path.normpath( | ||
os.path.join(current_dir, os.pardir, os.pardir, "resources", "db") | ||
) | ||
expected_metric_values_path = os.path.join( | ||
db_resources_path, "db_version_7ac759974ad8_with_metrics_expected_values.json" | ||
) | ||
with TempDir() as tmp_db_dir: | ||
db_path = tmp_db_dir.path("tmp_db.sql") | ||
db_url = "sqlite:///" + db_path | ||
shutil.copyfile( | ||
src=os.path.join(db_resources_path, "db_version_7ac759974ad8_with_metrics.sql"), | ||
dst=db_path, | ||
) | ||
|
||
invoke_cli_runner(mlflow.db.commands, ["upgrade", db_url]) | ||
store = self._get_store(db_uri=db_url) | ||
with open(expected_metric_values_path) as f: | ||
expected_metric_values = json.load(f) | ||
|
||
for run_id, expected_metrics in expected_metric_values.items(): | ||
fetched_run = store.get_run(run_id=run_id) | ||
assert fetched_run.data.metrics == expected_metrics | ||
|
||
def _generate_large_data(self, nb_runs=1000): | ||
experiment_id = self.store.create_experiment("test_experiment") | ||
|
||
|
@@ -3423,33 +3358,6 @@ def test_log_inputs_with_duplicates_in_single_request(self): | |
) | ||
|
||
|
||
def test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db(monkeypatch): | ||
monkeypatch.setenv("MLFLOW_SQLALCHEMYSTORE_POOLCLASS", "SingletonThreadPool") | ||
store = SqlAlchemyStore("sqlite:///:memory:", ARTIFACT_URI) | ||
experiment_id = store.create_experiment(name="exp1") | ||
run = store.create_run( | ||
experiment_id=experiment_id, user_id="user", start_time=0, tags=[], run_name="name" | ||
) | ||
run_id = run.info.run_id | ||
metric = entities.Metric("mymetric", 1, 0, 0) | ||
store.log_metric(run_id=run_id, metric=metric) | ||
param = entities.Param("myparam", "A") | ||
store.log_param(run_id=run_id, param=param) | ||
fetched_run = store.get_run(run_id=run_id) | ||
assert fetched_run.info.run_id == run_id | ||
assert metric.key in fetched_run.data.metrics | ||
assert param.key in fetched_run.data.params | ||
|
||
|
||
def test_sqlalchemy_store_can_be_initialized_when_default_experiment_has_been_deleted( | ||
tmp_sqlite_uri, | ||
): | ||
store = SqlAlchemyStore(tmp_sqlite_uri, ARTIFACT_URI) | ||
store.delete_experiment("0") | ||
assert store.get_experiment("0").lifecycle_stage == entities.LifecycleStage.DELETED | ||
SqlAlchemyStore(tmp_sqlite_uri, ARTIFACT_URI) | ||
|
||
|
||
class TestSqlAlchemyStoreMigratedDB(TestSqlAlchemyStore): | ||
""" | ||
Test case where user has an existing DB with schema generated before MLflow 1.0, | ||
|
@@ -3526,8 +3434,8 @@ def test_get_attribute_name(): | |
assert len(entities.RunInfo.get_orderable_attributes()) == 7 | ||
|
||
|
||
def test_get_orderby_clauses(tmp_sqlite_uri): | ||
store = SqlAlchemyStore(tmp_sqlite_uri, ARTIFACT_URI) | ||
def test_get_orderby_clauses(): | ||
store = SqlAlchemyStore(DB_URI, ARTIFACT_URI) | ||
with store.ManagedSessionMaker() as session: | ||
# test that ['runs.start_time DESC', 'SqlRun.run_uuid'] is returned by default | ||
parsed = [str(x) for x in _get_orderby_clauses([], session)[1]] | ||
|
@@ -3563,214 +3471,3 @@ def test_get_orderby_clauses(tmp_sqlite_uri): | |
assert "value IS NULL" in select_clause[0] | ||
# test that clause name is in parsed | ||
assert "clause_1" in parsed[0] | ||
|
||
|
||
def _assert_create_experiment_appends_to_artifact_uri_path_correctly( | ||
artifact_root_uri, expected_artifact_uri_format | ||
): | ||
# Patch `is_local_uri` to prevent the SqlAlchemy store from attempting to create local | ||
# filesystem directories for file URI and POSIX path test cases | ||
with mock.patch("mlflow.store.tracking.sqlalchemy_store.is_local_uri", return_value=False): | ||
with TempDir() as tmp: | ||
dbfile_path = tmp.path("db") | ||
store = SqlAlchemyStore( | ||
db_uri="sqlite:///" + dbfile_path, default_artifact_root=artifact_root_uri | ||
) | ||
exp_id = store.create_experiment(name="exp") | ||
exp = store.get_experiment(exp_id) | ||
cwd = Path.cwd().as_posix() | ||
drive = Path.cwd().drive | ||
if is_windows() and expected_artifact_uri_format.startswith("file:"): | ||
cwd = f"/{cwd}" | ||
drive = f"{drive}/" | ||
assert exp.artifact_location == expected_artifact_uri_format.format( | ||
e=exp_id, cwd=cwd, drive=drive | ||
) | ||
|
||
|
||
@pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") | ||
@pytest.mark.parametrize( | ||
("input_uri", "expected_uri"), | ||
[ | ||
("file://my_server/my_path/my_sub_path", "file://my_server/my_path/my_sub_path/{e}"), | ||
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"), | ||
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}"), | ||
("#path/to/local/folder?", "file://{cwd}/{e}#path/to/local/folder?"), | ||
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"), | ||
("file:///path/to/local/folder", "file:///{drive}path/to/local/folder/{e}"), | ||
( | ||
"file:path/to/local/folder?param=value", | ||
"file://{cwd}/path/to/local/folder/{e}?param=value", | ||
), | ||
( | ||
"file:///path/to/local/folder?param=value#fragment", | ||
"file:///{drive}path/to/local/folder/{e}?param=value#fragment", | ||
), | ||
], | ||
) | ||
def test_create_experiment_appends_to_artifact_local_path_file_uri_correctly_on_windows( | ||
input_uri, expected_uri | ||
): | ||
_assert_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri) | ||
|
||
|
||
@pytest.mark.skipif(is_windows(), reason="This test fails on Windows") | ||
@pytest.mark.parametrize( | ||
("input_uri", "expected_uri"), | ||
[ | ||
("path/to/local/folder", "{cwd}/path/to/local/folder/{e}"), | ||
("/path/to/local/folder", "/path/to/local/folder/{e}"), | ||
("#path/to/local/folder?", "{cwd}/#path/to/local/folder?/{e}"), | ||
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"), | ||
("file:///path/to/local/folder", "file:///path/to/local/folder/{e}"), | ||
( | ||
"file:path/to/local/folder?param=value", | ||
"file://{cwd}/path/to/local/folder/{e}?param=value", | ||
), | ||
( | ||
"file:///path/to/local/folder?param=value#fragment", | ||
"file:///path/to/local/folder/{e}?param=value#fragment", | ||
), | ||
], | ||
) | ||
def test_create_experiment_appends_to_artifact_local_path_file_uri_correctly( | ||
input_uri, expected_uri | ||
): | ||
_assert_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("input_uri", "expected_uri"), | ||
[ | ||
("s3://bucket/path/to/root", "s3://bucket/path/to/root/{e}"), | ||
( | ||
"s3://bucket/path/to/root?creds=mycreds", | ||
"s3://bucket/path/to/root/{e}?creds=mycreds", | ||
), | ||
( | ||
"dbscheme+driver://root@host/dbname?creds=mycreds#myfragment", | ||
"dbscheme+driver://root@host/dbname/{e}?creds=mycreds#myfragment", | ||
), | ||
( | ||
"dbscheme+driver://root:[email protected]?creds=mycreds#myfragment", | ||
"dbscheme+driver://root:[email protected]/{e}?creds=mycreds#myfragment", | ||
), | ||
( | ||
"dbscheme+driver://root:[email protected]/mydb?creds=mycreds#myfragment", | ||
"dbscheme+driver://root:[email protected]/mydb/{e}?creds=mycreds#myfragment", | ||
), | ||
], | ||
) | ||
def test_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri): | ||
_assert_create_experiment_appends_to_artifact_uri_path_correctly(input_uri, expected_uri) | ||
|
||
|
||
def _assert_create_run_appends_to_artifact_uri_path_correctly( | ||
artifact_root_uri, expected_artifact_uri_format | ||
): | ||
# Patch `is_local_uri` to prevent the SqlAlchemy store from attempting to create local | ||
# filesystem directories for file URI and POSIX path test cases | ||
with mock.patch("mlflow.store.tracking.sqlalchemy_store.is_local_uri", return_value=False): | ||
with TempDir() as tmp: | ||
dbfile_path = tmp.path("db") | ||
store = SqlAlchemyStore( | ||
db_uri="sqlite:///" + dbfile_path, default_artifact_root=artifact_root_uri | ||
) | ||
exp_id = store.create_experiment(name="exp") | ||
run = store.create_run( | ||
experiment_id=exp_id, user_id="user", start_time=0, tags=[], run_name="name" | ||
) | ||
cwd = Path.cwd().as_posix() | ||
drive = Path.cwd().drive | ||
if is_windows() and expected_artifact_uri_format.startswith("file:"): | ||
cwd = f"/{cwd}" | ||
drive = f"{drive}/" | ||
assert run.info.artifact_uri == expected_artifact_uri_format.format( | ||
e=exp_id, r=run.info.run_id, cwd=cwd, drive=drive | ||
) | ||
|
||
|
||
@pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") | ||
@pytest.mark.parametrize( | ||
("input_uri", "expected_uri"), | ||
[ | ||
( | ||
"file://my_server/my_path/my_sub_path", | ||
"file://my_server/my_path/my_sub_path/{e}/{r}/artifacts", | ||
), | ||
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"), | ||
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}/{r}/artifacts"), | ||
("#path/to/local/folder?", "file://{cwd}/{e}/{r}/artifacts#path/to/local/folder?"), | ||
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"), | ||
( | ||
"file:///path/to/local/folder", | ||
"file:///{drive}path/to/local/folder/{e}/{r}/artifacts", | ||
), | ||
( | ||
"file:path/to/local/folder?param=value", | ||
"file://{cwd}/path/to/local/folder/{e}/{r}/artifacts?param=value", | ||
), | ||
( | ||
"file:///path/to/local/folder?param=value#fragment", | ||
"file:///{drive}path/to/local/folder/{e}/{r}/artifacts?param=value#fragment", | ||
), | ||
], | ||
) | ||
def test_create_run_appends_to_artifact_local_path_file_uri_correctly_on_windows( | ||
input_uri, expected_uri | ||
): | ||
_assert_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri) | ||
|
||
|
||
@pytest.mark.skipif(is_windows(), reason="This test fails on Windows") | ||
@pytest.mark.parametrize( | ||
("input_uri", "expected_uri"), | ||
[ | ||
("path/to/local/folder", "{cwd}/path/to/local/folder/{e}/{r}/artifacts"), | ||
("/path/to/local/folder", "/path/to/local/folder/{e}/{r}/artifacts"), | ||
("#path/to/local/folder?", "{cwd}/#path/to/local/folder?/{e}/{r}/artifacts"), | ||
("file:path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"), | ||
( | ||
"file:///path/to/local/folder", | ||
"file:///path/to/local/folder/{e}/{r}/artifacts", | ||
), | ||
( | ||
"file:path/to/local/folder?param=value", | ||
"file://{cwd}/path/to/local/folder/{e}/{r}/artifacts?param=value", | ||
), | ||
( | ||
"file:///path/to/local/folder?param=value#fragment", | ||
"file:///path/to/local/folder/{e}/{r}/artifacts?param=value#fragment", | ||
), | ||
], | ||
) | ||
def test_create_run_appends_to_artifact_local_path_file_uri_correctly(input_uri, expected_uri): | ||
_assert_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("input_uri", "expected_uri"), | ||
[ | ||
("s3://bucket/path/to/root", "s3://bucket/path/to/root/{e}/{r}/artifacts"), | ||
( | ||
"s3://bucket/path/to/root?creds=mycreds", | ||
"s3://bucket/path/to/root/{e}/{r}/artifacts?creds=mycreds", | ||
), | ||
( | ||
"dbscheme+driver://root@host/dbname?creds=mycreds#myfragment", | ||
"dbscheme+driver://root@host/dbname/{e}/{r}/artifacts?creds=mycreds#myfragment", | ||
), | ||
( | ||
"dbscheme+driver://root:[email protected]?creds=mycreds#myfragment", | ||
"dbscheme+driver://root:[email protected]/{e}/{r}/artifacts" | ||
"?creds=mycreds#myfragment", | ||
), | ||
( | ||
"dbscheme+driver://root:[email protected]/mydb?creds=mycreds#myfragment", | ||
"dbscheme+driver://root:[email protected]/mydb/{e}/{r}/artifacts" | ||
"?creds=mycreds#myfragment", | ||
), | ||
], | ||
) | ||
def test_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri): | ||
_assert_create_run_appends_to_artifact_uri_path_correctly(input_uri, expected_uri) |