Skip to content

Commit

Permalink
Polars writing (#460)
Browse files Browse the repository at this point in the history
* Hacking at polars support

* Add tests and add polars to env

* Add reader and tests
  • Loading branch information
stephenpardy authored Jul 1, 2024
1 parent 8dc219d commit 3748bf6
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 25 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies:
- pytest
- pytest-cov
- xgboost
- polars<1.0

# for versioning
- versioneer
Expand Down
20 changes: 14 additions & 6 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
if TYPE_CHECKING:
import dask.dataframe as dd
import pandas as pd
import polars as pl

from rubicon_ml.client import Artifact, Dataframe
from rubicon_ml.domain import DOMAIN_TYPES
Expand Down Expand Up @@ -306,7 +307,10 @@ def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact:

@failsafe
def artifacts(
self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or"
self,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
qtype: str = "or",
) -> List[Artifact]:
"""Get the artifacts logged to this client object.
Expand Down Expand Up @@ -380,7 +384,8 @@ def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Arti
for repo in self.repositories:
try:
artifact = client.Artifact(
repo.get_artifact_metadata(project_name, id, experiment_id), self
repo.get_artifact_metadata(project_name, id, experiment_id),
self,
)
except Exception as err:
return_err = err
Expand Down Expand Up @@ -455,7 +460,7 @@ class DataframeMixin:
@failsafe
def log_dataframe(
self,
df: Union[pd.DataFrame, dd.DataFrame],
df: Union[pd.DataFrame, "dd.DataFrame", "pl.DataFrame"],
description: Optional[str] = None,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
Expand All @@ -465,8 +470,8 @@ def log_dataframe(
Parameters
----------
df : pandas.DataFrame or dask.dataframe.DataFrame
The `dask` or `pandas` dataframe to log.
df : pandas.DataFrame, dask.dataframe.DataFrame, or polars DataFrame
The dataframe to log.
description : str, optional
The dataframe's description. Use to provide
additional context.
Expand Down Expand Up @@ -508,7 +513,10 @@ def log_dataframe(

@failsafe
def dataframes(
self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or"
self,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
qtype: str = "or",
) -> List[Dataframe]:
"""Get the dataframes logged to this client object.
Expand Down
64 changes: 49 additions & 15 deletions rubicon_ml/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from datetime import datetime
from json import JSONDecodeError
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union
from zipfile import ZipFile

import fsspec
Expand All @@ -14,6 +14,10 @@
from rubicon_ml.exceptions import RubiconException
from rubicon_ml.repository.utils import json, slugify

if TYPE_CHECKING:
import dask.dataframe as dd
import polars as pl


class BaseRepository:
"""The base repository defines all the shared interactions
Expand Down Expand Up @@ -167,13 +171,13 @@ def _load_metadata_files(

# -------- Projects --------

def _get_project_metadata_path(self, project_name):
def _get_project_metadata_path(self, project_name: str):
"""Returns the path of the project with name `project_name`'s
metadata.
"""
return f"{self.root_dir}/{slugify(project_name)}/metadata.json"

def create_project(self, project):
def create_project(self, project: domain.Project):
"""Persist a project to the configured filesystem.
Parameters
Expand Down Expand Up @@ -221,7 +225,7 @@ def get_projects(self) -> List[domain.Project]:

# ------ Experiments -------

def _get_experiment_metadata_root(self, project_name):
def _get_experiment_metadata_root(self, project_name: str):
"""Returns the experiments directory of the project with
name `project_name`.
"""
Expand All @@ -235,7 +239,7 @@ def _get_experiment_metadata_path(self, project_name, experiment_id):

return f"{experiment_metadata_root}/{experiment_id}/metadata.json"

def create_experiment(self, experiment):
def create_experiment(self, experiment: domain.Experiment):
"""Persist an experiment to the configured filesystem.
Parameters
Expand Down Expand Up @@ -589,7 +593,9 @@ def _get_dataframe_data_path(self, project_name, experiment_id, dataframe_id):

return f"{dataframe_metadata_root}/{dataframe_id}/data"

def _persist_dataframe(self, df, path):
def _persist_dataframe(
self, df: Union[pd.DataFrame, "dd.DataFrame", "pl.DataFrame"], path: str
):
"""Persists the dataframe `df` to the configured filesystem.
Note
Expand All @@ -602,19 +608,33 @@ def _persist_dataframe(self, df, path):
self._mkdir(path)
path = f"{path}/data.parquet"

df.to_parquet(path, engine="pyarrow")
if hasattr(df, "write_parquet"):
# handle Polars
df.write_parquet(path)
else:
# Dask or pandas
df.to_parquet(path, engine="pyarrow")

def _read_dataframe(self, path, df_type="pandas"):
def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "pandas"):
"""Reads the dataframe `df` from the configured filesystem."""
df = None
acceptable_types = ["pandas", "dask"]
if df_type not in acceptable_types:
raise ValueError(f"`df_type` must be one of {acceptable_types}")
acceptable_types = ["pandas", "dask", "polars"]

if df_type == "pandas":
path = f"{path}/data.parquet"
df = pd.read_parquet(path, engine="pyarrow")
else:
elif df_type == "polars":
try:
from polars import read_parquet
except ImportError:
raise RubiconException(
"`rubicon_ml` requires `polars` to be installed in the current environment "
"to read dataframes with `df_type`='polars'. `pip install polars` "
"or `conda install polars` to continue."
)
df = read_parquet(path)

elif df_type == "dask":
try:
from dask import dataframe as dd
except ImportError:
Expand All @@ -625,10 +645,18 @@ def _read_dataframe(self, path, df_type="pandas"):
)

df = dd.read_parquet(path, engine="pyarrow")
else:
raise ValueError(f"`df_type` must be one of {acceptable_types}")

return df

def create_dataframe(self, dataframe, data, project_name, experiment_id=None):
def create_dataframe(
self,
dataframe: domain.Dataframe,
data: Union[pd.DataFrame, "dd.DataFrame", "pl.DataFrame"],
project_name: str,
experiment_id: Optional[str] = None,
):
"""Persist a dataframe to the configured filesystem.
Parameters
Expand Down Expand Up @@ -709,7 +737,13 @@ def get_dataframes_metadata(

return self._load_metadata_files(dataframe_metadata_root, domain.Dataframe)

def get_dataframe_data(self, project_name, dataframe_id, experiment_id=None, df_type="pandas"):
def get_dataframe_data(
self,
project_name: str,
dataframe_id: str,
experiment_id: Optional[str] = None,
df_type: Literal["pandas", "dask", "polars"] = "pandas",
):
"""Retrieve a dataframe's raw data.
Parameters
Expand All @@ -724,7 +758,7 @@ def get_dataframe_data(self, project_name, dataframe_id, experiment_id=None, df_
`artifact_id` is logged to. Dataframes do not
need to belong to an experiment.
df_type : str, optional
The type of dataframe. Can be either `pandas` or `dask`.
The type of dataframe. Can be `pandas`, `dask`, or `polars`.
Returns
-------
Expand Down
75 changes: 71 additions & 4 deletions tests/unit/repository/test_base_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch

import pandas as pd
import polars as pl
import pytest
from dask import dataframe as dd

Expand All @@ -26,7 +27,10 @@ def _create_experiment(repository, project=None, tags=[], comments=[]):
project = _create_project(repository)

experiment = domain.Experiment(
name=f"Test Experiment {uuid.uuid4()}", project_name=project.name, tags=[], comments=[]
name=f"Test Experiment {uuid.uuid4()}",
project_name=project.name,
tags=[],
comments=[],
)
repository.create_experiment(experiment)

Expand All @@ -52,7 +56,8 @@ def _create_pandas_dataframe(repository, project=None, dataframe_data=None, mult

if dataframe_data is None:
dataframe_data = pd.DataFrame(
[[0, 1, "a"], [1, 1, "b"], [2, 2, "c"], [3, 2, "d"]], columns=["a", "b", "c"]
[[0, 1, "a"], [1, 1, "b"], [2, 2, "c"], [3, 2, "d"]],
columns=["a", "b", "c"],
)
if multi_index:
dataframe_data = dataframe_data.set_index(["b", "a"]) # Set multiindex
Expand All @@ -76,6 +81,24 @@ def _create_dask_dataframe(repository, project=None):
return dataframe


def _create_polars_dataframe(repository, project=None):
if project is None:
project = _create_project(repository)

df = pl.DataFrame(
{
"a": [0, 1, 2, 3],
"b": [1, 1, 2, 2],
"c": ["a", "b", "c", "d"],
}
)

dataframe = domain.Dataframe(parent_id=project.id)
repository.create_dataframe(dataframe, df, project.name)

return dataframe


def _create_feature(repository, experiment=None):
if experiment is None:
experiment = _create_experiment(repository)
Expand Down Expand Up @@ -383,6 +406,18 @@ def test_persist_dataframe(mock_to_parquet, memory_repository):
mock_to_parquet.assert_called_once_with(f"{path}/data.parquet", engine="pyarrow")


@patch("polars.DataFrame.write_parquet")
def test_persist_dataframe_polars(mock_write_parquet, memory_repository):
repository = memory_repository
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
path = "./local/root"

# calls `BaseRepository._persist_dataframe` despite class using `MemoryRepository`
super(MemoryRepository, repository)._persist_dataframe(df, path)

mock_write_parquet.assert_called_once_with(f"{path}")


@patch("pandas.read_parquet")
def test_read_dataframe(mock_read_parquet, memory_repository):
repository = memory_repository
Expand Down Expand Up @@ -488,6 +523,36 @@ def test_create_dask_dataframe(memory_repository):
assert dataframe.id == dataframe_json["id"]


def test_create_polars_dataframe(memory_repository):
repository = memory_repository
project = _create_project(repository)
dataframe = _create_polars_dataframe(repository, project=project)

dataframe_root = f"{repository.root_dir}/{slugify(project.name)}/dataframes/{dataframe.id}"
dataframe_metadata_path = f"{dataframe_root}/metadata.json"
dataframe_data_path = f"{dataframe_root}/data"

open_file = repository.filesystem.open(dataframe_metadata_path)
with open_file as f:
dataframe_json = json.load(f)

assert repository.filesystem.exists(dataframe_data_path)
assert dataframe.id == dataframe_json["id"]


def test_get_polars_dataframe(memory_repository):
repository = memory_repository
project = _create_project(repository)
written_dataframe = _create_polars_dataframe(repository, project=project)
dataframe = repository.get_dataframe_metadata(project.name, written_dataframe.id)

data = repository.get_dataframe_data(project.name, written_dataframe.id, df_type="polars")
assert not data.is_empty()

assert dataframe.id == written_dataframe.id
assert dataframe.parent_id == written_dataframe.parent_id


def test_get_pandas_dataframe(memory_repository):
repository = memory_repository
project = _create_project(repository)
Expand Down Expand Up @@ -520,7 +585,7 @@ def test_get_dask_dataframe(memory_repository):
written_dataframe = _create_dask_dataframe(repository, project=project)
dataframe = repository.get_dataframe_metadata(project.name, written_dataframe.id)

data = repository.get_dataframe_data(project.name, written_dataframe.id)
data = repository.get_dataframe_data(project.name, written_dataframe.id, df_type="dask")
assert not data.compute().empty

assert dataframe.id == written_dataframe.id
Expand Down Expand Up @@ -858,7 +923,9 @@ def test_get_dataframe_tags_with_project_parent_root(memory_repository):
project = _create_project(repository)
dataframe = _create_pandas_dataframe(repository, project=project)
dataframe_tags_root = repository._get_tag_metadata_root(
project.name, entity_identifier=dataframe.id, entity_type=dataframe.__class__.__name__
project.name,
entity_identifier=dataframe.id,
entity_type=dataframe.__class__.__name__,
)

assert (
Expand Down

0 comments on commit 3748bf6

Please sign in to comment.