From e7a1e5ed593b89c310528ae86967e0d8a5e515c3 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 10 Sep 2023 11:37:02 +0200 Subject: [PATCH] Add example program `tracking_merlion.py`, and a corresponding test case --- .github/workflows/main.yml | 2 +- CHANGES.md | 1 + README.md | 8 +- examples/tracking_merlion.py | 186 +++++++++++++++++++++++++++++++++++ pyproject.toml | 4 + tests/test_examples.py | 57 +++++++++++ tests/util.py | 23 +++++ 7 files changed, 278 insertions(+), 3 deletions(-) create mode 100644 examples/tracking_merlion.py create mode 100644 tests/test_examples.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0f94775..92e109e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -59,7 +59,7 @@ jobs: pip install "setuptools>=64" --upgrade # Install package in editable mode. - pip install --use-pep517 --prefer-binary --editable=.[develop,docs,test] + pip install --use-pep517 --prefer-binary --editable=.[examples,develop,docs,test] - name: Run linter and software tests run: | diff --git a/CHANGES.md b/CHANGES.md index d6a9c0a..6b51f4c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,3 +11,4 @@ - Project: Add `versioningit`, for effortless versioning - Add patch for SQLAlchemy Inspector's `get_table_names` - Reorder CrateDB SQLAlchemy Dialect polyfills +- Add example experiment program `tracking_merlion.py`, and corresponding tests diff --git a/README.md b/README.md index 61c753a..865a997 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ for [MLflow Tracking]. Install the most recent version of the `mlflow-cratedb` package. ```shell -pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb' +pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb#egg=mlflow-cratedb[examples]' ``` To verify if the installation worked, you can inspect the version numbers @@ -54,7 +54,7 @@ git clone https://github.com/crate-workbench/mlflow-cratedb cd mlflow-cratedb python3 -m venv .venv source .venv/bin/activate -pip install --editable='.[develop,docs,test]' +pip install --editable='.[examples,develop,docs,test]' ``` Run linters and software tests, skipping slow tests: @@ -74,13 +74,17 @@ pytest -m slow [Siddharth Murching], [Corey Zumar], [Harutaka Kawamura], [Ben Wilson], and all other contributors for conceiving and maintaining [MLflow]. +[Andreas Nigg] for contributing the [tracking_merlion.py](./examples/tracking_merlion.py) +ML experiment program, which is using [Merlion]. +[Andreas Nigg]: https://github.com/andnig [Ben Wilson]: https://github.com/BenWilson2 [Corey Zumar]: https://github.com/dbczumar [CrateDB]: https://github.com/crate/crate [CrateDB Cloud]: https://console.cratedb.cloud/ [Harutaka Kawamura]: https://github.com/harupy +[Merlion]: https://github.com/salesforce/Merlion [MLflow]: https://mlflow.org/ [MLflow Tracking]: https://mlflow.org/docs/latest/tracking.html [Siddharth Murching]: https://github.com/smurching diff --git a/examples/tracking_merlion.py b/examples/tracking_merlion.py new file mode 100644 index 0000000..a1a00a6 --- /dev/null +++ b/examples/tracking_merlion.py @@ -0,0 +1,186 @@ +""" +About + +Use MLflow and CrateDB to track the metrics, parameters, and outcomes of an ML +experiment program using Merlion. It uses the `machine_temperature_system_failure.csv` +dataset from the Numenta Anomaly Benchmark data. + +- https://github.com/crate-workbench/mlflow-cratedb +- https://mlflow.org/docs/latest/tracking.html + +Usage + +Before running the program, optionally define the `MLFLOW_TRACKING_URI` environment +variable, in order to record events and metrics either directly into the database, +or by submitting them to an MLflow Tracking Server. + + # Use CrateDB database directly + export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow" + + # Use MLflow Tracking Server + export MLFLOW_TRACKING_URI=http://127.0.0.1:5000 + +Resources + +- https://mlflow.org/ +- https://github.com/crate/crate +- https://github.com/salesforce/Merlion +- https://github.com/numenta/NAB +""" + +import os + +import mlflow +import numpy as np +import pandas as pd +from crate import client +from merlion.evaluate.anomaly import TSADMetric +from merlion.models.defaults import DefaultDetector, DefaultDetectorConfig +from merlion.utils import TimeSeries + + +def connect_database(): + """ + Connect to CrateDB, and return database connection object. + """ + dburi = os.getenv("CRATEDB_HTTP_URL", "http://crate@localhost:4200") + return client.connect(dburi) + + +def table_exists(table_name: str, schema_name: str = "doc") -> bool: + """ + Check if database table exists. + """ + conn = connect_database() + cursor = conn.cursor() + sql = ( + f"SELECT table_name FROM information_schema.tables " # noqa: S608 + f"WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" + ) + cursor.execute(sql) + rowcount = cursor.rowcount + cursor.close() + conn.close() + return rowcount > 0 + + +def import_data(table_name: str): + """ + Download Numenta Anomaly Benchmark data, and load into database. + """ + + data = pd.read_csv( + "https://raw.githubusercontent.com/numenta/NAB/master/data/realKnownCause/machine_temperature_system_failure.csv" + ) + + # Split the data into chunks of 1000 rows each for better insert performance. + chunk_size = 1000 + chunks = np.array_split(data, int(len(data) / chunk_size)) + + # Insert data into CrateDB. + with connect_database() as conn: + cursor = conn.cursor() + # Create the table if it doesn't exist. + cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} (timestamp TIMESTAMP, temperature DOUBLE)") + # Insert the data in chunks. + for chunk in chunks: + sql = f"INSERT INTO {table_name} (timestamp, temperature) VALUES (?, ?)" # noqa: S608 + cursor.executemany(sql, list(chunk.itertuples(index=False, name=None))) + + +def read_data(table_name: str) -> pd.DataFrame: + """ + Read data from database into pandas DataFrame. + """ + conn = connect_database() + with conn: + cursor = conn.cursor() + cursor.execute( + f"""SELECT + DATE_BIN('5 min'::INTERVAL, "timestamp", 0) AS timestamp, + MAX(temperature) AS value + FROM {table_name} + GROUP BY timestamp + ORDER BY timestamp ASC""" + ) + data = cursor.fetchall() + + # Convert database response to pandas DataFrame. + time_series = pd.DataFrame( + [{"timestamp": pd.Timestamp.fromtimestamp(item[0] / 1000), "value": item[1]} for item in data] + ) + # Set the timestamp as the index + return time_series.set_index("timestamp") + + +def run_experiment(time_series: pd.DataFrame): + """ + Run experiment on DataFrame, using Merlion. Track it using MLflow. + """ + mlflow.set_experiment("numenta-merlion-experiment") + + with mlflow.start_run(): + train_data = TimeSeries.from_pd(time_series[time_series.index < pd.to_datetime("2013-12-15")]) + test_data = TimeSeries.from_pd(time_series[time_series.index >= pd.to_datetime("2013-12-15")]) + + model = DefaultDetector(DefaultDetectorConfig()) + model.train(train_data=train_data) + + test_pred = model.get_anomaly_label(time_series=test_data) + + # Prepare the test labels + time_frames = [ + ["2013-12-15 17:50:00.000000", "2013-12-17 17:00:00.000000"], + ["2014-01-27 14:20:00.000000", "2014-01-29 13:30:00.000000"], + ["2014-02-07 14:55:00.000000", "2014-02-09 14:05:00.000000"], + ] + + time_frames = [[pd.to_datetime(start), pd.to_datetime(end)] for start, end in time_frames] + time_series["test_labels"] = 0 + for start, end in time_frames: + time_series.loc[(time_series.index >= start) & (time_series.index <= end), "test_labels"] = 1 + + test_labels = TimeSeries.from_pd(time_series["test_labels"]) + + p = TSADMetric.Precision.value(ground_truth=test_labels, predict=test_pred) + r = TSADMetric.Recall.value(ground_truth=test_labels, predict=test_pred) + f1 = TSADMetric.F1.value(ground_truth=test_labels, predict=test_pred) + mttd = TSADMetric.MeanTimeToDetect.value(ground_truth=test_labels, predict=test_pred) + print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}\n" f"Mean Time To Detect: {mttd}") # noqa: T201 + + mlflow.log_metric("precision", p) + mlflow.log_metric("recall", r) + mlflow.log_metric("f1", f1) + mlflow.log_metric("mttd", mttd.total_seconds()) + mlflow.log_param("anomaly_threshold", model.config.threshold.alm_threshold) + mlflow.log_param("min_alm_window", model.config.threshold.min_alm_in_window) + mlflow.log_param("alm_window_minutes", model.config.threshold.alm_window_minutes) + mlflow.log_param("alm_suppress_minutes", model.config.threshold.alm_suppress_minutes) + mlflow.log_param("ensemble_size", model.config.model.combiner.n_models) + + # Save the model to MLflow. + model.save("model") + mlflow.log_artifact("model") + + +def main(): + """ + Provision dataset, and run experiment. + """ + + # Table name where the actual data is stored. + data_table = "machine_data" + + # Provision data to operate on, only once. + if not table_exists(data_table): + import_data(data_table) + + # Read data into pandas DataFrame. + data = read_data(data_table) + + # Run experiment on data. + run_experiment(data) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 780abcc..f4a8fa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,9 @@ develop = [ "ruff==0.0.287", "validate-pyproject<0.15", ] +examples = [ + "salesforce-merlion<2.1", +] release = [ "build<2", 'minibump<1; python_version >= "3.10"', @@ -84,6 +87,7 @@ release = [ ] test = [ "coverage<8", + "psutil<6", "pytest<8", ] [project.scripts] diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 0000000..befb718 --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,57 @@ +import logging +import sys +import time +from pathlib import Path + +import mlflow +import pytest +import sqlalchemy as sa + +from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables +from tests.util import process + +# The canonical database schema used for example purposes is `examples`. +DB_URI = "crate://crate@localhost/?schema=examples" +MLFLOW_TRACKING_URI = "http://127.0.0.1:5000" + + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def engine(): + yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI) + + +def test_tracking_merlion(engine: sa.Engine): + _setup_db_drop_tables(engine=engine) + _setup_db_create_tables(engine=engine) + tracking_merlion = Path(__file__).parent.parent.joinpath("examples").joinpath("tracking_merlion.py") + cmd_server = [ + "mlflow-cratedb", + "server", + "--workers=1", + f"--backend-store-uri={DB_URI}", + "--gunicorn-opts='--log-level=debug'", + ] + cmd_client = [ + sys.executable, + tracking_merlion, + ] + + logger.info("Starting server") + with process(cmd_server, stdout=sys.stdout.buffer, stderr=sys.stderr.buffer, close_fds=True) as server_process: + logger.info(f"Started server with process id: {server_process.pid}") + # TODO: Wait for HTTP response. + time.sleep(4) + logger.info("Starting client") + with process( + cmd_client, + env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI}, + stdout=sys.stdout.buffer, + stderr=sys.stderr.buffer, + ) as client_process: + client_process.wait(timeout=120) + assert client_process.returncode == 0 + + # TODO: Verify database content. diff --git a/tests/util.py b/tests/util.py index 8741ada..201bef3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,6 +1,9 @@ # Source: mlflow:tests/integration/utils.py and mlflow:tests/store/tracking/test_file_store.py +import subprocess +from contextlib import contextmanager from typing import List +import psutil from click.testing import CliRunner from mlflow.entities import DatasetInput @@ -28,3 +31,23 @@ def assert_dataset_inputs_equal(inputs1: List[DatasetInput], inputs2: List[Datas tag2 = tags2[idx] assert tag1.key == tag1.key assert tag1.value == tag2.value + + +@contextmanager +def process(*args, **kwargs) -> subprocess.Popen: + """ + Wrapper around `subprocess.Popen` to also terminate child processes after exiting. + + https://gist.github.com/jizhilong/6687481#gistcomment-3057122 + """ + proc = subprocess.Popen(*args, **kwargs) # noqa: S603 + try: + yield proc + finally: + try: + children = psutil.Process(proc.pid).children(recursive=True) + except psutil.NoSuchProcess: + return + for child in children: + child.kill() + proc.kill()