diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d2a0ef --- /dev/null +++ b/.gitignore @@ -0,0 +1,96 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ +pytest-junit.xml + +# Sphinx documentation +docs/_build/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +.python-version + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json +mypy-report.xml + +# Pyre type checker +.pyre/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# VS Code +.vscode/ diff --git a/pyproject.toml b/pyproject.toml index 1d2c5ea..e649d2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,11 +12,11 @@ dependencies = [ "dagster>=1.7.4", "mlflow>=2.12.1", "joblib>=1.4.2", - "lakefs-spec>=0.9.0" + "lakefs-spec>=0.9.0", ] [project.optional-dependencies] -dev = ["black", "ruff", "ruff-lsp"] +dev = ["ruff"] [tool.ruff] src = ["src"] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..d8eddc7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,95 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile -o requirements-dev.txt --no-annotate --extra=dev pyproject.toml +aenum==3.1.15 +alembic==1.13.1 +aniso8601==9.0.1 +annotated-types==0.7.0 +blinker==1.8.2 +cachetools==5.3.3 +certifi==2024.6.2 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +coloredlogs==14.0 +contourpy==1.2.1 +croniter==2.0.5 +cycler==0.12.1 +dagster==1.7.8 +dagster-pipes==1.7.8 +deprecated==1.2.14 +docker==7.1.0 +docstring-parser==0.16 +entrypoints==0.4 +filelock==3.14.0 +flask==3.0.3 +fonttools==4.53.0 +fsspec==2024.6.0 +gitdb==4.0.11 +gitpython==3.1.43 +graphene==3.3 +graphql-core==3.2.3 +graphql-relay==3.2.0 +grpcio==1.64.1 +grpcio-health-checking==1.62.2 +gunicorn==22.0.0 +humanfriendly==10.0 +idna==3.7 +importlib-metadata==7.1.0 +itsdangerous==2.2.0 +jinja2==3.1.4 +joblib==1.4.2 +kiwisolver==1.4.5 +lakefs==0.6.2 +lakefs-sdk==1.25.0 +lakefs-spec==0.9.0 +mako==1.3.5 +markdown==3.6 +markdown-it-py==3.0.0 +markupsafe==2.1.5 +matplotlib==3.9.0 +mdurl==0.1.2 +mlflow==2.13.1 +numpy==1.26.4 +opentelemetry-api==1.25.0 +opentelemetry-sdk==1.25.0 +opentelemetry-semantic-conventions==0.46b0 +packaging==24.0 +pandas==2.2.2 +pendulum==3.0.0 +pillow==10.3.0 +protobuf==4.25.3 +pyarrow==15.0.2 +pydantic==2.7.3 +pydantic-core==2.18.4 +pygments==2.18.0 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +pytz==2024.1 +pyyaml==6.0.1 +querystring-parser==1.2.4 +requests==2.32.3 +rich==13.7.1 +ruff==0.4.7 +scikit-learn==1.5.0 +scipy==1.13.1 +setuptools==70.0.0 +six==1.16.0 +smmap==5.0.1 +sqlalchemy==2.0.30 +sqlparse==0.5.0 +structlog==24.2.0 +tabulate==0.9.0 +threadpoolctl==3.5.0 +time-machine==2.14.1 +tomli==2.0.1 +toposort==1.10 +tqdm==4.66.4 +typing-extensions==4.12.1 +tzdata==2024.1 +universal-pathlib==0.2.2 +urllib3==2.0.7 +watchdog==4.0.1 +werkzeug==3.0.3 +wrapt==1.16.0 +zipp==3.19.1 diff --git a/requirements.txt b/requirements.txt index b8410d2..bce739e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,281 +1,94 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile --extra=dev -# +# This file was autogenerated by uv via the following command: +# uv pip compile -o requirements.txt --no-annotate --strip-extras pyproject.toml aenum==3.1.15 - # via lakefs-sdk alembic==1.13.1 - # via - # dagster - # mlflow aniso8601==9.0.1 - # via graphene annotated-types==0.6.0 - # via pydantic -attrs==23.2.0 - # via - # cattrs - # lsprotocol -black==24.4.2 - # via tentacles (pyproject.toml) blinker==1.8.1 - # via flask -cattrs==23.2.3 - # via - # lsprotocol - # pygls +cachetools==5.3.3 certifi==2024.2.2 - # via requests charset-normalizer==3.3.2 - # via requests click==8.1.7 - # via - # black - # dagster - # flask - # mlflow cloudpickle==3.0.0 - # via mlflow coloredlogs==14.0 - # via dagster contourpy==1.2.1 - # via matplotlib croniter==2.0.5 - # via dagster cycler==0.12.1 - # via matplotlib -dagster==1.7.4 - # via tentacles (pyproject.toml) -dagster-pipes==1.7.4 - # via dagster +dagster==1.7.8 +dagster-pipes==1.7.8 +deprecated==1.2.14 docker==7.0.0 - # via mlflow docstring-parser==0.16 - # via dagster entrypoints==0.4 - # via mlflow filelock==3.14.0 - # via dagster flask==3.0.3 - # via mlflow fonttools==4.51.0 - # via matplotlib fsspec==2024.3.1 - # via - # lakefs-spec - # universal-pathlib gitdb==4.0.11 - # via gitpython gitpython==3.1.43 - # via mlflow graphene==3.3 - # via mlflow graphql-core==3.2.3 - # via - # graphene - # graphql-relay graphql-relay==3.2.0 - # via graphene grpcio==1.63.0 - # via - # dagster - # grpcio-health-checking grpcio-health-checking==1.62.2 - # via dagster gunicorn==21.2.0 - # via mlflow humanfriendly==10.0 - # via coloredlogs idna==3.7 - # via requests importlib-metadata==7.1.0 - # via mlflow itsdangerous==2.2.0 - # via flask jinja2==3.1.3 - # via - # dagster - # flask - # mlflow joblib==1.4.2 - # via - # scikit-learn - # tentacles (pyproject.toml) kiwisolver==1.4.5 - # via matplotlib -lakefs==0.6.0 - # via lakefs-spec -lakefs-sdk==1.21.0 - # via lakefs +lakefs==0.6.2 +lakefs-sdk==1.25.0 lakefs-spec==0.9.0 - # via tentacles (pyproject.toml) -lsprotocol==2023.0.1 - # via - # pygls - # ruff-lsp mako==1.3.3 - # via alembic markdown==3.6 - # via mlflow markdown-it-py==3.0.0 - # via rich markupsafe==2.1.5 - # via - # jinja2 - # mako - # werkzeug matplotlib==3.8.4 - # via mlflow mdurl==0.1.2 - # via markdown-it-py -mlflow==2.12.1 - # via tentacles (pyproject.toml) -mypy-extensions==1.0.0 - # via black +mlflow==2.13.1 numpy==1.26.4 - # via - # contourpy - # matplotlib - # mlflow - # pandas - # pyarrow - # scikit-learn - # scipy +opentelemetry-api==1.25.0 +opentelemetry-sdk==1.25.0 +opentelemetry-semantic-conventions==0.46b0 packaging==24.0 - # via - # black - # dagster - # docker - # gunicorn - # matplotlib - # mlflow - # ruff-lsp pandas==2.2.2 - # via mlflow -pathspec==0.12.1 - # via black pendulum==3.0.0 - # via dagster pillow==10.3.0 - # via matplotlib -platformdirs==4.2.1 - # via black protobuf==4.25.3 - # via - # dagster - # grpcio-health-checking - # mlflow pyarrow==15.0.2 - # via mlflow pydantic==2.7.1 - # via - # dagster - # lakefs-sdk pydantic-core==2.18.2 - # via pydantic -pygls==1.3.1 - # via ruff-lsp pygments==2.17.2 - # via rich pyparsing==3.1.2 - # via matplotlib python-dateutil==2.9.0.post0 - # via - # croniter - # dagster - # lakefs-sdk - # matplotlib - # pandas - # pendulum - # time-machine python-dotenv==1.0.1 - # via dagster pytz==2024.1 - # via - # croniter - # dagster - # mlflow - # pandas pyyaml==6.0.1 - # via - # dagster - # lakefs - # mlflow querystring-parser==1.2.4 - # via mlflow requests==2.31.0 - # via - # dagster - # docker - # mlflow rich==13.7.1 - # via dagster -ruff==0.4.2 - # via - # ruff-lsp - # tentacles (pyproject.toml) -ruff-lsp==0.0.53 - # via tentacles (pyproject.toml) scikit-learn==1.4.2 - # via mlflow scipy==1.13.0 - # via - # mlflow - # scikit-learn +setuptools==70.0.0 six==1.16.0 - # via - # python-dateutil - # querystring-parser smmap==5.0.1 - # via gitdb sqlalchemy==2.0.29 - # via - # alembic - # dagster - # mlflow sqlparse==0.5.0 - # via mlflow structlog==24.1.0 - # via dagster tabulate==0.9.0 - # via dagster threadpoolctl==3.5.0 - # via scikit-learn time-machine==2.14.1 - # via pendulum tomli==2.0.1 - # via dagster toposort==1.10 - # via dagster tqdm==4.66.4 - # via dagster typing-extensions==4.11.0 - # via - # alembic - # dagster - # pydantic - # pydantic-core - # ruff-lsp - # sqlalchemy tzdata==2024.1 - # via - # pandas - # pendulum universal-pathlib==0.2.2 - # via dagster urllib3==2.0.7 - # via - # docker - # lakefs-sdk - # requests watchdog==4.0.0 - # via dagster werkzeug==3.0.2 - # via flask +wrapt==1.16.0 zipp==3.18.1 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/src/tentacles/resources/mlflow_session.py b/src/tentacles/resources/mlflow_session.py index 0c23126..4ddf03d 100644 --- a/src/tentacles/resources/mlflow_session.py +++ b/src/tentacles/resources/mlflow_session.py @@ -22,12 +22,21 @@ class MlflowSession(ConfigurableResource): Optional password for authenticating against the MLflow tracking server. experiment : str Experiment name. + use_asset_run_key : bool + Whether the Dagster asset key should be included in the MLflow run name. + use_dagster_run_id : bool + Whether the Dagster run ID should be included in the MLflow run name. + run_name_prefix : Optional[str] + Optional prefix for the MLflow run name. """ tracking_url: str username: Optional[str] password: Optional[str] experiment: str + use_asset_run_key: bool = True + use_dagster_run_id: bool = True + run_name_prefix: Optional[str] def setup_for_execution(self, context: InitResourceContext) -> None: """Setup the resource. @@ -53,8 +62,8 @@ def _get_run_name_from_context( The run name is constructed as follows: - The run name prefix (when provided) - - The asset key name - - The run identifier + - The asset key name (if ``self.use_asset_run_key == True``) + - The run identifier (if ``self.use_dagster_run_id == True``) Parameters ---------- @@ -68,14 +77,30 @@ def _get_run_name_from_context( str Run name """ - asset_key = get_asset_key(context) - dagster_run_id = get_run_id(context, short=True) - run_name = f"{asset_key}-{dagster_run_id}" + if run_name_prefix is None: + run_name_prefix = self.run_name_prefix + + parts: list[str] = [] + + if self.use_asset_run_key: + asset_key = get_asset_key(context) + if isinstance(asset_key, list): + raise ValueError("Can not derive MLflow run name for multi-assets.") + + parts.append(asset_key) + + if self.use_dagster_run_id: + run_id = get_run_id(context, short=True) + parts.append(run_id) + if run_name_prefix is not None: - run_name = f"{run_name_prefix}-{run_name}" + parts.append(run_name_prefix) + + if len(parts) == 0: + raise ValueError("Could not derive MLflow run name.") - return run_name + return "-".join(parts) def get_run( self, @@ -113,7 +138,12 @@ def get_run( return mlflow.start_run(run_id=run_id, run_name=run_name) else: tags["dagster.run_id"] = get_run_id(context) - tags["dagster.asset_name"] = get_asset_key(context) + + # MLflow tags must be strings, coerce multi-asset keys to a comma-separated list + asset_key = get_asset_key(context) + if isinstance(asset_key, list): + asset_key = ",".join(asset_key) + tags["dagster.asset_name"] = asset_key return mlflow.start_run(run_name=run_name, tags=tags) diff --git a/src/tentacles/utils/dagster.py b/src/tentacles/utils/dagster.py index 4260f98..2decbb0 100644 --- a/src/tentacles/utils/dagster.py +++ b/src/tentacles/utils/dagster.py @@ -28,7 +28,7 @@ def get_run_id(context: AssetExecutionContext, short: bool = False) -> str: return run_id[:SHORT_RUN_ID_LENGTH] if short else run_id -def get_asset_key(context: AssetExecutionContext) -> str: +def get_asset_key(context: AssetExecutionContext) -> str | list[str]: """Get the asset key from the dagster context. Parameters @@ -41,12 +41,18 @@ def get_asset_key(context: AssetExecutionContext) -> str: str Asset key. """ - return context.asset_key.to_user_string() + outputs = context.selected_output_names + if len(outputs) == 1: + return context.asset_key.to_user_string() + return [ + context.asset_key_for_output(o).to_user_string() + for o in outputs + ] def get_metadata(context: Union[OutputContext, InputContext]) -> dict: """Get metadata from the dagster context. - + Parameters ---------- context : Union[OutputContext, InputContext]