From 3748bf6cd0d43ae88d1b671125c3b9d3936c5e86 Mon Sep 17 00:00:00 2001 From: Stephen Pardy Date: Mon, 1 Jul 2024 14:42:33 -0500 Subject: [PATCH] Polars writing (#460) * Hacking at polars support * Add tests and add polars to env * Add reader and tests --- environment.yml | 1 + rubicon_ml/client/mixin.py | 20 +++++-- rubicon_ml/repository/base.py | 64 ++++++++++++++++----- tests/unit/repository/test_base_repo.py | 75 +++++++++++++++++++++++-- 4 files changed, 135 insertions(+), 25 deletions(-) diff --git a/environment.yml b/environment.yml index ea31b482..2a2aac77 100644 --- a/environment.yml +++ b/environment.yml @@ -39,6 +39,7 @@ dependencies: - pytest - pytest-cov - xgboost + - polars<1.0 # for versioning - versioneer diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 48c29bc5..4a4ac228 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: import dask.dataframe as dd import pandas as pd + import polars as pl from rubicon_ml.client import Artifact, Dataframe from rubicon_ml.domain import DOMAIN_TYPES @@ -306,7 +307,10 @@ def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact: @failsafe def artifacts( - self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + self, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + qtype: str = "or", ) -> List[Artifact]: """Get the artifacts logged to this client object. @@ -380,7 +384,8 @@ def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Arti for repo in self.repositories: try: artifact = client.Artifact( - repo.get_artifact_metadata(project_name, id, experiment_id), self + repo.get_artifact_metadata(project_name, id, experiment_id), + self, ) except Exception as err: return_err = err @@ -455,7 +460,7 @@ class DataframeMixin: @failsafe def log_dataframe( self, - df: Union[pd.DataFrame, dd.DataFrame], + df: Union[pd.DataFrame, "dd.DataFrame", "pl.DataFrame"], description: Optional[str] = None, name: Optional[str] = None, tags: Optional[List[str]] = None, @@ -465,8 +470,8 @@ def log_dataframe( Parameters ---------- - df : pandas.DataFrame or dask.dataframe.DataFrame - The `dask` or `pandas` dataframe to log. + df : pandas.DataFrame, dask.dataframe.DataFrame, or polars DataFrame + The dataframe to log. description : str, optional The dataframe's description. Use to provide additional context. @@ -508,7 +513,10 @@ def log_dataframe( @failsafe def dataframes( - self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + self, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + qtype: str = "or", ) -> List[Dataframe]: """Get the dataframes logged to this client object. diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index dca05a81..a5603a5c 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime from json import JSONDecodeError -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union from zipfile import ZipFile import fsspec @@ -14,6 +14,10 @@ from rubicon_ml.exceptions import RubiconException from rubicon_ml.repository.utils import json, slugify +if TYPE_CHECKING: + import dask.dataframe as dd + import polars as pl + class BaseRepository: """The base repository defines all the shared interactions @@ -167,13 +171,13 @@ def _load_metadata_files( # -------- Projects -------- - def _get_project_metadata_path(self, project_name): + def _get_project_metadata_path(self, project_name: str): """Returns the path of the project with name `project_name`'s metadata. """ return f"{self.root_dir}/{slugify(project_name)}/metadata.json" - def create_project(self, project): + def create_project(self, project: domain.Project): """Persist a project to the configured filesystem. Parameters @@ -221,7 +225,7 @@ def get_projects(self) -> List[domain.Project]: # ------ Experiments ------- - def _get_experiment_metadata_root(self, project_name): + def _get_experiment_metadata_root(self, project_name: str): """Returns the experiments directory of the project with name `project_name`. """ @@ -235,7 +239,7 @@ def _get_experiment_metadata_path(self, project_name, experiment_id): return f"{experiment_metadata_root}/{experiment_id}/metadata.json" - def create_experiment(self, experiment): + def create_experiment(self, experiment: domain.Experiment): """Persist an experiment to the configured filesystem. Parameters @@ -589,7 +593,9 @@ def _get_dataframe_data_path(self, project_name, experiment_id, dataframe_id): return f"{dataframe_metadata_root}/{dataframe_id}/data" - def _persist_dataframe(self, df, path): + def _persist_dataframe( + self, df: Union[pd.DataFrame, "dd.DataFrame", "pl.DataFrame"], path: str + ): """Persists the dataframe `df` to the configured filesystem. Note @@ -602,19 +608,33 @@ def _persist_dataframe(self, df, path): self._mkdir(path) path = f"{path}/data.parquet" - df.to_parquet(path, engine="pyarrow") + if hasattr(df, "write_parquet"): + # handle Polars + df.write_parquet(path) + else: + # Dask or pandas + df.to_parquet(path, engine="pyarrow") - def _read_dataframe(self, path, df_type="pandas"): + def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "pandas"): """Reads the dataframe `df` from the configured filesystem.""" df = None - acceptable_types = ["pandas", "dask"] - if df_type not in acceptable_types: - raise ValueError(f"`df_type` must be one of {acceptable_types}") + acceptable_types = ["pandas", "dask", "polars"] if df_type == "pandas": path = f"{path}/data.parquet" df = pd.read_parquet(path, engine="pyarrow") - else: + elif df_type == "polars": + try: + from polars import read_parquet + except ImportError: + raise RubiconException( + "`rubicon_ml` requires `polars` to be installed in the current environment " + "to read dataframes with `df_type`='polars'. `pip install polars` " + "or `conda install polars` to continue." + ) + df = read_parquet(path) + + elif df_type == "dask": try: from dask import dataframe as dd except ImportError: @@ -625,10 +645,18 @@ def _read_dataframe(self, path, df_type="pandas"): ) df = dd.read_parquet(path, engine="pyarrow") + else: + raise ValueError(f"`df_type` must be one of {acceptable_types}") return df - def create_dataframe(self, dataframe, data, project_name, experiment_id=None): + def create_dataframe( + self, + dataframe: domain.Dataframe, + data: Union[pd.DataFrame, "dd.DataFrame", "pl.DataFrame"], + project_name: str, + experiment_id: Optional[str] = None, + ): """Persist a dataframe to the configured filesystem. Parameters @@ -709,7 +737,13 @@ def get_dataframes_metadata( return self._load_metadata_files(dataframe_metadata_root, domain.Dataframe) - def get_dataframe_data(self, project_name, dataframe_id, experiment_id=None, df_type="pandas"): + def get_dataframe_data( + self, + project_name: str, + dataframe_id: str, + experiment_id: Optional[str] = None, + df_type: Literal["pandas", "dask", "polars"] = "pandas", + ): """Retrieve a dataframe's raw data. Parameters @@ -724,7 +758,7 @@ def get_dataframe_data(self, project_name, dataframe_id, experiment_id=None, df_ `artifact_id` is logged to. Dataframes do not need to belong to an experiment. df_type : str, optional - The type of dataframe. Can be either `pandas` or `dask`. + The type of dataframe. Can be `pandas`, `dask`, or `polars`. Returns ------- diff --git a/tests/unit/repository/test_base_repo.py b/tests/unit/repository/test_base_repo.py index b3086811..df3b9c88 100644 --- a/tests/unit/repository/test_base_repo.py +++ b/tests/unit/repository/test_base_repo.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pandas as pd +import polars as pl import pytest from dask import dataframe as dd @@ -26,7 +27,10 @@ def _create_experiment(repository, project=None, tags=[], comments=[]): project = _create_project(repository) experiment = domain.Experiment( - name=f"Test Experiment {uuid.uuid4()}", project_name=project.name, tags=[], comments=[] + name=f"Test Experiment {uuid.uuid4()}", + project_name=project.name, + tags=[], + comments=[], ) repository.create_experiment(experiment) @@ -52,7 +56,8 @@ def _create_pandas_dataframe(repository, project=None, dataframe_data=None, mult if dataframe_data is None: dataframe_data = pd.DataFrame( - [[0, 1, "a"], [1, 1, "b"], [2, 2, "c"], [3, 2, "d"]], columns=["a", "b", "c"] + [[0, 1, "a"], [1, 1, "b"], [2, 2, "c"], [3, 2, "d"]], + columns=["a", "b", "c"], ) if multi_index: dataframe_data = dataframe_data.set_index(["b", "a"]) # Set multiindex @@ -76,6 +81,24 @@ def _create_dask_dataframe(repository, project=None): return dataframe +def _create_polars_dataframe(repository, project=None): + if project is None: + project = _create_project(repository) + + df = pl.DataFrame( + { + "a": [0, 1, 2, 3], + "b": [1, 1, 2, 2], + "c": ["a", "b", "c", "d"], + } + ) + + dataframe = domain.Dataframe(parent_id=project.id) + repository.create_dataframe(dataframe, df, project.name) + + return dataframe + + def _create_feature(repository, experiment=None): if experiment is None: experiment = _create_experiment(repository) @@ -383,6 +406,18 @@ def test_persist_dataframe(mock_to_parquet, memory_repository): mock_to_parquet.assert_called_once_with(f"{path}/data.parquet", engine="pyarrow") +@patch("polars.DataFrame.write_parquet") +def test_persist_dataframe_polars(mock_write_parquet, memory_repository): + repository = memory_repository + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + path = "./local/root" + + # calls `BaseRepository._persist_dataframe` despite class using `MemoryRepository` + super(MemoryRepository, repository)._persist_dataframe(df, path) + + mock_write_parquet.assert_called_once_with(f"{path}") + + @patch("pandas.read_parquet") def test_read_dataframe(mock_read_parquet, memory_repository): repository = memory_repository @@ -488,6 +523,36 @@ def test_create_dask_dataframe(memory_repository): assert dataframe.id == dataframe_json["id"] +def test_create_polars_dataframe(memory_repository): + repository = memory_repository + project = _create_project(repository) + dataframe = _create_polars_dataframe(repository, project=project) + + dataframe_root = f"{repository.root_dir}/{slugify(project.name)}/dataframes/{dataframe.id}" + dataframe_metadata_path = f"{dataframe_root}/metadata.json" + dataframe_data_path = f"{dataframe_root}/data" + + open_file = repository.filesystem.open(dataframe_metadata_path) + with open_file as f: + dataframe_json = json.load(f) + + assert repository.filesystem.exists(dataframe_data_path) + assert dataframe.id == dataframe_json["id"] + + +def test_get_polars_dataframe(memory_repository): + repository = memory_repository + project = _create_project(repository) + written_dataframe = _create_polars_dataframe(repository, project=project) + dataframe = repository.get_dataframe_metadata(project.name, written_dataframe.id) + + data = repository.get_dataframe_data(project.name, written_dataframe.id, df_type="polars") + assert not data.is_empty() + + assert dataframe.id == written_dataframe.id + assert dataframe.parent_id == written_dataframe.parent_id + + def test_get_pandas_dataframe(memory_repository): repository = memory_repository project = _create_project(repository) @@ -520,7 +585,7 @@ def test_get_dask_dataframe(memory_repository): written_dataframe = _create_dask_dataframe(repository, project=project) dataframe = repository.get_dataframe_metadata(project.name, written_dataframe.id) - data = repository.get_dataframe_data(project.name, written_dataframe.id) + data = repository.get_dataframe_data(project.name, written_dataframe.id, df_type="dask") assert not data.compute().empty assert dataframe.id == written_dataframe.id @@ -858,7 +923,9 @@ def test_get_dataframe_tags_with_project_parent_root(memory_repository): project = _create_project(repository) dataframe = _create_pandas_dataframe(repository, project=project) dataframe_tags_root = repository._get_tag_metadata_root( - project.name, entity_identifier=dataframe.id, entity_type=dataframe.__class__.__name__ + project.name, + entity_identifier=dataframe.id, + entity_type=dataframe.__class__.__name__, ) assert (