Skip to content

Commit

Permalink
Refactoring and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Oct 16, 2024
1 parent 136b4e1 commit ac86669
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 145 deletions.
39 changes: 18 additions & 21 deletions recipys/ingredients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import pandas as pd
import polars as pl
from typing import overload

from pandas.io.sql import get_schema

from recipys.constants import Backend


class Ingredients:
"""Wrapper around either polars.DataFrame to store columns roles (e.g., predictor)
Due to the workings of polars, we do not subclass pl.dataframe anymore, but instead store the dataframe as an attribute.
Due to the workings of polars, we do not subclass pl.dataframe anymore,
but instead store the dataframe as an attribute.
Args:
roles: roles of DataFrame columns as (list of) strings.
Defaults to None.
Expand Down Expand Up @@ -41,13 +39,13 @@ def __init__(
elif isinstance(data, Ingredients):
self.backend = data.get_backend()
else:
raise ValueError(f"Backend not specified and could not be inferred from data.")
raise ValueError("Backend not specified and could not be inferred from data.")
else:
self.backend = backend
if isinstance(data, pd.DataFrame) or isinstance(data, pl.DataFrame):
if self.backend == Backend.POLARS:
if isinstance(data, pd.DataFrame):
self.data = pl.DataFrame(data)
self.data = pl.DataFrame(data)
elif isinstance(data, pl.DataFrame):
self.data = data
else:
Expand Down Expand Up @@ -75,8 +73,10 @@ def __init__(
self.roles = {}
elif not isinstance(roles, dict):
raise TypeError(f"Expected dict object for roles, got {roles.__class__}")
elif check_roles and not all(set(k).issubset(set(self.data.columns)) for k,v in roles.items()):
raise ValueError(f"Roles contains variable names that are not in the data {list(roles.values())} {self.data.columns}.")
elif check_roles and not all(set(k).issubset(set(self.data.columns)) for k, v in roles.items()):
raise ValueError(
f"Roles contains variable names that are not in the data {list(roles.values())} {self.data.columns}."
)
# Todo: do we want to allow ingredients without grouping columns?
# elif check_roles and select_groups(self) == []:
# raise ValueError("Roles are given but no groups are found in the data.")
Expand All @@ -94,7 +94,7 @@ def _constructor(self):
def columns(self):
return self.data.columns

def to_df(self, output_format = None) -> pl.DataFrame:
def to_df(self, output_format=None) -> pl.DataFrame:
"""Return the underlying DataFrame.
Expand All @@ -114,8 +114,6 @@ def to_df(self, output_format = None) -> pl.DataFrame:
else:
return self.data



def _check_column(self, column):
if not isinstance(column, str):
raise ValueError(f"Expected string, got {column}")
Expand Down Expand Up @@ -178,21 +176,23 @@ def update_role(self, column: str, new_role: str, old_role: str = None):
f"Attempted to update role of {column} to {new_role} but "
f"{column} has more than one current roles: {self.roles[column]}"
)
def select_dtypes(self,include=None):

def select_dtypes(self, include=None):
# if(isinstance(include,[str])):
dtypes = self.get_str_dtypes()
selected = [key for key, value in dtypes.items() if value in include]
return selected

def get_dtypes(self):
dtypes = list(self.schema.values())
return dtypes

def get_str_dtypes(self):
""""
Helper function for polar dataframes to return schema with dtypes as strings
""" "
Helper function for polar dataframes to return schema with dtypes as strings
"""
dtypes = self.get_schema()
return {key:str(value) for key,value in dtypes.items()}
return {key: str(value) for key, value in dtypes.items()}
# return list(map(dtypes, cast()))

def get_schema(self):
Expand All @@ -204,10 +204,10 @@ def get_schema(self):
def get_df(self):
return self.to_df()

def set_df(self,df):
def set_df(self, df):
self.data = df

def groupby(self,by):
def groupby(self, by):
if self.backend == Backend.POLARS:
self.data.group_by(by)
else:
Expand All @@ -230,8 +230,5 @@ def __setitem__(self, idx, val):
def __getitem__(self, list: list[str]) -> pl.DataFrame:
return self.data[list]

def __getitem__(self, idx:int) -> pl.Series:
def __getitem__(self, idx: int) -> pl.Series:
return self.data[idx]



25 changes: 15 additions & 10 deletions recipys/recipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from collections import Counter
from copy import copy, deepcopy
from copy import copy
from itertools import chain
from typing import Union

Expand All @@ -12,6 +12,7 @@
from recipys.step import Step
from recipys.constants import Backend


class Recipe:
"""Recipe for preprocessing data
Expand Down Expand Up @@ -41,8 +42,8 @@ def __init__(
if not isinstance(data, Ingredients):
try:
data = Ingredients(data, backend=backend)
except:
raise (f"Expected Ingredients, got {data.__class__}")
except Exception as e:
raise (f"Expected Ingredients, got {data.__class__} {e}")
self.data = data
self.steps = []
self.original_columns = copy(data.columns)
Expand Down Expand Up @@ -113,14 +114,16 @@ def add_step(self, step: Step) -> Recipe:
def _check_data(self, data: Union[pl.DataFrame | pd.DataFrame, Ingredients]) -> Ingredients:
if data is None:
data = self.data
elif type(data) == pl.DataFrame or type(data) == pd.DataFrame:
# this is only executed when prep or bake recieve a DF that is different to the original data
elif isinstance(data, pl.DataFrame) or isinstance(data, pd.DataFrame):
# this is only executed when prep or bake receive a DF that is different to the original data
# don't check the roles here, because self.data can have more roles than data (post feature generation)
data = Ingredients(data, roles=self.data.roles, check_roles=False)
#if not data.columns.equals(self.data.columns):
# if not data.columns.equals(self.data.columns):
if not set(data.columns) == set(self.original_columns):
raise ValueError(f"Columns of data argument differs from recipe data: "
f"{[x for x in data.columns if x not in self.original_columns]}.")
raise ValueError(
f"Columns of data argument differs from recipe data: "
f"{[x for x in data.columns if x not in self.original_columns]}."
)
return data

def _apply_group(self, data, step):
Expand All @@ -130,7 +133,9 @@ def _apply_group(self, data, step):
data.groupby(group_vars)
return data

def prep(self, data: Union[pl.DataFrame | pd.DataFrame, Ingredients] = None, refit: bool = False) -> pl.DataFrame | pd.DataFrame:
def prep(
self, data: Union[pl.DataFrame | pd.DataFrame, Ingredients] = None, refit: bool = False
) -> pl.DataFrame | pd.DataFrame:
"""Fits and transforms, in other words preps, the data.
Args:
Expand All @@ -144,7 +149,7 @@ def prep(self, data: Union[pl.DataFrame | pd.DataFrame, Ingredients] = None, ref
# Todo: check why the roles dissapear after copying
data = copy(data)
data = self._apply_fit_transform(data, refit)
#return pl.DataFrame(data)
# return pl.DataFrame(data)
return data.get_df()

def bake(self, data: Union[pl.DataFrame | pd.DataFrame, Ingredients] = None) -> pl.DataFrame | pd.DataFrame:
Expand Down
19 changes: 10 additions & 9 deletions recipys/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from polars import DataType
from recipys.constants import Backend


class Selector:
"""Class responsible for selecting the variables affected by a recipe step
Expand All @@ -16,12 +17,12 @@ class Selector:
"""

def __init__(
self,
description: str,
names: Union[str, list[str]] = None,
roles: Union[str, list[str]] = None,
types: Union[str, list[str]] = None,
pattern: re.Pattern = None,
self,
description: str,
names: Union[str, list[str]] = None,
roles: Union[str, list[str]] = None,
types: Union[str, list[str]] = None,
pattern: re.Pattern = None,
):
self.description = description
self.set_names(names)
Expand Down Expand Up @@ -54,7 +55,6 @@ def set_types(self, roles: Union[str, list[str]]):
self.types = enlist_str(roles)
# self.types = enlist_dt(roles)


def set_pattern(self, pattern: re.Pattern):
"""Set the pattern to search with this Selector
Expand Down Expand Up @@ -91,7 +91,7 @@ def __call__(self, ingr: Ingredients) -> list[str]:
vars = intersection(vars, sel_roles)

if self.types is not None:
sel_types = list(ingr.select_dtypes(include=self.types)) #.columns.tolist()
sel_types = list(ingr.select_dtypes(include=self.types)) # .columns.tolist()
vars = intersection(vars, sel_types)

if self.names is not None:
Expand Down Expand Up @@ -119,7 +119,7 @@ def enlist_dt(x: Union[DataType, list[DataType], None]) -> Union[list[DataType],
_description_
"""
if isinstance(x, DataType):
return [x]
return [x]
elif isinstance(x, list):
if not all(isinstance(i, DataType) for i in x):
raise TypeError("Only lists of datatypes are allowed.")
Expand Down Expand Up @@ -262,6 +262,7 @@ def has_type(types: Union[str, list[str]]) -> Selector:
"""
return Selector(description=f"types: {types}", types=types)


def all_predictors() -> Selector:
"""Define selector for all predictor columns
Expand Down
Loading

0 comments on commit ac86669

Please sign in to comment.