Skip to content

Commit

Permalink
Tests: Add support for CrateDB
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Sep 7, 2023
1 parent 5a623ba commit 867ad6e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mlflow_cratedb/adapter/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import sqlalchemy as sa
import sqlparse
from crate.client.sqlalchemy import CrateDialect
from sqlalchemy.event import listen

CRATEDB = CrateDialect.name


def _setup_db_create_tables(engine: sa.Engine):
"""
Expand Down
5 changes: 4 additions & 1 deletion mlflow_cratedb/monkey/db_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from mlflow_cratedb.adapter.db import CRATEDB


def patch_dbtypes():
"""
Register CrateDB as available database type.
"""
import mlflow.store.db.db_types as db_types

db_types.CRATEDB = "crate"
db_types.CRATEDB = CRATEDB

if db_types.CRATEDB not in db_types.DATABASE_ENGINES:
db_types.DATABASE_ENGINES.append(db_types.CRATEDB)
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ repository = "https://github.com/crate-workbench/mlflow-cratedb"
[tool.black]
line-length = 120

extend-exclude = "tests/test_tracking.py"

[tool.isort]
profile = "black"
skip_glob = "**/site-packages/**"
Expand All @@ -102,6 +104,7 @@ log_cli_level = "DEBUG"
testpaths = ["tests"]
xfail_strict = true
markers = [
"notrackingurimock",
]

[tool.coverage.run]
Expand Down Expand Up @@ -160,6 +163,11 @@ select = [
"RET",
]

extend-exclude = [
"tests/test_tracking.py",
]


[tool.ruff.per-file-ignores]
"tests/*" = ["S101"] # Use of `assert` detected

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from mlflow_cratedb import patch_all

patch_all()
10 changes: 9 additions & 1 deletion tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from tests.store.tracking import AbstractStoreTest
from tests.store.tracking.test_file_store import assert_dataset_inputs_equal

DB_URI = "sqlite:///"
DB_URI = "crate://crate@localhost/?schema=testdrive"
ARTIFACT_URI = "artifact_folder"

pytestmark = pytest.mark.notrackingurimock
Expand Down Expand Up @@ -150,6 +150,8 @@ def create_test_run(self):
return self._run_factory()

def _setup_db_uri(self):
# Original code
"""
if uri := MLFLOW_TRACKING_URI.get():
self.temp_dbfile = None
self.db_url = uri
Expand All @@ -158,6 +160,9 @@ def _setup_db_uri(self):
# Close handle immediately so that we can remove the file later on in Windows
os.close(fd)
self.db_url = f"{DB_URI}{self.temp_dbfile}"
"""
self.temp_dbfile = None
self.db_url = DB_URI

def setUp(self):
self._setup_db_uri()
Expand All @@ -177,6 +182,8 @@ def _get_query_to_reset_experiment_id(self):
elif dialect == SQLITE:
# In SQLite, deleting all experiments resets experiment_id
return None
elif dialect == CRATEDB:
return None
raise ValueError(f"Invalid dialect: {dialect}")

def tearDown(self):
Expand Down Expand Up @@ -1090,6 +1097,7 @@ def test_log_null_param(self):
POSTGRES: r"null value in column .+ of relation .+ violates not-null constrain",
MYSQL: r"Column .+ cannot be null",
MSSQL: r"Cannot insert the value NULL into column .+, table .+",
CRATEDB: r'".+" must not be null',
}[dialect]
with pytest.raises(MlflowException, match=regex) as exception_context:
self.store.log_param(run.info.run_id, param)
Expand Down

0 comments on commit 867ad6e

Please sign in to comment.