From 23705bb9c964b95a323f2f82631ff5168f1df2b9 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 13:12:31 +0200 Subject: [PATCH 01/18] Add adapter wrapper for MLflow/CrateDB, based on monkeypatching --- .gitignore | 3 + MANIFEST.in | 1 + README.md | 21 ++- mlflow_cratedb/__init__.py | 7 + mlflow_cratedb/adapter/__init__.py | 0 mlflow_cratedb/adapter/db.py | 30 ++++ mlflow_cratedb/adapter/ddl/cratedb.sql | 137 ++++++++++++++++++ mlflow_cratedb/adapter/ddl/drop.sql | 15 ++ mlflow_cratedb/adapter/util.py | 8 + mlflow_cratedb/cli.py | 2 + mlflow_cratedb/monkey/__init__.py | 25 ++++ mlflow_cratedb/monkey/db_types.py | 10 ++ mlflow_cratedb/monkey/db_utils.py | 50 +++++++ .../monkey/environment_variables.py | 8 + mlflow_cratedb/monkey/models.py | 29 ++++ mlflow_cratedb/monkey/server.py | 34 +++++ mlflow_cratedb/monkey/tracking.py | 56 +++++++ mlflow_cratedb/server.py | 2 + pyproject.toml | 15 +- 19 files changed, 445 insertions(+), 8 deletions(-) create mode 100644 MANIFEST.in create mode 100644 mlflow_cratedb/adapter/__init__.py create mode 100644 mlflow_cratedb/adapter/db.py create mode 100644 mlflow_cratedb/adapter/ddl/cratedb.sql create mode 100644 mlflow_cratedb/adapter/ddl/drop.sql create mode 100644 mlflow_cratedb/adapter/util.py create mode 100644 mlflow_cratedb/cli.py create mode 100644 mlflow_cratedb/monkey/__init__.py create mode 100644 mlflow_cratedb/monkey/db_types.py create mode 100644 mlflow_cratedb/monkey/db_utils.py create mode 100644 mlflow_cratedb/monkey/environment_variables.py create mode 100644 mlflow_cratedb/monkey/models.py create mode 100644 mlflow_cratedb/monkey/server.py create mode 100644 mlflow_cratedb/monkey/tracking.py create mode 100644 mlflow_cratedb/server.py diff --git a/.gitignore b/.gitignore index 675cdcb..0a3dedd 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ coverage.xml *.egg-info *.pyc __pycache__ + +/mlartifacts +/mlruns diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..103fd31 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include mlflow_cratedb *.sql diff --git a/README.md b/README.md index 137e93d..cb9c937 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ ## About -A wrapper around [MLflow] to use [CrateDB] as storage database for [MLflow Tracking]. +An adapter wrapper for [MLflow] to use [CrateDB] as a storage database +for [MLflow Tracking]. ## Setup @@ -16,7 +17,23 @@ pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb' ## Usage -TODO. +In order to spin up a CrateDB instance without further ado, you can use +Docker or Podman. +```shell +docker run --rm -it --publish=4200:4200 --publish=5432:5432 \ + --env=CRATE_HEAP_SIZE=4g crate \ + -Cdiscovery.type=single-node \ + -Ccluster.routing.allocation.disk.threshold_enabled=false +``` + +Start the MLflow server, pointing it to your [CrateDB] instance, +running on `localhost`. +```shell +mlflow-cratedb server --backend-store-uri='crate://crate@localhost' --dev +``` + +Please note that you need to invoke the `mlflow-cratedb` command, which +runs MLflow amalgamated with the necessary changes to support CrateDB. ## Development diff --git a/mlflow_cratedb/__init__.py b/mlflow_cratedb/__init__.py index e69de29..6c2eec6 100644 --- a/mlflow_cratedb/__init__.py +++ b/mlflow_cratedb/__init__.py @@ -0,0 +1,7 @@ +from mlflow.utils import logging_utils + +from mlflow_cratedb.monkey import patch_all + +# Enable logging, and activate monkeypatch. +logging_utils.enable_logging() +patch_all() diff --git a/mlflow_cratedb/adapter/__init__.py b/mlflow_cratedb/adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlflow_cratedb/adapter/db.py b/mlflow_cratedb/adapter/db.py new file mode 100644 index 0000000..a7f4639 --- /dev/null +++ b/mlflow_cratedb/adapter/db.py @@ -0,0 +1,30 @@ +import importlib.resources + +import sqlalchemy as sa +import sqlparse + + +def _setup_db_create_tables(engine: sa.Engine): + """ + Because CrateDB does not play well with a full-fledged SQLAlchemy data model and + corresponding Alembic migrations, shortcut that and replace it with a classic + database schema provisioning based on SQL DDL. + + It will cause additional maintenance, but well, c'est la vie. + + TODO: Currently, the path is hardcoded to `cratedb.sql`. + """ + schema_name = engine.url.query.get("schema") + with importlib.resources.path("mlflow_cratedb.adapter", "ddl") as ddl: + schema = ddl.joinpath("cratedb.sql") + sql_statements = schema.read_text().format(schema_name=schema_name) + with engine.connect() as connection: + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + + +def _setup_db_drop_tables(): + """ + TODO: Not implemented yet. + """ + pass diff --git a/mlflow_cratedb/adapter/ddl/cratedb.sql b/mlflow_cratedb/adapter/ddl/cratedb.sql new file mode 100644 index 0000000..2efd362 --- /dev/null +++ b/mlflow_cratedb/adapter/ddl/cratedb.sql @@ -0,0 +1,137 @@ +CREATE TABLE IF NOT EXISTS "{schema_name}"."datasets" ( + "dataset_uuid" TEXT NOT NULL, + "experiment_id" BIGINT NOT NULL, + "name" TEXT NOT NULL, + "digest" TEXT NOT NULL, + "dataset_source_type" TEXT NOT NULL, + "dataset_source" TEXT NOT NULL, + "dataset_schema" TEXT, + "dataset_profile" TEXT, + PRIMARY KEY ("experiment_id", "name", "digest") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."experiment_tags" ( + "key" TEXT NOT NULL, + "value" TEXT, + "experiment_id" BIGINT NOT NULL, + PRIMARY KEY ("key", "experiment_id") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."experiments" ( + "experiment_id" BIGINT NOT NULL, -- default=autoincrement + "name" TEXT NOT NULL, + "artifact_location" TEXT, + "lifecycle_stage" TEXT, + "creation_time" BIGINT, -- default=get_current_time_millis + "last_update_time" BIGINT, -- default=get_current_time_millis + PRIMARY KEY ("experiment_id") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."inputs" ( + "input_uuid" TEXT NOT NULL, + "source_type" TEXT NOT NULL, + "source_id" TEXT NOT NULL, + "destination_type" TEXT NOT NULL, + "destination_id" TEXT NOT NULL, + PRIMARY KEY ("source_type", "source_id", "destination_type", "destination_id") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."input_tags" ( + "input_uuid" TEXT NOT NULL, + "name" TEXT NOT NULL, + "value" TEXT NOT NULL, + PRIMARY KEY ("input_uuid", "name") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."latest_metrics" ( + "key" TEXT NOT NULL, + "value" REAL NOT NULL, + "timestamp" BIGINT NOT NULL, + "step" BIGINT NOT NULL, + "is_nan" BOOLEAN NOT NULL, + "run_uuid" TEXT NOT NULL, + PRIMARY KEY ("key", "run_uuid") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."metrics" ( + "key" TEXT NOT NULL, + "value" REAL NOT NULL, + "timestamp" BIGINT NOT NULL, + "step" BIGINT NOT NULL, + "is_nan" BOOLEAN NOT NULL, + "run_uuid" TEXT NOT NULL, + PRIMARY KEY ("key", "timestamp", "step", "run_uuid", "is_nan") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."model_versions" ( + "name" TEXT NOT NULL, + "version" INTEGER NOT NULL, + "creation_time" BIGINT, -- default=get_current_time_millis + "last_update_time" BIGINT, -- default=get_current_time_millis + "description" TEXT, + "user_id" TEXT, + "current_stage" TEXT, + "source" TEXT, + "run_id" TEXT, + "run_link" TEXT, + "status" TEXT, + "status_message" TEXT +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."model_version_tags" ( + "name" TEXT NOT NULL, + "version" INTEGER NOT NULL, + "key" TEXT NOT NULL, + "value" TEXT +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."params" ( + "key" TEXT NOT NULL, + "value" TEXT NOT NULL, + "run_uuid" TEXT NOT NULL, + PRIMARY KEY ("key", "run_uuid") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."registered_models" ( + "name" TEXT NOT NULL, + "key" TEXT NOT NULL, + "value" TEXT +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."registered_model_aliases" ( + "name" TEXT NOT NULL, + "alias" TEXT NOT NULL, + "version" TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."registered_model_tags" ( + "name" TEXT NOT NULL, + "creation_time" BIGINT, -- default=get_current_time_millis + "last_update_time" BIGINT, -- default=get_current_time_millis + "description" TEXT +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."runs" ( + "run_uuid" TEXT NOT NULL, + "name" TEXT, + "source_type" TEXT, + "source_name" TEXT, + "entry_point_name" TEXT, + "user_id" TEXT, + "status" TEXT, + "start_time" BIGINT, + "end_time" BIGINT, + "deleted_time" BIGINT, + "source_version" TEXT, + "lifecycle_stage" TEXT, + "artifact_uri" TEXT, + "experiment_id" BIGINT, + PRIMARY KEY ("run_uuid") +); + +CREATE TABLE IF NOT EXISTS "{schema_name}"."tags" ( + "key" TEXT NOT NULL, + "value" TEXT, + "run_uuid" TEXT NOT NULL, + PRIMARY KEY ("key", "run_uuid") +); diff --git a/mlflow_cratedb/adapter/ddl/drop.sql b/mlflow_cratedb/adapter/ddl/drop.sql new file mode 100644 index 0000000..a3ca17b --- /dev/null +++ b/mlflow_cratedb/adapter/ddl/drop.sql @@ -0,0 +1,15 @@ +DROP TABLE IF EXISTS "{schema_name}"."datasets"; +DROP TABLE IF EXISTS "{schema_name}"."experiment_tags"; +DROP TABLE IF EXISTS "{schema_name}"."experiments"; +DROP TABLE IF EXISTS "{schema_name}"."inputs"; +DROP TABLE IF EXISTS "{schema_name}"."input_tags"; +DROP TABLE IF EXISTS "{schema_name}"."latest_metrics"; +DROP TABLE IF EXISTS "{schema_name}"."metrics"; +DROP TABLE IF EXISTS "{schema_name}"."model_versions"; +DROP TABLE IF EXISTS "{schema_name}"."model_version_tags"; +DROP TABLE IF EXISTS "{schema_name}"."params"; +DROP TABLE IF EXISTS "{schema_name}"."registered_models"; +DROP TABLE IF EXISTS "{schema_name}"."registered_model_aliases"; +DROP TABLE IF EXISTS "{schema_name}"."registered_model_tags"; +DROP TABLE IF EXISTS "{schema_name}"."runs"; +DROP TABLE IF EXISTS "{schema_name}"."tags"; diff --git a/mlflow_cratedb/adapter/util.py b/mlflow_cratedb/adapter/util.py new file mode 100644 index 0000000..8f4cf65 --- /dev/null +++ b/mlflow_cratedb/adapter/util.py @@ -0,0 +1,8 @@ +from vasuki import generate_nagamani19_int + + +def generate_unique_integer() -> int: + """ + Produce a short, unique, non-sequential identifier based on Hashids. + """ + return generate_nagamani19_int(size=10) diff --git a/mlflow_cratedb/cli.py b/mlflow_cratedb/cli.py new file mode 100644 index 0000000..2fbb14c --- /dev/null +++ b/mlflow_cratedb/cli.py @@ -0,0 +1,2 @@ +# Intercept CLI entrypoint for monkeypatching. +from mlflow.cli import cli # noqa: F401 diff --git a/mlflow_cratedb/monkey/__init__.py b/mlflow_cratedb/monkey/__init__.py new file mode 100644 index 0000000..f1c0b8f --- /dev/null +++ b/mlflow_cratedb/monkey/__init__.py @@ -0,0 +1,25 @@ +import logging + +from mlflow_cratedb.monkey.db_types import patch_dbtypes +from mlflow_cratedb.monkey.db_utils import patch_db_utils +from mlflow_cratedb.monkey.environment_variables import patch_environment_variables +from mlflow_cratedb.monkey.models import patch_models +from mlflow_cratedb.monkey.server import patch_run_server +from mlflow_cratedb.monkey.tracking import patch_sqlalchemy_store + +logger = logging.getLogger("mlflow") + +ANSI_YELLOW = "\033[93m" +ANSI_RESET = "\033[0m" + + +def patch_all(): + logger.info(f"{ANSI_YELLOW}Amalgamating MLflow for CrateDB{ANSI_RESET}") + logger.debug("To undo that, run `pip uninstall mlflow-cratedb`") + + patch_environment_variables() + patch_models() + patch_sqlalchemy_store() + patch_dbtypes() + patch_db_utils() + patch_run_server() diff --git a/mlflow_cratedb/monkey/db_types.py b/mlflow_cratedb/monkey/db_types.py new file mode 100644 index 0000000..5a1a4ef --- /dev/null +++ b/mlflow_cratedb/monkey/db_types.py @@ -0,0 +1,10 @@ +def patch_dbtypes(): + """ + Register CrateDB as available database type. + """ + import mlflow.store.db.db_types as db_types + + db_types.CRATEDB = "crate" + + if db_types.CRATEDB not in db_types.DATABASE_ENGINES: + db_types.DATABASE_ENGINES.append(db_types.CRATEDB) diff --git a/mlflow_cratedb/monkey/db_utils.py b/mlflow_cratedb/monkey/db_utils.py new file mode 100644 index 0000000..0e6ff6d --- /dev/null +++ b/mlflow_cratedb/monkey/db_utils.py @@ -0,0 +1,50 @@ +import typing as t + +import sqlalchemy as sa + + +def patch_db_utils(): + import mlflow.store.db.utils as db_utils + + db_utils._initialize_tables = _initialize_tables + db_utils._verify_schema = _verify_schema + + +def _initialize_tables(engine: sa.Engine): + """ + Skip SQLAlchemy schema provisioning and Alembic migrations. + Both don't play well with CrateDB. + """ + from mlflow.store.db.utils import _logger + + from mlflow_cratedb.adapter.db import _setup_db_create_tables + + patch_sqlalchemy_inspector(engine) + _logger.info("Creating initial MLflow database tables...") + _setup_db_create_tables(engine) + + +def _verify_schema(engine: sa.Engine): + """ + Skipping Alembic, that's a no-op. + """ + pass + + +def patch_sqlalchemy_inspector(engine: sa.Engine): + """ + When using `get_table_names()`, make sure the correct schema name gets used. + + TODO: Submit this to SQLAlchemy? + """ + get_table_names_dist = engine.dialect.get_table_names + schema_name = engine.url.query.get("schema") + if isinstance(schema_name, tuple): + schema_name = schema_name[0] + + def get_table_names(connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]: + if schema is None: + schema = schema_name + return get_table_names_dist(connection=connection, schema=schema, **kw) + + engine.dialect.get_table_names = get_table_names # type: ignore diff --git a/mlflow_cratedb/monkey/environment_variables.py b/mlflow_cratedb/monkey/environment_variables.py new file mode 100644 index 0000000..3da4d72 --- /dev/null +++ b/mlflow_cratedb/monkey/environment_variables.py @@ -0,0 +1,8 @@ +def patch_environment_variables(): + """ + Do not send multiple retrying HTTP requests only if connection is unstable. + """ + import mlflow.environment_variables as envvars + from mlflow.environment_variables import _EnvironmentVariable + + envvars.MLFLOW_HTTP_REQUEST_MAX_RETRIES = _EnvironmentVariable("MLFLOW_HTTP_REQUEST_MAX_RETRIES", int, 0) diff --git a/mlflow_cratedb/monkey/models.py b/mlflow_cratedb/monkey/models.py new file mode 100644 index 0000000..814bdb5 --- /dev/null +++ b/mlflow_cratedb/monkey/models.py @@ -0,0 +1,29 @@ +from abc import ABC + +from mlflow_cratedb.adapter.util import generate_unique_integer + + +def patch_models(): + """ + Configure SQLAlchemy model columns with an alternative to `autoincrement=True`. + + In this case, use a random identifier: Nagamani19, a short, unique, + non-sequential identifier based on Hashids. + """ + import sqlalchemy as sa + import sqlalchemy.sql.schema as schema + + ColumnDist: type = schema.Column + + class Column(ColumnDist, ABC): + inherit_cache = False + + def __init__(self, *args, **kwargs): + if "autoincrement" in kwargs: + del kwargs["autoincrement"] + if "default" not in kwargs: + kwargs["default"] = generate_unique_integer + ColumnDist.__init__(self, *args, **kwargs) # type: ignore + + schema.Column = Column # type: ignore + sa.Column = Column # type: ignore diff --git a/mlflow_cratedb/monkey/server.py b/mlflow_cratedb/monkey/server.py new file mode 100644 index 0000000..4ec15f1 --- /dev/null +++ b/mlflow_cratedb/monkey/server.py @@ -0,0 +1,34 @@ +# Use another WSGI application entrypoint instead of `mlflow.server:app`. +# It is defined in `pyproject.toml` at `[project.entry-points."mlflow.app"]`. +MLFLOW_APP_NAME = "mlflow-cratedb" + + +def patch_run_server(): + """ + Intercept `mlflow.server._run_server`, and set `--app-name` to + a wrapper application. This is needed to run the monkeypatching also + within the gunicorn workers. + """ + import mlflow.server as server + + _run_server_dist = server._run_server + + def run_server(*args, **kwargs): + args_dict = _get_args_dict(_run_server_dist, args, kwargs) + args_effective = list(args) + if "app_name" in args_dict and args_dict["app_name"] is None: + args_effective.pop() + kwargs["app_name"] = MLFLOW_APP_NAME + return _run_server_dist(*args_effective, **kwargs) + + server._run_server = run_server + + +def _get_args_dict(fn, args, kwargs): + """ + Returns a dictionary containing both args and kwargs. + + https://stackoverflow.com/a/40363565 + """ + args_names = fn.__code__.co_varnames[: fn.__code__.co_argcount] + return {**dict(zip(args_names, args)), **kwargs} diff --git a/mlflow_cratedb/monkey/tracking.py b/mlflow_cratedb/monkey/tracking.py new file mode 100644 index 0000000..03072a9 --- /dev/null +++ b/mlflow_cratedb/monkey/tracking.py @@ -0,0 +1,56 @@ +def patch_sqlalchemy_store(): + from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + SqlAlchemyStore.create_experiment = create_experiment + + +def create_experiment(self, name, artifact_location=None, tags=None): + """ + MLflow's `create_experiment`, but with a synchronization patch for CrateDB. + It is annotated with "Patch begin|end" in the code below. + """ + import sqlalchemy + from mlflow import MlflowException + from mlflow.entities import LifecycleStage + from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS + from mlflow.store.tracking.dbmodels.models import SqlExperiment, SqlExperimentTag + from mlflow.utils.time_utils import get_current_time_millis + from mlflow.utils.uri import resolve_uri_if_local + from mlflow.utils.validation import _validate_experiment_name + + _validate_experiment_name(name) + if artifact_location: + artifact_location = resolve_uri_if_local(artifact_location) + with self.ManagedSessionMaker() as session: + try: + creation_time = get_current_time_millis() + experiment = SqlExperiment( + name=name, + lifecycle_stage=LifecycleStage.ACTIVE, + artifact_location=artifact_location, + creation_time=creation_time, + last_update_time=creation_time, + ) + experiment.tags = [SqlExperimentTag(key=tag.key, value=tag.value) for tag in tags] if tags else [] + session.add(experiment) + + # Patch begin. + # TODO: Submit upstream? + session.flush() + + # TODO: This is specific to CrateDB. Implement as an SQLAlchemy hook in some way? + session.execute(sqlalchemy.text(f"REFRESH TABLE {SqlExperiment.__tablename__};")) + # Patch end. + + if not artifact_location: + # this requires a double write. The first one to generate an autoincrement-ed ID + eid = session.query(SqlExperiment).filter_by(name=name).first().experiment_id + experiment.artifact_location = self._get_artifact_location(eid) + except sqlalchemy.exc.IntegrityError as e: + raise MlflowException( + f"Experiment(name={name}) already exists. Error: {e}", + RESOURCE_ALREADY_EXISTS, + ) from e + + session.flush() + return str(experiment.experiment_id) diff --git a/mlflow_cratedb/server.py b/mlflow_cratedb/server.py new file mode 100644 index 0000000..58e83e2 --- /dev/null +++ b/mlflow_cratedb/server.py @@ -0,0 +1,2 @@ +# Intercept server entrypoint for monkeypatching. +from mlflow.server import app # noqa: F401 diff --git a/pyproject.toml b/pyproject.toml index 0d3fe58..3f0112b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,11 +12,11 @@ readme = "README.md" requires-python = ">=3.8" license = {text = "Apache License 2.0"} keywords = [ - "mlflow", "cratedb", - "mlops", - "mlflow-tracking", "machine learning", + "mlflow", + "mlflow-tracking", + "mlops", ] authors = [ {name = "Andreas Motl", email = "andreas.motl@crate.io"}, @@ -48,9 +48,9 @@ classifiers = [ dependencies = [ "crash", "crate[sqlalchemy]", - "mlflow==2.6.0", + "mlflow==2.6", "sqlparse<0.5", - "vasuki>=0.4,<1", + "vasuki<1,>=0.4", ] [project.optional-dependencies] @@ -71,6 +71,10 @@ test = [ "coverage<8", "pytest<8", ] +[project.scripts] +mlflow-cratedb = "mlflow_cratedb.cli:cli" +[project.entry-points."mlflow.app"] +mlflow-cratedb = "mlflow_cratedb.server:app" [tool.setuptools] # https://setuptools.pypa.io/en/latest/userguide/package_discovery.html packages = ["mlflow_cratedb"] @@ -80,7 +84,6 @@ changelog = "https://github.com/crate-workbench/mlflow-cratedb/blob/main/CHANGES documentation = "https://github.com/crate-workbench/mlflow-cratedb" homepage = "https://github.com/crate-workbench/mlflow-cratedb" repository = "https://github.com/crate-workbench/mlflow-cratedb" - [tool.black] line-length = 120 From 93af416650d65ed1051e1c72b1458687d5c16cf0 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 22:11:20 +0200 Subject: [PATCH 02/18] Documentation: Recommend to use a dedicated database schema, e.g. mlflow --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cb9c937..2ef5dc3 100644 --- a/README.md +++ b/README.md @@ -29,12 +29,16 @@ docker run --rm -it --publish=4200:4200 --publish=5432:5432 \ Start the MLflow server, pointing it to your [CrateDB] instance, running on `localhost`. ```shell -mlflow-cratedb server --backend-store-uri='crate://crate@localhost' --dev +mlflow-cratedb server --backend-store-uri='crate://crate@localhost/?schema=mlflow' --dev ``` Please note that you need to invoke the `mlflow-cratedb` command, which runs MLflow amalgamated with the necessary changes to support CrateDB. +Also note that we recommend to use a dedicated schema for storing MLflows +tables. In that spirit, the default schema `"doc"` is not populated by +tables of 3rd-party systems. + ## Development From 4c5891ba3083f87085dc8539c9dcea31ab305610 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 22:12:17 +0200 Subject: [PATCH 03/18] Improve running without dedicated database schema (None) --- mlflow_cratedb/adapter/db.py | 7 ++++-- mlflow_cratedb/adapter/ddl/cratedb.sql | 30 +++++++++++++------------- mlflow_cratedb/adapter/ddl/drop.sql | 30 +++++++++++++------------- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/mlflow_cratedb/adapter/db.py b/mlflow_cratedb/adapter/db.py index a7f4639..25398b2 100644 --- a/mlflow_cratedb/adapter/db.py +++ b/mlflow_cratedb/adapter/db.py @@ -15,9 +15,12 @@ def _setup_db_create_tables(engine: sa.Engine): TODO: Currently, the path is hardcoded to `cratedb.sql`. """ schema_name = engine.url.query.get("schema") + schema_prefix = "" + if schema_name is not None: + schema_prefix = f'"{schema_name}".' with importlib.resources.path("mlflow_cratedb.adapter", "ddl") as ddl: - schema = ddl.joinpath("cratedb.sql") - sql_statements = schema.read_text().format(schema_name=schema_name) + sql_file = ddl.joinpath("cratedb.sql") + sql_statements = sql_file.read_text().format(schema_prefix=schema_prefix) with engine.connect() as connection: for statement in sqlparse.split(sql_statements): connection.execute(sa.text(statement)) diff --git a/mlflow_cratedb/adapter/ddl/cratedb.sql b/mlflow_cratedb/adapter/ddl/cratedb.sql index 2efd362..a7f1be6 100644 --- a/mlflow_cratedb/adapter/ddl/cratedb.sql +++ b/mlflow_cratedb/adapter/ddl/cratedb.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS "{schema_name}"."datasets" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"datasets" ( "dataset_uuid" TEXT NOT NULL, "experiment_id" BIGINT NOT NULL, "name" TEXT NOT NULL, @@ -10,14 +10,14 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."datasets" ( PRIMARY KEY ("experiment_id", "name", "digest") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."experiment_tags" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"experiment_tags" ( "key" TEXT NOT NULL, "value" TEXT, "experiment_id" BIGINT NOT NULL, PRIMARY KEY ("key", "experiment_id") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."experiments" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"experiments" ( "experiment_id" BIGINT NOT NULL, -- default=autoincrement "name" TEXT NOT NULL, "artifact_location" TEXT, @@ -27,7 +27,7 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."experiments" ( PRIMARY KEY ("experiment_id") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."inputs" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"inputs" ( "input_uuid" TEXT NOT NULL, "source_type" TEXT NOT NULL, "source_id" TEXT NOT NULL, @@ -36,14 +36,14 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."inputs" ( PRIMARY KEY ("source_type", "source_id", "destination_type", "destination_id") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."input_tags" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"input_tags" ( "input_uuid" TEXT NOT NULL, "name" TEXT NOT NULL, "value" TEXT NOT NULL, PRIMARY KEY ("input_uuid", "name") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."latest_metrics" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"latest_metrics" ( "key" TEXT NOT NULL, "value" REAL NOT NULL, "timestamp" BIGINT NOT NULL, @@ -53,7 +53,7 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."latest_metrics" ( PRIMARY KEY ("key", "run_uuid") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."metrics" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"metrics" ( "key" TEXT NOT NULL, "value" REAL NOT NULL, "timestamp" BIGINT NOT NULL, @@ -63,7 +63,7 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."metrics" ( PRIMARY KEY ("key", "timestamp", "step", "run_uuid", "is_nan") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."model_versions" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"model_versions" ( "name" TEXT NOT NULL, "version" INTEGER NOT NULL, "creation_time" BIGINT, -- default=get_current_time_millis @@ -78,40 +78,40 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."model_versions" ( "status_message" TEXT ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."model_version_tags" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"model_version_tags" ( "name" TEXT NOT NULL, "version" INTEGER NOT NULL, "key" TEXT NOT NULL, "value" TEXT ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."params" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"params" ( "key" TEXT NOT NULL, "value" TEXT NOT NULL, "run_uuid" TEXT NOT NULL, PRIMARY KEY ("key", "run_uuid") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."registered_models" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"registered_models" ( "name" TEXT NOT NULL, "key" TEXT NOT NULL, "value" TEXT ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."registered_model_aliases" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"registered_model_aliases" ( "name" TEXT NOT NULL, "alias" TEXT NOT NULL, "version" TEXT NOT NULL ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."registered_model_tags" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"registered_model_tags" ( "name" TEXT NOT NULL, "creation_time" BIGINT, -- default=get_current_time_millis "last_update_time" BIGINT, -- default=get_current_time_millis "description" TEXT ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."runs" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"runs" ( "run_uuid" TEXT NOT NULL, "name" TEXT, "source_type" TEXT, @@ -129,7 +129,7 @@ CREATE TABLE IF NOT EXISTS "{schema_name}"."runs" ( PRIMARY KEY ("run_uuid") ); -CREATE TABLE IF NOT EXISTS "{schema_name}"."tags" ( +CREATE TABLE IF NOT EXISTS {schema_prefix}"tags" ( "key" TEXT NOT NULL, "value" TEXT, "run_uuid" TEXT NOT NULL, diff --git a/mlflow_cratedb/adapter/ddl/drop.sql b/mlflow_cratedb/adapter/ddl/drop.sql index a3ca17b..5c73c22 100644 --- a/mlflow_cratedb/adapter/ddl/drop.sql +++ b/mlflow_cratedb/adapter/ddl/drop.sql @@ -1,15 +1,15 @@ -DROP TABLE IF EXISTS "{schema_name}"."datasets"; -DROP TABLE IF EXISTS "{schema_name}"."experiment_tags"; -DROP TABLE IF EXISTS "{schema_name}"."experiments"; -DROP TABLE IF EXISTS "{schema_name}"."inputs"; -DROP TABLE IF EXISTS "{schema_name}"."input_tags"; -DROP TABLE IF EXISTS "{schema_name}"."latest_metrics"; -DROP TABLE IF EXISTS "{schema_name}"."metrics"; -DROP TABLE IF EXISTS "{schema_name}"."model_versions"; -DROP TABLE IF EXISTS "{schema_name}"."model_version_tags"; -DROP TABLE IF EXISTS "{schema_name}"."params"; -DROP TABLE IF EXISTS "{schema_name}"."registered_models"; -DROP TABLE IF EXISTS "{schema_name}"."registered_model_aliases"; -DROP TABLE IF EXISTS "{schema_name}"."registered_model_tags"; -DROP TABLE IF EXISTS "{schema_name}"."runs"; -DROP TABLE IF EXISTS "{schema_name}"."tags"; +DROP TABLE IF EXISTS {schema_prefix}"datasets"; +DROP TABLE IF EXISTS {schema_prefix}"experiment_tags"; +DROP TABLE IF EXISTS {schema_prefix}"experiments"; +DROP TABLE IF EXISTS {schema_prefix}"inputs"; +DROP TABLE IF EXISTS {schema_prefix}"input_tags"; +DROP TABLE IF EXISTS {schema_prefix}"latest_metrics"; +DROP TABLE IF EXISTS {schema_prefix}"metrics"; +DROP TABLE IF EXISTS {schema_prefix}"model_versions"; +DROP TABLE IF EXISTS {schema_prefix}"model_version_tags"; +DROP TABLE IF EXISTS {schema_prefix}"params"; +DROP TABLE IF EXISTS {schema_prefix}"registered_models"; +DROP TABLE IF EXISTS {schema_prefix}"registered_model_aliases"; +DROP TABLE IF EXISTS {schema_prefix}"registered_model_tags"; +DROP TABLE IF EXISTS {schema_prefix}"runs"; +DROP TABLE IF EXISTS {schema_prefix}"tags"; From 14d7c0d80f6686e4f313e03ae3ce5c02c108aace Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 22:13:28 +0200 Subject: [PATCH 04/18] Run database provisioning only once per process instance --- mlflow_cratedb/monkey/db_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlflow_cratedb/monkey/db_utils.py b/mlflow_cratedb/monkey/db_utils.py index 0e6ff6d..5c22be4 100644 --- a/mlflow_cratedb/monkey/db_utils.py +++ b/mlflow_cratedb/monkey/db_utils.py @@ -1,3 +1,4 @@ +import functools import typing as t import sqlalchemy as sa @@ -10,6 +11,7 @@ def patch_db_utils(): db_utils._verify_schema = _verify_schema +@functools.cache def _initialize_tables(engine: sa.Engine): """ Skip SQLAlchemy schema provisioning and Alembic migrations. From fb8c7404cf6d695fccd747be5ccdc2ab0860085d Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 22:17:19 +0200 Subject: [PATCH 05/18] Transparently invoke `REFRESH TABLE` after inserts, updates, and deletes Remove previous strategy, where hooking into the corresponding database wrapper function was extremely invasive, and not sustainable. --- mlflow_cratedb/adapter/db.py | 28 ++++++++++++++++ mlflow_cratedb/monkey/__init__.py | 2 -- mlflow_cratedb/monkey/db_utils.py | 3 ++ mlflow_cratedb/monkey/tracking.py | 56 ------------------------------- 4 files changed, 31 insertions(+), 58 deletions(-) delete mode 100644 mlflow_cratedb/monkey/tracking.py diff --git a/mlflow_cratedb/adapter/db.py b/mlflow_cratedb/adapter/db.py index 25398b2..000e9b9 100644 --- a/mlflow_cratedb/adapter/db.py +++ b/mlflow_cratedb/adapter/db.py @@ -2,6 +2,7 @@ import sqlalchemy as sa import sqlparse +from sqlalchemy.event import listen def _setup_db_create_tables(engine: sa.Engine): @@ -31,3 +32,30 @@ def _setup_db_drop_tables(): TODO: Not implemented yet. """ pass + + +def enable_refresh_after_dml(): + """ + Run `REFRESH TABLE ` after each INSERT, UPDATE, and DELETE operation. + + CrateDB is eventually consistent, i.e. write operations are not flushed to + disk immediately, so readers may see stale data. In a traditional OLTP-like + application, this is not applicable. + + This SQLAlchemy extension makes sure that data is synchronized after each + operation manipulating data. + """ + from mlflow.store.db.base_sql_model import Base + + for mapper in Base.registry.mappers: + listen(mapper.class_, "after_insert", do_refresh) + listen(mapper.class_, "after_update", do_refresh) + listen(mapper.class_, "after_delete", do_refresh) + + +def do_refresh(mapper, connection, target): + """ + SQLAlchemy event handler for `after_{insert,update,delete}` events, invoking `REFRESH TABLE`. + """ + sql = f"REFRESH TABLE {target.__tablename__}" + connection.execute(sa.text(sql)) diff --git a/mlflow_cratedb/monkey/__init__.py b/mlflow_cratedb/monkey/__init__.py index f1c0b8f..5390fd6 100644 --- a/mlflow_cratedb/monkey/__init__.py +++ b/mlflow_cratedb/monkey/__init__.py @@ -5,7 +5,6 @@ from mlflow_cratedb.monkey.environment_variables import patch_environment_variables from mlflow_cratedb.monkey.models import patch_models from mlflow_cratedb.monkey.server import patch_run_server -from mlflow_cratedb.monkey.tracking import patch_sqlalchemy_store logger = logging.getLogger("mlflow") @@ -19,7 +18,6 @@ def patch_all(): patch_environment_variables() patch_models() - patch_sqlalchemy_store() patch_dbtypes() patch_db_utils() patch_run_server() diff --git a/mlflow_cratedb/monkey/db_utils.py b/mlflow_cratedb/monkey/db_utils.py index 5c22be4..eca26fa 100644 --- a/mlflow_cratedb/monkey/db_utils.py +++ b/mlflow_cratedb/monkey/db_utils.py @@ -3,6 +3,8 @@ import sqlalchemy as sa +from mlflow_cratedb.adapter.db import enable_refresh_after_dml + def patch_db_utils(): import mlflow.store.db.utils as db_utils @@ -21,6 +23,7 @@ def _initialize_tables(engine: sa.Engine): from mlflow_cratedb.adapter.db import _setup_db_create_tables + enable_refresh_after_dml() patch_sqlalchemy_inspector(engine) _logger.info("Creating initial MLflow database tables...") _setup_db_create_tables(engine) diff --git a/mlflow_cratedb/monkey/tracking.py b/mlflow_cratedb/monkey/tracking.py deleted file mode 100644 index 03072a9..0000000 --- a/mlflow_cratedb/monkey/tracking.py +++ /dev/null @@ -1,56 +0,0 @@ -def patch_sqlalchemy_store(): - from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - - SqlAlchemyStore.create_experiment = create_experiment - - -def create_experiment(self, name, artifact_location=None, tags=None): - """ - MLflow's `create_experiment`, but with a synchronization patch for CrateDB. - It is annotated with "Patch begin|end" in the code below. - """ - import sqlalchemy - from mlflow import MlflowException - from mlflow.entities import LifecycleStage - from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS - from mlflow.store.tracking.dbmodels.models import SqlExperiment, SqlExperimentTag - from mlflow.utils.time_utils import get_current_time_millis - from mlflow.utils.uri import resolve_uri_if_local - from mlflow.utils.validation import _validate_experiment_name - - _validate_experiment_name(name) - if artifact_location: - artifact_location = resolve_uri_if_local(artifact_location) - with self.ManagedSessionMaker() as session: - try: - creation_time = get_current_time_millis() - experiment = SqlExperiment( - name=name, - lifecycle_stage=LifecycleStage.ACTIVE, - artifact_location=artifact_location, - creation_time=creation_time, - last_update_time=creation_time, - ) - experiment.tags = [SqlExperimentTag(key=tag.key, value=tag.value) for tag in tags] if tags else [] - session.add(experiment) - - # Patch begin. - # TODO: Submit upstream? - session.flush() - - # TODO: This is specific to CrateDB. Implement as an SQLAlchemy hook in some way? - session.execute(sqlalchemy.text(f"REFRESH TABLE {SqlExperiment.__tablename__};")) - # Patch end. - - if not artifact_location: - # this requires a double write. The first one to generate an autoincrement-ed ID - eid = session.query(SqlExperiment).filter_by(name=name).first().experiment_id - experiment.artifact_location = self._get_artifact_location(eid) - except sqlalchemy.exc.IntegrityError as e: - raise MlflowException( - f"Experiment(name={name}) already exists. Error: {e}", - RESOURCE_ALREADY_EXISTS, - ) from e - - session.flush() - return str(experiment.experiment_id) From aefc5fd803f60f7e0b41bc64b8c84a2f1a397eb2 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 12:17:05 +0200 Subject: [PATCH 06/18] Fix CrateDB database schema DDL - Use `DOUBLE` instead of `REAL` - Add missing `metrics.value` to primary key --- mlflow_cratedb/adapter/ddl/cratedb.sql | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlflow_cratedb/adapter/ddl/cratedb.sql b/mlflow_cratedb/adapter/ddl/cratedb.sql index a7f1be6..e8e8182 100644 --- a/mlflow_cratedb/adapter/ddl/cratedb.sql +++ b/mlflow_cratedb/adapter/ddl/cratedb.sql @@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"input_tags" ( CREATE TABLE IF NOT EXISTS {schema_prefix}"latest_metrics" ( "key" TEXT NOT NULL, - "value" REAL NOT NULL, + "value" DOUBLE NOT NULL, "timestamp" BIGINT NOT NULL, "step" BIGINT NOT NULL, "is_nan" BOOLEAN NOT NULL, @@ -55,12 +55,12 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"latest_metrics" ( CREATE TABLE IF NOT EXISTS {schema_prefix}"metrics" ( "key" TEXT NOT NULL, - "value" REAL NOT NULL, + "value" DOUBLE NOT NULL, "timestamp" BIGINT NOT NULL, "step" BIGINT NOT NULL, "is_nan" BOOLEAN NOT NULL, "run_uuid" TEXT NOT NULL, - PRIMARY KEY ("key", "timestamp", "step", "run_uuid", "is_nan") + PRIMARY KEY ("key", "timestamp", "step", "run_uuid", "value", "is_nan") ); CREATE TABLE IF NOT EXISTS {schema_prefix}"model_versions" ( From c7ff38f2c02959e6c1b1098a1ae6f04144348930 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 12:49:52 +0200 Subject: [PATCH 07/18] SA: Improve autoincrement polyfill --- mlflow_cratedb/monkey/models.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/mlflow_cratedb/monkey/models.py b/mlflow_cratedb/monkey/models.py index 814bdb5..507ef74 100644 --- a/mlflow_cratedb/monkey/models.py +++ b/mlflow_cratedb/monkey/models.py @@ -1,5 +1,3 @@ -from abc import ABC - from mlflow_cratedb.adapter.util import generate_unique_integer @@ -9,21 +7,19 @@ def patch_models(): In this case, use a random identifier: Nagamani19, a short, unique, non-sequential identifier based on Hashids. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_translate_autoincrement` or such. """ - import sqlalchemy as sa import sqlalchemy.sql.schema as schema - ColumnDist: type = schema.Column - - class Column(ColumnDist, ABC): - inherit_cache = False + init_dist = schema.Column.__init__ - def __init__(self, *args, **kwargs): - if "autoincrement" in kwargs: - del kwargs["autoincrement"] - if "default" not in kwargs: - kwargs["default"] = generate_unique_integer - ColumnDist.__init__(self, *args, **kwargs) # type: ignore + def __init__(self, *args, **kwargs): + if "autoincrement" in kwargs: + del kwargs["autoincrement"] + if "default" not in kwargs: + kwargs["default"] = generate_unique_integer + init_dist(self, *args, **kwargs) - schema.Column = Column # type: ignore - sa.Column = Column # type: ignore + schema.Column.__init__ = __init__ # type: ignore[method-assign] From 089f4c71103fd2aee7c254f64895477196fcbe5c Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 12:50:50 +0200 Subject: [PATCH 08/18] SA: Add patch to remove `FOR UPDATE` clauses --- mlflow_cratedb/monkey/models.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mlflow_cratedb/monkey/models.py b/mlflow_cratedb/monkey/models.py index 507ef74..f39a05c 100644 --- a/mlflow_cratedb/monkey/models.py +++ b/mlflow_cratedb/monkey/models.py @@ -23,3 +23,20 @@ def __init__(self, *args, **kwargs): init_dist(self, *args, **kwargs) schema.Column.__init__ = __init__ # type: ignore[method-assign] + + +def patch_compiler(): + """ + Patch CrateDB SQLAlchemy dialect to not omit the `FOR UPDATE` clause on + `SELECT ... FOR UPDATE` statements. + + https://github.com/crate-workbench/mlflow-cratedb/issues/7 + + TODO: Submit to `crate-python` as a bugfix patch. + """ + from crate.client.sqlalchemy.compiler import CrateCompiler + + def for_update_clause(self, select, **kw): + return "" + + CrateCompiler.for_update_clause = for_update_clause From 4a6d6c6ecb6e6f8ba5eb7e4dd9097a7bfd44cb40 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 13:51:38 +0200 Subject: [PATCH 09/18] Add more patches and polyfills --- mlflow_cratedb/adapter/db.py | 3 + mlflow_cratedb/monkey/__init__.py | 10 +- mlflow_cratedb/monkey/db_utils.py | 3 +- mlflow_cratedb/monkey/mlflow/__init__.py | 12 ++ mlflow_cratedb/monkey/mlflow/model.py | 11 ++ mlflow_cratedb/monkey/mlflow/search_utils.py | 27 ++++ mlflow_cratedb/monkey/mlflow/tracking.py | 157 +++++++++++++++++++ mlflow_cratedb/patch/__init__.py | 0 mlflow_cratedb/patch/crate_python.py | 59 +++++++ 9 files changed, 280 insertions(+), 2 deletions(-) create mode 100644 mlflow_cratedb/monkey/mlflow/__init__.py create mode 100644 mlflow_cratedb/monkey/mlflow/model.py create mode 100644 mlflow_cratedb/monkey/mlflow/search_utils.py create mode 100644 mlflow_cratedb/monkey/mlflow/tracking.py create mode 100644 mlflow_cratedb/patch/__init__.py create mode 100644 mlflow_cratedb/patch/crate_python.py diff --git a/mlflow_cratedb/adapter/db.py b/mlflow_cratedb/adapter/db.py index 000e9b9..c040966 100644 --- a/mlflow_cratedb/adapter/db.py +++ b/mlflow_cratedb/adapter/db.py @@ -44,6 +44,9 @@ def enable_refresh_after_dml(): This SQLAlchemy extension makes sure that data is synchronized after each operation manipulating data. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_dml_refresh` or such. """ from mlflow.store.db.base_sql_model import Base diff --git a/mlflow_cratedb/monkey/__init__.py b/mlflow_cratedb/monkey/__init__.py index 5390fd6..da3cbf7 100644 --- a/mlflow_cratedb/monkey/__init__.py +++ b/mlflow_cratedb/monkey/__init__.py @@ -3,8 +3,10 @@ from mlflow_cratedb.monkey.db_types import patch_dbtypes from mlflow_cratedb.monkey.db_utils import patch_db_utils from mlflow_cratedb.monkey.environment_variables import patch_environment_variables -from mlflow_cratedb.monkey.models import patch_models +from mlflow_cratedb.monkey.mlflow import patch_mlflow +from mlflow_cratedb.monkey.models import patch_compiler, patch_models from mlflow_cratedb.monkey.server import patch_run_server +from mlflow_cratedb.patch.crate_python import patch_raise_for_status logger = logging.getLogger("mlflow") @@ -16,8 +18,14 @@ def patch_all(): logger.info(f"{ANSI_YELLOW}Amalgamating MLflow for CrateDB{ANSI_RESET}") logger.debug("To undo that, run `pip uninstall mlflow-cratedb`") + # crate-python + patch_raise_for_status() + + # MLflow patch_environment_variables() + patch_compiler() patch_models() patch_dbtypes() patch_db_utils() + patch_mlflow() patch_run_server() diff --git a/mlflow_cratedb/monkey/db_utils.py b/mlflow_cratedb/monkey/db_utils.py index eca26fa..880d9ca 100644 --- a/mlflow_cratedb/monkey/db_utils.py +++ b/mlflow_cratedb/monkey/db_utils.py @@ -9,6 +9,7 @@ def patch_db_utils(): import mlflow.store.db.utils as db_utils + enable_refresh_after_dml() db_utils._initialize_tables = _initialize_tables db_utils._verify_schema = _verify_schema @@ -23,7 +24,6 @@ def _initialize_tables(engine: sa.Engine): from mlflow_cratedb.adapter.db import _setup_db_create_tables - enable_refresh_after_dml() patch_sqlalchemy_inspector(engine) _logger.info("Creating initial MLflow database tables...") _setup_db_create_tables(engine) @@ -40,6 +40,7 @@ def patch_sqlalchemy_inspector(engine: sa.Engine): """ When using `get_table_names()`, make sure the correct schema name gets used. + TODO: Verify if this is really needed. SQLAlchemy should use the `search_path` properly already. TODO: Submit this to SQLAlchemy? """ get_table_names_dist = engine.dialect.get_table_names diff --git a/mlflow_cratedb/monkey/mlflow/__init__.py b/mlflow_cratedb/monkey/mlflow/__init__.py new file mode 100644 index 0000000..d6d76b1 --- /dev/null +++ b/mlflow_cratedb/monkey/mlflow/__init__.py @@ -0,0 +1,12 @@ +from mlflow_cratedb.monkey.mlflow.model import polyfill_uniqueness_constraints +from mlflow_cratedb.monkey.mlflow.search_utils import patch_mlflow_search_utils +from mlflow_cratedb.monkey.mlflow.tracking import patch_mlflow_tracking + + +def patch_mlflow(): + """ + Patch the MLflow package. + """ + polyfill_uniqueness_constraints() + patch_mlflow_search_utils() + patch_mlflow_tracking() diff --git a/mlflow_cratedb/monkey/mlflow/model.py b/mlflow_cratedb/monkey/mlflow/model.py new file mode 100644 index 0000000..5480e9d --- /dev/null +++ b/mlflow_cratedb/monkey/mlflow/model.py @@ -0,0 +1,11 @@ +from mlflow_cratedb.patch.crate_python import check_uniqueness_factory + + +def polyfill_uniqueness_constraints(): + """ + Establish a manual uniqueness check on the `SqlExperiment.name` column. + """ + from mlflow.store.tracking.dbmodels.models import SqlExperiment + from sqlalchemy.event import listen + + listen(SqlExperiment, "before_insert", check_uniqueness_factory(SqlExperiment, "name")) diff --git a/mlflow_cratedb/monkey/mlflow/search_utils.py b/mlflow_cratedb/monkey/mlflow/search_utils.py new file mode 100644 index 0000000..8b7dfc4 --- /dev/null +++ b/mlflow_cratedb/monkey/mlflow/search_utils.py @@ -0,0 +1,27 @@ +def patch_mlflow_search_utils(): + """ + Patch MLflow's `SearchUtils` to return a comparison function for CrateDB. + """ + from mlflow.utils.search_utils import SearchUtils + + get_sql_comparison_func_dist = SearchUtils.get_sql_comparison_func + + def get_sql_comparison_func(comparator, dialect): + try: + return get_sql_comparison_func_dist(comparator, dialect) + except KeyError: + + def comparison_func(column, value): + if comparator == "LIKE": + return column.like(value) + elif comparator == "ILIKE": # noqa: RET505 + return column.ilike(value) + elif comparator == "IN": + return column.in_(value) + elif comparator == "NOT IN": + return ~column.in_(value) + return SearchUtils.get_comparison_func(comparator)(column, value) + + return comparison_func + + SearchUtils.get_sql_comparison_func = get_sql_comparison_func diff --git a/mlflow_cratedb/monkey/mlflow/tracking.py b/mlflow_cratedb/monkey/mlflow/tracking.py new file mode 100644 index 0000000..76ad947 --- /dev/null +++ b/mlflow_cratedb/monkey/mlflow/tracking.py @@ -0,0 +1,157 @@ +import math +from functools import partial + + +def patch_mlflow_tracking(): + """ + Patch the experiment tracking subsystem of MLflow. + """ + patch_create_default_experiment() + patch_get_orderby_clauses() + patch_search_runs() + + +def patch_create_default_experiment(): + """ + The `_create_default_experiment` function runs an SQL query sidetracking + the SQLAlchemy ORM. Thus, it needs to be explicitly patched to invoke + a corresponding `REFRESH TABLE` statement afterwards. + """ + import sqlalchemy as sa + from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + create_default_experiment_dist = SqlAlchemyStore._create_default_experiment + + def _create_default_experiment(self, session): + from mlflow.store.tracking.dbmodels.models import SqlExperiment + + outcome = create_default_experiment_dist(self, session) + session.execute(sa.text(f"REFRESH TABLE {SqlExperiment.__tablename__}")) + return outcome + + SqlAlchemyStore._create_default_experiment = _create_default_experiment + + +def patch_get_orderby_clauses(): + """ + MLflow's `_get_orderby_clauses` adds an `sql.case(...)` clause, which CrateDB does not understand. + + https://github.com/crate-workbench/mlflow-cratedb/issues/8 + """ + import mlflow.store.tracking.sqlalchemy_store as sqlalchemy_store + + _get_orderby_clauses_dist = sqlalchemy_store._get_orderby_clauses + + def filter_case_clauses(items): + new_list = [] + for item in items: + label = None + if isinstance(item, str): + label = item + elif hasattr(item, "name"): + label = item.name + if label is None or not label.startswith("clause_"): + new_list.append(item) + return new_list + + def _get_orderby_clauses(order_by_list, session): + cases_orderby, parsed_orderby, sorting_joins = _get_orderby_clauses_dist(order_by_list, session) + cases_orderby = filter_case_clauses(cases_orderby) + parsed_orderby = filter_case_clauses(parsed_orderby) + return cases_orderby, parsed_orderby, sorting_joins + + sqlalchemy_store._get_orderby_clauses = _get_orderby_clauses + + +def patch_search_runs(): + """ + Patch MLflow's `_search_runs` function to invoke `fix_sort_order` afterwards, + compensating the other patch to `_get_orderby_clauses`. + """ + from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + search_runs_dist = SqlAlchemyStore._search_runs + + def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): + runs_with_inputs, next_page_token = search_runs_dist( + self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token + ) + runs_with_inputs = fix_sort_order(order_by, runs_with_inputs) + return runs_with_inputs, next_page_token + + SqlAlchemyStore._search_runs = _search_runs + + +def fix_sort_order(order_by, runs_with_inputs): + """ + Attempts to fix the sort order of returned tracking results, trying to + compensate the patch to MLflow's `_get_orderby_clauses`. + + Covered by test cases `test_order_by_attributes` and `test_order_by_metric_tag_param`. + """ + import functools + + from mlflow.utils.search_utils import SearchUtils + + if order_by is None: + return runs_with_inputs + + def attribute_getter(key, item): + return getattr(item.info, key) + + def metrics_getter(key, item): + return item.data.metrics.get(key) + + def tags_getter(key, item): + return item.data.tags.get(key) + + def parameters_getter(key, item): + return item.data.params.get(key) + + def isnan(value): + return value is None or (isinstance(value, float) and math.isnan(value)) + + def compare_special(getter, i1, i2): + """ + Comparison function which can accept None or NaN values. + + Otherwise, Python would raise:: + + TypeError: '<' not supported between instances of 'NoneType' and 'int' + """ + # + i1 = getter(i1) + i2 = getter(i2) + if isnan(i1): + return 1 + if isnan(i2): + return -1 + i1 = str(i1).lower() + i2 = str(i2).lower() + return i1 < i2 + + attribute_order_count = 0 + for order_by_clause in order_by: + # Special case: When sorting using multiple attributes, something goes south. + # So, limit (proper) sorting to the use of a single attribute only. + if attribute_order_count >= 1: + continue + + (key_type, key, ascending) = SearchUtils.parse_order_by_for_search_runs(order_by_clause) + key_translated = SearchUtils.translate_key_alias(key) + + if key_type == "attribute": + getter = partial(attribute_getter, key_translated) + attribute_order_count += 1 + elif key_type == "metric": + getter = partial(metrics_getter, key_translated) + elif key_type == "tag": + getter = partial(tags_getter, key_translated) + elif key_type == "parameter": + getter = partial(parameters_getter, key_translated) + else: + raise NotImplementedError(f"Need to implement getter for key_type={key_type}. clause={order_by_clause}") + + runs_with_inputs = sorted(runs_with_inputs, key=functools.cmp_to_key(partial(compare_special, getter))) + + return runs_with_inputs diff --git a/mlflow_cratedb/patch/__init__.py b/mlflow_cratedb/patch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlflow_cratedb/patch/crate_python.py b/mlflow_cratedb/patch/crate_python.py new file mode 100644 index 0000000..d84fbbf --- /dev/null +++ b/mlflow_cratedb/patch/crate_python.py @@ -0,0 +1,59 @@ +def patch_raise_for_status(): + """ + Patch the `crate.client.http._raise_for_status` function to properly raise + SQLAlchemy's `IntegrityError` exceptions for CrateDB's `DuplicateKeyException` + errors. + + It is needed to make the `check_uniqueness` machinery work, which is emulating + UNIQUE constraints on table columns. + + https://github.com/crate-workbench/mlflow-cratedb/issues/9 + + TODO: Submit to `crate-python` as a bugfix patch. + """ + import crate.client.http as http + + _raise_for_status_dist = http._raise_for_status + + def _raise_for_status(response): + from crate.client.exceptions import IntegrityError, ProgrammingError + + try: + return _raise_for_status_dist(response) + except ProgrammingError as ex: + if "DuplicateKeyException" in ex.message: + raise IntegrityError(ex.message, error_trace=ex.error_trace) from ex + raise + + http._raise_for_status = _raise_for_status + + +def check_uniqueness_factory(sa_entity, attribute_name): + """ + Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable. + + CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_translate_unique` or such. + """ + + def check_uniqueness(mapper, connection, target): + from sqlalchemy.exc import IntegrityError + + if isinstance(target, sa_entity): + # TODO: How to use `session.query(SqlExperiment)` here? + stmt = ( + mapper.selectable.select() + .filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) + .compile(bind=connection.engine) + ) + results = connection.execute(stmt) + if results.rowcount > 0: + raise IntegrityError( + statement=stmt, + params=[], + orig=Exception(f"DuplicateKeyException on column: {target.__tablename__}.{attribute_name}"), + ) + + return check_uniqueness From e8a70844a6b2a89cc852ff58aba945d36ddbf399 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 13:51:44 +0200 Subject: [PATCH 10/18] Sandbox: Add `poe check-fast` vs. `poe check` There will be tests marked with `pytest.mark.slow`. `poe check-fast` will omit them. --- README.md | 9 +++++++-- docs/backlog.md | 1 + pyproject.toml | 6 ++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2ef5dc3..0b21a75 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,15 @@ source .venv/bin/activate pip install --editable='.[develop,docs,test]' ``` -Run linters and software tests: +Run linters and software tests, skipping slow tests: ```shell source .venv/bin/activate -poe check +poe check-fast +``` + +Exclusively run "slow" tests. +```shell +pytest -m slow ``` diff --git a/docs/backlog.md b/docs/backlog.md index 6827573..a45af63 100644 --- a/docs/backlog.md +++ b/docs/backlog.md @@ -5,3 +5,4 @@ - Run an MLflow project from the given URI, using `mlflow run` - Explore `mlflow experiments search` for testing purposes - CLI shortcut for `ddl/drop.sql` +- Use `search_path` instead of populating the schema name into SQL DDL files diff --git a/pyproject.toml b/pyproject.toml index 3f0112b..51dd95f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,8 @@ log_cli_level = "DEBUG" testpaths = ["tests"] xfail_strict = true markers = [ + "notrackingurimock", + "slow", ] [tool.coverage.run] @@ -186,8 +188,12 @@ lint = [ test = [ { cmd = "pytest" }, ] +test-fast = [ + { cmd = "pytest -m 'not slow'" }, +] build = { cmd = "python -m build" } check = ["lint", "test"] +check-fast = ["lint", "test-fast"] release = [ { cmd = "minibump bump --relax patch" }, From 83883c0145a7f4dbed9c6a362e795cba388b8f96 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 14:03:45 +0200 Subject: [PATCH 11/18] Refactor adapters and patches --- mlflow_cratedb/__init__.py | 2 +- mlflow_cratedb/boot.py | 17 +++++++ mlflow_cratedb/monkey/__init__.py | 31 ------------ mlflow_cratedb/monkey/mlflow/__init__.py | 12 ----- mlflow_cratedb/monkey/models.py | 42 ---------------- mlflow_cratedb/patch/crate_python.py | 50 +++++++++++++++++++ mlflow_cratedb/patch/mlflow/__init__.py | 20 ++++++++ .../{monkey => patch/mlflow}/db_types.py | 0 .../{monkey => patch/mlflow}/db_utils.py | 22 +------- .../{monkey => patch}/mlflow/model.py | 0 .../{monkey => patch}/mlflow/search_utils.py | 2 +- .../{monkey => patch/mlflow}/server.py | 0 .../mlflow/settings.py} | 0 .../{monkey => patch}/mlflow/tracking.py | 2 +- mlflow_cratedb/patch/sqlalchemy.py | 23 +++++++++ 15 files changed, 114 insertions(+), 109 deletions(-) create mode 100644 mlflow_cratedb/boot.py delete mode 100644 mlflow_cratedb/monkey/__init__.py delete mode 100644 mlflow_cratedb/monkey/mlflow/__init__.py delete mode 100644 mlflow_cratedb/monkey/models.py create mode 100644 mlflow_cratedb/patch/mlflow/__init__.py rename mlflow_cratedb/{monkey => patch/mlflow}/db_types.py (100%) rename mlflow_cratedb/{monkey => patch/mlflow}/db_utils.py (50%) rename mlflow_cratedb/{monkey => patch}/mlflow/model.py (100%) rename mlflow_cratedb/{monkey => patch}/mlflow/search_utils.py (96%) rename mlflow_cratedb/{monkey => patch/mlflow}/server.py (100%) rename mlflow_cratedb/{monkey/environment_variables.py => patch/mlflow/settings.py} (100%) rename mlflow_cratedb/{monkey => patch}/mlflow/tracking.py (99%) create mode 100644 mlflow_cratedb/patch/sqlalchemy.py diff --git a/mlflow_cratedb/__init__.py b/mlflow_cratedb/__init__.py index 6c2eec6..85d6d27 100644 --- a/mlflow_cratedb/__init__.py +++ b/mlflow_cratedb/__init__.py @@ -1,6 +1,6 @@ from mlflow.utils import logging_utils -from mlflow_cratedb.monkey import patch_all +from mlflow_cratedb.boot import patch_all # Enable logging, and activate monkeypatch. logging_utils.enable_logging() diff --git a/mlflow_cratedb/boot.py b/mlflow_cratedb/boot.py new file mode 100644 index 0000000..436af52 --- /dev/null +++ b/mlflow_cratedb/boot.py @@ -0,0 +1,17 @@ +import logging + +from mlflow_cratedb.patch.crate_python import patch_crate_python +from mlflow_cratedb.patch.mlflow import patch_mlflow + +logger = logging.getLogger("mlflow") + +ANSI_YELLOW = "\033[93m" +ANSI_RESET = "\033[0m" + + +def patch_all(): + logger.info(f"{ANSI_YELLOW}Amalgamating MLflow for CrateDB{ANSI_RESET}") + logger.debug("To undo that, run `pip uninstall mlflow-cratedb`") + + patch_crate_python() + patch_mlflow() diff --git a/mlflow_cratedb/monkey/__init__.py b/mlflow_cratedb/monkey/__init__.py deleted file mode 100644 index da3cbf7..0000000 --- a/mlflow_cratedb/monkey/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from mlflow_cratedb.monkey.db_types import patch_dbtypes -from mlflow_cratedb.monkey.db_utils import patch_db_utils -from mlflow_cratedb.monkey.environment_variables import patch_environment_variables -from mlflow_cratedb.monkey.mlflow import patch_mlflow -from mlflow_cratedb.monkey.models import patch_compiler, patch_models -from mlflow_cratedb.monkey.server import patch_run_server -from mlflow_cratedb.patch.crate_python import patch_raise_for_status - -logger = logging.getLogger("mlflow") - -ANSI_YELLOW = "\033[93m" -ANSI_RESET = "\033[0m" - - -def patch_all(): - logger.info(f"{ANSI_YELLOW}Amalgamating MLflow for CrateDB{ANSI_RESET}") - logger.debug("To undo that, run `pip uninstall mlflow-cratedb`") - - # crate-python - patch_raise_for_status() - - # MLflow - patch_environment_variables() - patch_compiler() - patch_models() - patch_dbtypes() - patch_db_utils() - patch_mlflow() - patch_run_server() diff --git a/mlflow_cratedb/monkey/mlflow/__init__.py b/mlflow_cratedb/monkey/mlflow/__init__.py deleted file mode 100644 index d6d76b1..0000000 --- a/mlflow_cratedb/monkey/mlflow/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from mlflow_cratedb.monkey.mlflow.model import polyfill_uniqueness_constraints -from mlflow_cratedb.monkey.mlflow.search_utils import patch_mlflow_search_utils -from mlflow_cratedb.monkey.mlflow.tracking import patch_mlflow_tracking - - -def patch_mlflow(): - """ - Patch the MLflow package. - """ - polyfill_uniqueness_constraints() - patch_mlflow_search_utils() - patch_mlflow_tracking() diff --git a/mlflow_cratedb/monkey/models.py b/mlflow_cratedb/monkey/models.py deleted file mode 100644 index f39a05c..0000000 --- a/mlflow_cratedb/monkey/models.py +++ /dev/null @@ -1,42 +0,0 @@ -from mlflow_cratedb.adapter.util import generate_unique_integer - - -def patch_models(): - """ - Configure SQLAlchemy model columns with an alternative to `autoincrement=True`. - - In this case, use a random identifier: Nagamani19, a short, unique, - non-sequential identifier based on Hashids. - - TODO: Submit patch to `crate-python`, to be enabled by a - dialect parameter `crate_translate_autoincrement` or such. - """ - import sqlalchemy.sql.schema as schema - - init_dist = schema.Column.__init__ - - def __init__(self, *args, **kwargs): - if "autoincrement" in kwargs: - del kwargs["autoincrement"] - if "default" not in kwargs: - kwargs["default"] = generate_unique_integer - init_dist(self, *args, **kwargs) - - schema.Column.__init__ = __init__ # type: ignore[method-assign] - - -def patch_compiler(): - """ - Patch CrateDB SQLAlchemy dialect to not omit the `FOR UPDATE` clause on - `SELECT ... FOR UPDATE` statements. - - https://github.com/crate-workbench/mlflow-cratedb/issues/7 - - TODO: Submit to `crate-python` as a bugfix patch. - """ - from crate.client.sqlalchemy.compiler import CrateCompiler - - def for_update_clause(self, select, **kw): - return "" - - CrateCompiler.for_update_clause = for_update_clause diff --git a/mlflow_cratedb/patch/crate_python.py b/mlflow_cratedb/patch/crate_python.py index d84fbbf..d566e33 100644 --- a/mlflow_cratedb/patch/crate_python.py +++ b/mlflow_cratedb/patch/crate_python.py @@ -1,3 +1,53 @@ +from mlflow_cratedb.adapter.util import generate_unique_integer + + +def patch_crate_python(): + patch_compiler() + patch_models() + patch_raise_for_status() + + +def patch_models(): + """ + Configure SQLAlchemy model columns with an alternative to `autoincrement=True`. + + In this case, use a random identifier: Nagamani19, a short, unique, + non-sequential identifier based on Hashids. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_translate_autoincrement` or such. + """ + import sqlalchemy.sql.schema as schema + + init_dist = schema.Column.__init__ + + def __init__(self, *args, **kwargs): + if "autoincrement" in kwargs: + del kwargs["autoincrement"] + if "default" not in kwargs: + kwargs["default"] = generate_unique_integer + init_dist(self, *args, **kwargs) + + schema.Column.__init__ = __init__ # type: ignore[method-assign] + + +def patch_compiler(): + """ + Patch CrateDB SQLAlchemy dialect to not omit the `FOR UPDATE` clause on + `SELECT ... FOR UPDATE` statements. + + https://github.com/crate-workbench/mlflow-cratedb/issues/7 + + TODO: Submit to `crate-python` as a bugfix patch. + """ + from crate.client.sqlalchemy.compiler import CrateCompiler + + def for_update_clause(self, select, **kw): + return "" + + CrateCompiler.for_update_clause = for_update_clause + + def patch_raise_for_status(): """ Patch the `crate.client.http._raise_for_status` function to properly raise diff --git a/mlflow_cratedb/patch/mlflow/__init__.py b/mlflow_cratedb/patch/mlflow/__init__.py new file mode 100644 index 0000000..c6dbe6b --- /dev/null +++ b/mlflow_cratedb/patch/mlflow/__init__.py @@ -0,0 +1,20 @@ +from mlflow_cratedb.patch.mlflow.db_types import patch_dbtypes +from mlflow_cratedb.patch.mlflow.db_utils import patch_db_utils +from mlflow_cratedb.patch.mlflow.model import polyfill_uniqueness_constraints +from mlflow_cratedb.patch.mlflow.search_utils import patch_search_utils +from mlflow_cratedb.patch.mlflow.server import patch_run_server +from mlflow_cratedb.patch.mlflow.settings import patch_environment_variables +from mlflow_cratedb.patch.mlflow.tracking import patch_tracking + + +def patch_mlflow(): + """ + Patch the MLflow package. + """ + patch_dbtypes() + patch_db_utils() + patch_run_server() + patch_environment_variables() + patch_search_utils() + patch_tracking() + polyfill_uniqueness_constraints() diff --git a/mlflow_cratedb/monkey/db_types.py b/mlflow_cratedb/patch/mlflow/db_types.py similarity index 100% rename from mlflow_cratedb/monkey/db_types.py rename to mlflow_cratedb/patch/mlflow/db_types.py diff --git a/mlflow_cratedb/monkey/db_utils.py b/mlflow_cratedb/patch/mlflow/db_utils.py similarity index 50% rename from mlflow_cratedb/monkey/db_utils.py rename to mlflow_cratedb/patch/mlflow/db_utils.py index 880d9ca..ba656ae 100644 --- a/mlflow_cratedb/monkey/db_utils.py +++ b/mlflow_cratedb/patch/mlflow/db_utils.py @@ -1,9 +1,9 @@ import functools -import typing as t import sqlalchemy as sa from mlflow_cratedb.adapter.db import enable_refresh_after_dml +from mlflow_cratedb.patch.sqlalchemy import patch_sqlalchemy_inspector def patch_db_utils(): @@ -34,23 +34,3 @@ def _verify_schema(engine: sa.Engine): Skipping Alembic, that's a no-op. """ pass - - -def patch_sqlalchemy_inspector(engine: sa.Engine): - """ - When using `get_table_names()`, make sure the correct schema name gets used. - - TODO: Verify if this is really needed. SQLAlchemy should use the `search_path` properly already. - TODO: Submit this to SQLAlchemy? - """ - get_table_names_dist = engine.dialect.get_table_names - schema_name = engine.url.query.get("schema") - if isinstance(schema_name, tuple): - schema_name = schema_name[0] - - def get_table_names(connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]: - if schema is None: - schema = schema_name - return get_table_names_dist(connection=connection, schema=schema, **kw) - - engine.dialect.get_table_names = get_table_names # type: ignore diff --git a/mlflow_cratedb/monkey/mlflow/model.py b/mlflow_cratedb/patch/mlflow/model.py similarity index 100% rename from mlflow_cratedb/monkey/mlflow/model.py rename to mlflow_cratedb/patch/mlflow/model.py diff --git a/mlflow_cratedb/monkey/mlflow/search_utils.py b/mlflow_cratedb/patch/mlflow/search_utils.py similarity index 96% rename from mlflow_cratedb/monkey/mlflow/search_utils.py rename to mlflow_cratedb/patch/mlflow/search_utils.py index 8b7dfc4..d3d4902 100644 --- a/mlflow_cratedb/monkey/mlflow/search_utils.py +++ b/mlflow_cratedb/patch/mlflow/search_utils.py @@ -1,4 +1,4 @@ -def patch_mlflow_search_utils(): +def patch_search_utils(): """ Patch MLflow's `SearchUtils` to return a comparison function for CrateDB. """ diff --git a/mlflow_cratedb/monkey/server.py b/mlflow_cratedb/patch/mlflow/server.py similarity index 100% rename from mlflow_cratedb/monkey/server.py rename to mlflow_cratedb/patch/mlflow/server.py diff --git a/mlflow_cratedb/monkey/environment_variables.py b/mlflow_cratedb/patch/mlflow/settings.py similarity index 100% rename from mlflow_cratedb/monkey/environment_variables.py rename to mlflow_cratedb/patch/mlflow/settings.py diff --git a/mlflow_cratedb/monkey/mlflow/tracking.py b/mlflow_cratedb/patch/mlflow/tracking.py similarity index 99% rename from mlflow_cratedb/monkey/mlflow/tracking.py rename to mlflow_cratedb/patch/mlflow/tracking.py index 76ad947..51a97a3 100644 --- a/mlflow_cratedb/monkey/mlflow/tracking.py +++ b/mlflow_cratedb/patch/mlflow/tracking.py @@ -2,7 +2,7 @@ from functools import partial -def patch_mlflow_tracking(): +def patch_tracking(): """ Patch the experiment tracking subsystem of MLflow. """ diff --git a/mlflow_cratedb/patch/sqlalchemy.py b/mlflow_cratedb/patch/sqlalchemy.py new file mode 100644 index 0000000..e48e677 --- /dev/null +++ b/mlflow_cratedb/patch/sqlalchemy.py @@ -0,0 +1,23 @@ +import typing as t + +import sqlalchemy as sa + + +def patch_sqlalchemy_inspector(engine: sa.Engine): + """ + When using `get_table_names()`, make sure the correct schema name gets used. + + TODO: Verify if this is really needed. SQLAlchemy should use the `search_path` properly already. + TODO: Submit this to SQLAlchemy? + """ + get_table_names_dist = engine.dialect.get_table_names + schema_name = engine.url.query.get("schema") + if isinstance(schema_name, tuple): + schema_name = schema_name[0] + + def get_table_names(connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]: + if schema is None: + schema = schema_name + return get_table_names_dist(connection=connection, schema=schema, **kw) + + engine.dialect.get_table_names = get_table_names # type: ignore From 9306a1b27749629a83ec0364f6b9c843b9b2ce66 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 14:44:52 +0200 Subject: [PATCH 12/18] Do not propagate schema into SQL DDL (`_setup_db_{drop,create}_tables`) --- docs/backlog.md | 1 - mlflow_cratedb/adapter/db.py | 17 ++++++++------- mlflow_cratedb/adapter/ddl/cratedb.sql | 30 +++++++++++++------------- mlflow_cratedb/adapter/ddl/drop.sql | 30 +++++++++++++------------- 4 files changed, 39 insertions(+), 39 deletions(-) diff --git a/docs/backlog.md b/docs/backlog.md index a45af63..6827573 100644 --- a/docs/backlog.md +++ b/docs/backlog.md @@ -5,4 +5,3 @@ - Run an MLflow project from the given URI, using `mlflow run` - Explore `mlflow experiments search` for testing purposes - CLI shortcut for `ddl/drop.sql` -- Use `search_path` instead of populating the schema name into SQL DDL files diff --git a/mlflow_cratedb/adapter/db.py b/mlflow_cratedb/adapter/db.py index c040966..d13c558 100644 --- a/mlflow_cratedb/adapter/db.py +++ b/mlflow_cratedb/adapter/db.py @@ -15,23 +15,24 @@ def _setup_db_create_tables(engine: sa.Engine): TODO: Currently, the path is hardcoded to `cratedb.sql`. """ - schema_name = engine.url.query.get("schema") - schema_prefix = "" - if schema_name is not None: - schema_prefix = f'"{schema_name}".' with importlib.resources.path("mlflow_cratedb.adapter", "ddl") as ddl: sql_file = ddl.joinpath("cratedb.sql") - sql_statements = sql_file.read_text().format(schema_prefix=schema_prefix) + sql_statements = sql_file.read_text() with engine.connect() as connection: for statement in sqlparse.split(sql_statements): connection.execute(sa.text(statement)) -def _setup_db_drop_tables(): +def _setup_db_drop_tables(engine: sa.Engine): """ - TODO: Not implemented yet. + Drop all relevant database tables. Handle with care. """ - pass + with importlib.resources.path("mlflow_cratedb.adapter", "ddl") as ddl: + sql_file = ddl.joinpath("drop.sql") + sql_statements = sql_file.read_text() + with engine.connect() as connection: + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) def enable_refresh_after_dml(): diff --git a/mlflow_cratedb/adapter/ddl/cratedb.sql b/mlflow_cratedb/adapter/ddl/cratedb.sql index e8e8182..bead087 100644 --- a/mlflow_cratedb/adapter/ddl/cratedb.sql +++ b/mlflow_cratedb/adapter/ddl/cratedb.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS {schema_prefix}"datasets" ( +CREATE TABLE IF NOT EXISTS "datasets" ( "dataset_uuid" TEXT NOT NULL, "experiment_id" BIGINT NOT NULL, "name" TEXT NOT NULL, @@ -10,14 +10,14 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"datasets" ( PRIMARY KEY ("experiment_id", "name", "digest") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"experiment_tags" ( +CREATE TABLE IF NOT EXISTS "experiment_tags" ( "key" TEXT NOT NULL, "value" TEXT, "experiment_id" BIGINT NOT NULL, PRIMARY KEY ("key", "experiment_id") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"experiments" ( +CREATE TABLE IF NOT EXISTS "experiments" ( "experiment_id" BIGINT NOT NULL, -- default=autoincrement "name" TEXT NOT NULL, "artifact_location" TEXT, @@ -27,7 +27,7 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"experiments" ( PRIMARY KEY ("experiment_id") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"inputs" ( +CREATE TABLE IF NOT EXISTS "inputs" ( "input_uuid" TEXT NOT NULL, "source_type" TEXT NOT NULL, "source_id" TEXT NOT NULL, @@ -36,14 +36,14 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"inputs" ( PRIMARY KEY ("source_type", "source_id", "destination_type", "destination_id") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"input_tags" ( +CREATE TABLE IF NOT EXISTS "input_tags" ( "input_uuid" TEXT NOT NULL, "name" TEXT NOT NULL, "value" TEXT NOT NULL, PRIMARY KEY ("input_uuid", "name") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"latest_metrics" ( +CREATE TABLE IF NOT EXISTS "latest_metrics" ( "key" TEXT NOT NULL, "value" DOUBLE NOT NULL, "timestamp" BIGINT NOT NULL, @@ -53,7 +53,7 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"latest_metrics" ( PRIMARY KEY ("key", "run_uuid") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"metrics" ( +CREATE TABLE IF NOT EXISTS "metrics" ( "key" TEXT NOT NULL, "value" DOUBLE NOT NULL, "timestamp" BIGINT NOT NULL, @@ -63,7 +63,7 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"metrics" ( PRIMARY KEY ("key", "timestamp", "step", "run_uuid", "value", "is_nan") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"model_versions" ( +CREATE TABLE IF NOT EXISTS "model_versions" ( "name" TEXT NOT NULL, "version" INTEGER NOT NULL, "creation_time" BIGINT, -- default=get_current_time_millis @@ -78,40 +78,40 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"model_versions" ( "status_message" TEXT ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"model_version_tags" ( +CREATE TABLE IF NOT EXISTS "model_version_tags" ( "name" TEXT NOT NULL, "version" INTEGER NOT NULL, "key" TEXT NOT NULL, "value" TEXT ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"params" ( +CREATE TABLE IF NOT EXISTS "params" ( "key" TEXT NOT NULL, "value" TEXT NOT NULL, "run_uuid" TEXT NOT NULL, PRIMARY KEY ("key", "run_uuid") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"registered_models" ( +CREATE TABLE IF NOT EXISTS "registered_models" ( "name" TEXT NOT NULL, "key" TEXT NOT NULL, "value" TEXT ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"registered_model_aliases" ( +CREATE TABLE IF NOT EXISTS "registered_model_aliases" ( "name" TEXT NOT NULL, "alias" TEXT NOT NULL, "version" TEXT NOT NULL ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"registered_model_tags" ( +CREATE TABLE IF NOT EXISTS "registered_model_tags" ( "name" TEXT NOT NULL, "creation_time" BIGINT, -- default=get_current_time_millis "last_update_time" BIGINT, -- default=get_current_time_millis "description" TEXT ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"runs" ( +CREATE TABLE IF NOT EXISTS "runs" ( "run_uuid" TEXT NOT NULL, "name" TEXT, "source_type" TEXT, @@ -129,7 +129,7 @@ CREATE TABLE IF NOT EXISTS {schema_prefix}"runs" ( PRIMARY KEY ("run_uuid") ); -CREATE TABLE IF NOT EXISTS {schema_prefix}"tags" ( +CREATE TABLE IF NOT EXISTS "tags" ( "key" TEXT NOT NULL, "value" TEXT, "run_uuid" TEXT NOT NULL, diff --git a/mlflow_cratedb/adapter/ddl/drop.sql b/mlflow_cratedb/adapter/ddl/drop.sql index 5c73c22..a439e50 100644 --- a/mlflow_cratedb/adapter/ddl/drop.sql +++ b/mlflow_cratedb/adapter/ddl/drop.sql @@ -1,15 +1,15 @@ -DROP TABLE IF EXISTS {schema_prefix}"datasets"; -DROP TABLE IF EXISTS {schema_prefix}"experiment_tags"; -DROP TABLE IF EXISTS {schema_prefix}"experiments"; -DROP TABLE IF EXISTS {schema_prefix}"inputs"; -DROP TABLE IF EXISTS {schema_prefix}"input_tags"; -DROP TABLE IF EXISTS {schema_prefix}"latest_metrics"; -DROP TABLE IF EXISTS {schema_prefix}"metrics"; -DROP TABLE IF EXISTS {schema_prefix}"model_versions"; -DROP TABLE IF EXISTS {schema_prefix}"model_version_tags"; -DROP TABLE IF EXISTS {schema_prefix}"params"; -DROP TABLE IF EXISTS {schema_prefix}"registered_models"; -DROP TABLE IF EXISTS {schema_prefix}"registered_model_aliases"; -DROP TABLE IF EXISTS {schema_prefix}"registered_model_tags"; -DROP TABLE IF EXISTS {schema_prefix}"runs"; -DROP TABLE IF EXISTS {schema_prefix}"tags"; +DROP TABLE IF EXISTS "datasets"; +DROP TABLE IF EXISTS "experiment_tags"; +DROP TABLE IF EXISTS "experiments"; +DROP TABLE IF EXISTS "inputs"; +DROP TABLE IF EXISTS "input_tags"; +DROP TABLE IF EXISTS "latest_metrics"; +DROP TABLE IF EXISTS "metrics"; +DROP TABLE IF EXISTS "model_versions"; +DROP TABLE IF EXISTS "model_version_tags"; +DROP TABLE IF EXISTS "params"; +DROP TABLE IF EXISTS "registered_models"; +DROP TABLE IF EXISTS "registered_model_aliases"; +DROP TABLE IF EXISTS "registered_model_tags"; +DROP TABLE IF EXISTS "runs"; +DROP TABLE IF EXISTS "tags"; From e2e8fcc4b9e9c95dffbd39ba387e36d585876a42 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 14:45:26 +0200 Subject: [PATCH 13/18] Update README not to advertise not applicable CrateDB options --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 0b21a75..0c5b22a 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,7 @@ Docker or Podman. ```shell docker run --rm -it --publish=4200:4200 --publish=5432:5432 \ --env=CRATE_HEAP_SIZE=4g crate \ - -Cdiscovery.type=single-node \ - -Ccluster.routing.allocation.disk.threshold_enabled=false + -Cdiscovery.type=single-node ``` Start the MLflow server, pointing it to your [CrateDB] instance, From 34486ead25a11ebdf01a03ff2e3580a34f15cb93 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 14:46:05 +0200 Subject: [PATCH 14/18] Tests: Add a few adapter tests, verifying basic database conversations --- tests/test_adapter.py | 69 +++++++++++++++++++++++++++++++++++++++++++ tests/test_foo.py | 2 -- 2 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 tests/test_adapter.py delete mode 100644 tests/test_foo.py diff --git a/tests/test_adapter.py b/tests/test_adapter.py new file mode 100644 index 0000000..8d4e1ff --- /dev/null +++ b/tests/test_adapter.py @@ -0,0 +1,69 @@ +import mlflow +import pytest +import sqlalchemy as sa +from mlflow.store.tracking.dbmodels.initial_models import Base +from mlflow.store.tracking.dbmodels.models import SqlExperiment +from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + +from mlflow_cratedb.adapter.db import _setup_db_create_tables, _setup_db_drop_tables + +DB_URI = "crate://crate@localhost/?schema=testdrive" +ARTIFACT_URI = "artifact_folder" + + +@pytest.fixture +def engine(): + yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI) + + +@pytest.fixture +def store(): + """ + A fixture for providing an instance of `SqlAlchemyStore`. + """ + yield SqlAlchemyStore(DB_URI, ARTIFACT_URI) + + +@pytest.fixture +def store_empty(store): + """ + A fixture for providing an instance of `SqlAlchemyStore`, + after pruning all database tables. + """ + with store.ManagedSessionMaker() as session: + session.query(SqlExperiment).delete() + for mapper in Base.registry.mappers: + session.query(mapper.class_).delete() + sql = f"REFRESH TABLE testdrive.{mapper.class_.__tablename__};" + session.execute(sa.text(sql)) + yield store + + +def test_setup_tables(engine: sa.Engine): + """ + Test if creating database tables works, and that they use the correct schema. + """ + _setup_db_drop_tables(engine=engine) + _setup_db_create_tables(engine=engine) + with engine.connect() as connection: + result = connection.execute(sa.text("SELECT * FROM testdrive.experiments;")) + assert result.rowcount == 0 + + +def test_query_model(store: SqlAlchemyStore): + """ + Verify setting up MLflow database tables works well. + """ + + with store.ManagedSessionMaker() as session: + # Verify table has one record, the "Default" experiment. + assert session.query(SqlExperiment).count() == 1 + + # Run a basic ORM-based query. + experiment: SqlExperiment = session.query(SqlExperiment).one() + assert experiment.name == "Default" + + # Run the same query using plain SQL. + # This makes sure the designated schema is properly used through `search_path`. + record = session.execute(sa.text("SELECT * FROM testdrive.experiments;")).mappings().one() + assert record["name"] == "Default" diff --git a/tests/test_foo.py b/tests/test_foo.py deleted file mode 100644 index 8270b66..0000000 --- a/tests/test_foo.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_foo(): - assert 42 == 42 From 25919b52c3e10c741fa29d073cd03be4f0c9b072 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 18:49:36 +0200 Subject: [PATCH 15/18] CI: Provide CrateDB nightly to the test suite on GHA --- .github/workflows/main.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 76d9d5e..0f94775 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,6 +29,14 @@ jobs: OS: ${{ matrix.os }} PYTHON: ${{ matrix.python-version }} + # https://docs.github.com/en/free-pro-team@latest/actions/guides/about-service-containers + services: + cratedb: + image: crate/crate:nightly + ports: + - 4200:4200 + - 5432:5432 + name: Python ${{ matrix.python-version }} on OS ${{ matrix.os }} steps: From 440ef2e0d69d8e0991df709d61dbeb086cac8d67 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 15:26:50 +0200 Subject: [PATCH 16/18] Refactor adapters and patches once again --- mlflow_cratedb/adapter/db.py | 65 ------------------------- mlflow_cratedb/adapter/setup_db.py | 34 +++++++++++++ mlflow_cratedb/adapter/util.py | 8 --- mlflow_cratedb/patch/crate_python.py | 16 ++++-- mlflow_cratedb/patch/mlflow/__init__.py | 5 +- mlflow_cratedb/patch/mlflow/db_utils.py | 4 +- mlflow_cratedb/patch/mlflow/model.py | 37 +++++++++++++- tests/test_adapter.py | 2 +- 8 files changed, 86 insertions(+), 85 deletions(-) delete mode 100644 mlflow_cratedb/adapter/db.py create mode 100644 mlflow_cratedb/adapter/setup_db.py delete mode 100644 mlflow_cratedb/adapter/util.py diff --git a/mlflow_cratedb/adapter/db.py b/mlflow_cratedb/adapter/db.py deleted file mode 100644 index d13c558..0000000 --- a/mlflow_cratedb/adapter/db.py +++ /dev/null @@ -1,65 +0,0 @@ -import importlib.resources - -import sqlalchemy as sa -import sqlparse -from sqlalchemy.event import listen - - -def _setup_db_create_tables(engine: sa.Engine): - """ - Because CrateDB does not play well with a full-fledged SQLAlchemy data model and - corresponding Alembic migrations, shortcut that and replace it with a classic - database schema provisioning based on SQL DDL. - - It will cause additional maintenance, but well, c'est la vie. - - TODO: Currently, the path is hardcoded to `cratedb.sql`. - """ - with importlib.resources.path("mlflow_cratedb.adapter", "ddl") as ddl: - sql_file = ddl.joinpath("cratedb.sql") - sql_statements = sql_file.read_text() - with engine.connect() as connection: - for statement in sqlparse.split(sql_statements): - connection.execute(sa.text(statement)) - - -def _setup_db_drop_tables(engine: sa.Engine): - """ - Drop all relevant database tables. Handle with care. - """ - with importlib.resources.path("mlflow_cratedb.adapter", "ddl") as ddl: - sql_file = ddl.joinpath("drop.sql") - sql_statements = sql_file.read_text() - with engine.connect() as connection: - for statement in sqlparse.split(sql_statements): - connection.execute(sa.text(statement)) - - -def enable_refresh_after_dml(): - """ - Run `REFRESH TABLE ` after each INSERT, UPDATE, and DELETE operation. - - CrateDB is eventually consistent, i.e. write operations are not flushed to - disk immediately, so readers may see stale data. In a traditional OLTP-like - application, this is not applicable. - - This SQLAlchemy extension makes sure that data is synchronized after each - operation manipulating data. - - TODO: Submit patch to `crate-python`, to be enabled by a - dialect parameter `crate_dml_refresh` or such. - """ - from mlflow.store.db.base_sql_model import Base - - for mapper in Base.registry.mappers: - listen(mapper.class_, "after_insert", do_refresh) - listen(mapper.class_, "after_update", do_refresh) - listen(mapper.class_, "after_delete", do_refresh) - - -def do_refresh(mapper, connection, target): - """ - SQLAlchemy event handler for `after_{insert,update,delete}` events, invoking `REFRESH TABLE`. - """ - sql = f"REFRESH TABLE {target.__tablename__}" - connection.execute(sa.text(sql)) diff --git a/mlflow_cratedb/adapter/setup_db.py b/mlflow_cratedb/adapter/setup_db.py new file mode 100644 index 0000000..41ad7be --- /dev/null +++ b/mlflow_cratedb/adapter/setup_db.py @@ -0,0 +1,34 @@ +import importlib.resources + +import sqlalchemy as sa +import sqlparse + + +def read_ddl(filename: str): + return importlib.resources.files("mlflow_cratedb.adapter.ddl").joinpath(filename).read_text() + + +def _setup_db_create_tables(engine: sa.Engine): + """ + Because CrateDB does not play well with a full-fledged SQLAlchemy data model and + corresponding Alembic migrations, shortcut that and replace it with a classic + database schema provisioning based on SQL DDL. + + It will cause additional maintenance, but well, c'est la vie. + + TODO: Currently, the path is hardcoded to `cratedb.sql`. + """ + sql_statements = read_ddl("cratedb.sql") + with engine.connect() as connection: + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + + +def _setup_db_drop_tables(engine: sa.Engine): + """ + Drop all relevant database tables. Handle with care. + """ + sql_statements = read_ddl("drop.sql") + with engine.connect() as connection: + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) diff --git a/mlflow_cratedb/adapter/util.py b/mlflow_cratedb/adapter/util.py deleted file mode 100644 index 8f4cf65..0000000 --- a/mlflow_cratedb/adapter/util.py +++ /dev/null @@ -1,8 +0,0 @@ -from vasuki import generate_nagamani19_int - - -def generate_unique_integer() -> int: - """ - Produce a short, unique, non-sequential identifier based on Hashids. - """ - return generate_nagamani19_int(size=10) diff --git a/mlflow_cratedb/patch/crate_python.py b/mlflow_cratedb/patch/crate_python.py index d566e33..6155f5b 100644 --- a/mlflow_cratedb/patch/crate_python.py +++ b/mlflow_cratedb/patch/crate_python.py @@ -1,6 +1,3 @@ -from mlflow_cratedb.adapter.util import generate_unique_integer - - def patch_crate_python(): patch_compiler() patch_models() @@ -15,7 +12,7 @@ def patch_models(): non-sequential identifier based on Hashids. TODO: Submit patch to `crate-python`, to be enabled by a - dialect parameter `crate_translate_autoincrement` or such. + dialect parameter `crate_polyfill_autoincrement` or such. """ import sqlalchemy.sql.schema as schema @@ -85,7 +82,7 @@ def check_uniqueness_factory(sa_entity, attribute_name): CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it. TODO: Submit patch to `crate-python`, to be enabled by a - dialect parameter `crate_translate_unique` or such. + dialect parameter `crate_polyfill_unique` or such. """ def check_uniqueness(mapper, connection, target): @@ -107,3 +104,12 @@ def check_uniqueness(mapper, connection, target): ) return check_uniqueness + + +def generate_unique_integer() -> int: + """ + Produce a short, unique, non-sequential identifier based on Hashids. + """ + from vasuki import generate_nagamani19_int + + return generate_nagamani19_int(size=10) diff --git a/mlflow_cratedb/patch/mlflow/__init__.py b/mlflow_cratedb/patch/mlflow/__init__.py index c6dbe6b..acd3a7a 100644 --- a/mlflow_cratedb/patch/mlflow/__init__.py +++ b/mlflow_cratedb/patch/mlflow/__init__.py @@ -1,6 +1,6 @@ from mlflow_cratedb.patch.mlflow.db_types import patch_dbtypes from mlflow_cratedb.patch.mlflow.db_utils import patch_db_utils -from mlflow_cratedb.patch.mlflow.model import polyfill_uniqueness_constraints +from mlflow_cratedb.patch.mlflow.model import polyfill_refresh_after_dml, polyfill_uniqueness_constraints from mlflow_cratedb.patch.mlflow.search_utils import patch_search_utils from mlflow_cratedb.patch.mlflow.server import patch_run_server from mlflow_cratedb.patch.mlflow.settings import patch_environment_variables @@ -12,9 +12,10 @@ def patch_mlflow(): Patch the MLflow package. """ patch_dbtypes() + polyfill_refresh_after_dml() + polyfill_uniqueness_constraints() patch_db_utils() patch_run_server() patch_environment_variables() patch_search_utils() patch_tracking() - polyfill_uniqueness_constraints() diff --git a/mlflow_cratedb/patch/mlflow/db_utils.py b/mlflow_cratedb/patch/mlflow/db_utils.py index ba656ae..a622782 100644 --- a/mlflow_cratedb/patch/mlflow/db_utils.py +++ b/mlflow_cratedb/patch/mlflow/db_utils.py @@ -2,14 +2,12 @@ import sqlalchemy as sa -from mlflow_cratedb.adapter.db import enable_refresh_after_dml from mlflow_cratedb.patch.sqlalchemy import patch_sqlalchemy_inspector def patch_db_utils(): import mlflow.store.db.utils as db_utils - enable_refresh_after_dml() db_utils._initialize_tables = _initialize_tables db_utils._verify_schema = _verify_schema @@ -22,7 +20,7 @@ def _initialize_tables(engine: sa.Engine): """ from mlflow.store.db.utils import _logger - from mlflow_cratedb.adapter.db import _setup_db_create_tables + from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables patch_sqlalchemy_inspector(engine) _logger.info("Creating initial MLflow database tables...") diff --git a/mlflow_cratedb/patch/mlflow/model.py b/mlflow_cratedb/patch/mlflow/model.py index 5480e9d..d55f79d 100644 --- a/mlflow_cratedb/patch/mlflow/model.py +++ b/mlflow_cratedb/patch/mlflow/model.py @@ -1,11 +1,46 @@ +import sqlalchemy as sa +from sqlalchemy.event import listen + from mlflow_cratedb.patch.crate_python import check_uniqueness_factory def polyfill_uniqueness_constraints(): """ Establish a manual uniqueness check on the `SqlExperiment.name` column. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_polyfill_unique` or such. """ from mlflow.store.tracking.dbmodels.models import SqlExperiment - from sqlalchemy.event import listen listen(SqlExperiment, "before_insert", check_uniqueness_factory(SqlExperiment, "name")) + + +def polyfill_refresh_after_dml(): + """ + Run `REFRESH TABLE ` after each INSERT, UPDATE, and DELETE operation. + + CrateDB is eventually consistent, i.e. write operations are not flushed to + disk immediately, so readers may see stale data. In a traditional OLTP-like + application, this is not applicable. + + This SQLAlchemy extension makes sure that data is synchronized after each + operation manipulating data. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_dml_refresh` or such. + """ + from mlflow.store.db.base_sql_model import Base + + for mapper in Base.registry.mappers: + listen(mapper.class_, "after_insert", do_refresh) + listen(mapper.class_, "after_update", do_refresh) + listen(mapper.class_, "after_delete", do_refresh) + + +def do_refresh(mapper, connection, target): + """ + SQLAlchemy event handler for `after_{insert,update,delete}` events, invoking `REFRESH TABLE`. + """ + sql = f"REFRESH TABLE {target.__tablename__}" + connection.execute(sa.text(sql)) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 8d4e1ff..5d9a5d7 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -5,7 +5,7 @@ from mlflow.store.tracking.dbmodels.models import SqlExperiment from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore -from mlflow_cratedb.adapter.db import _setup_db_create_tables, _setup_db_drop_tables +from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables DB_URI = "crate://crate@localhost/?schema=testdrive" ARTIFACT_URI = "artifact_folder" From 186c9e4d01742a7ae62e4c471a8cf7df90505428 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 16:48:30 +0200 Subject: [PATCH 17/18] Remove unnecessary patching of SQLAlchemy inspector --- mlflow_cratedb/patch/mlflow/db_utils.py | 3 --- mlflow_cratedb/patch/sqlalchemy.py | 23 ----------------------- 2 files changed, 26 deletions(-) delete mode 100644 mlflow_cratedb/patch/sqlalchemy.py diff --git a/mlflow_cratedb/patch/mlflow/db_utils.py b/mlflow_cratedb/patch/mlflow/db_utils.py index a622782..dc312ac 100644 --- a/mlflow_cratedb/patch/mlflow/db_utils.py +++ b/mlflow_cratedb/patch/mlflow/db_utils.py @@ -2,8 +2,6 @@ import sqlalchemy as sa -from mlflow_cratedb.patch.sqlalchemy import patch_sqlalchemy_inspector - def patch_db_utils(): import mlflow.store.db.utils as db_utils @@ -22,7 +20,6 @@ def _initialize_tables(engine: sa.Engine): from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables - patch_sqlalchemy_inspector(engine) _logger.info("Creating initial MLflow database tables...") _setup_db_create_tables(engine) diff --git a/mlflow_cratedb/patch/sqlalchemy.py b/mlflow_cratedb/patch/sqlalchemy.py deleted file mode 100644 index e48e677..0000000 --- a/mlflow_cratedb/patch/sqlalchemy.py +++ /dev/null @@ -1,23 +0,0 @@ -import typing as t - -import sqlalchemy as sa - - -def patch_sqlalchemy_inspector(engine: sa.Engine): - """ - When using `get_table_names()`, make sure the correct schema name gets used. - - TODO: Verify if this is really needed. SQLAlchemy should use the `search_path` properly already. - TODO: Submit this to SQLAlchemy? - """ - get_table_names_dist = engine.dialect.get_table_names - schema_name = engine.url.query.get("schema") - if isinstance(schema_name, tuple): - schema_name = schema_name[0] - - def get_table_names(connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]: - if schema is None: - schema = schema_name - return get_table_names_dist(connection=connection, schema=schema, **kw) - - engine.dialect.get_table_names = get_table_names # type: ignore From 68813334a6728c7c7165cf2dacbf34d840729d40 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 9 Sep 2023 20:56:11 +0200 Subject: [PATCH 18/18] Fix random integer generation --- mlflow_cratedb/patch/crate_python.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow_cratedb/patch/crate_python.py b/mlflow_cratedb/patch/crate_python.py index 6155f5b..3f7ed2d 100644 --- a/mlflow_cratedb/patch/crate_python.py +++ b/mlflow_cratedb/patch/crate_python.py @@ -112,4 +112,4 @@ def generate_unique_integer() -> int: """ from vasuki import generate_nagamani19_int - return generate_nagamani19_int(size=10) + return generate_nagamani19_int()