Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rcap107 committed Nov 21, 2024
1 parent fc064cd commit 920546b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
6 changes: 4 additions & 2 deletions 10_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@
print("fitting")
discover.fit(base_table)
print("transforming")
joined_table = discover.transform(base_table)
print(joined_table)
ranking = discover.transform(base_table)
print(ranking)

# %%
15 changes: 9 additions & 6 deletions skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,33 +406,36 @@ def _collect_polars_lazyframe(df):
# Loading data
# ============
#


# TODO: Adding X here as a placeholder to get around the type check,
@dispatch
def read_parquet(input_path):
def read_parquet(X, input_path):
raise NotImplementedError()


@read_parquet.specialize("pandas", argument_type=["DataFrame"])
def _read_parquet_pandas(input_path):
def _read_parquet_pandas(X, input_path):
return pd.read_parquet(input_path)


@read_parquet.specialize("polars", argument_type=["DataFrame"])
def _read_parquet_polars(input_path):
def _read_parquet_polars(X, input_path):
return pl.read_parquet(input_path)


@dispatch
def read_csv(input_path):
def read_csv(X, input_path):
raise NotImplementedError()


@read_csv.specialize("pandas", argument_type=["DataFrame"])
def _read_csv_pandas(input_path):
def _read_csv_pandas(X, input_path):
return pd.read_csv(input_path)


@read_csv.specialize("polars", argument_type=["DataFrame"])
def _read_csv_polars(input_path):
def _read_csv_polars(X, input_path):
return pl.read_csv(input_path)


Expand Down
28 changes: 13 additions & 15 deletions skrub/_discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def find_unique_values(table: pl.DataFrame, columns: list[str] = None) -> dict:
Returns:
A dict that contains the unique values found for each selected column.
"""
print("eh")
# select the columns of interest
if columns is not None:
# error checking columns
Expand All @@ -37,7 +36,7 @@ def find_unique_values(table: pl.DataFrame, columns: list[str] = None) -> dict:
else:
# Selecting only columns with strings
# TODO: string? categorical? both?
columns = table.select(cs.string()).columns
columns = cs.select(table, cs.string()).columns

unique_values = {}
# find the unique values
Expand Down Expand Up @@ -73,13 +72,10 @@ def measure_containment_tables(
for col_base, values_base in unique_values_base.items():
for col_cand, values_cand in dict_cand.items():
containment = measure_containment(values_base, values_cand)
tup = (col_base, path, col_cand, containment)
containment_list.append(tup)
# convert the containment list to a pl dataframe and return that
df_cont = pl.from_records(
containment_list, ["query_column", "cand_path", "cand_column", "containment"]
).filter(pl.col("containment") > 0)
return df_cont
if containment > 0:
tup = (col_base, str(path), col_cand, containment)
containment_list.append(tup)
return containment_list


def measure_containment(unique_values_query: set, unique_values_candidate: set):
Expand Down Expand Up @@ -111,14 +107,14 @@ def prepare_ranking(containment_list: list[tuple], budget: int):
"""

# Sort the list
containment_list = containment_list.sort("containment", descending=True)
containment_list = sorted(containment_list, key=lambda x: x[3], reverse=True)

# TODO: Somewhere here we might want to do some fancy filtering of the
# candidates in the ranking (with profiling)

# Return `budget` candidates
ranking = containment_list.top_k(budget, by="containment")
return ranking.rows()
ranking = containment_list[:budget]
return ranking


def execute_join(
Expand Down Expand Up @@ -252,7 +248,7 @@ def fit(self, X: pl.DataFrame, y=None):
for table_path in self._candidate_paths:
# TODO: check type of the table, is it parquet or csv?
# TODO: is this going to hold everything in memory?
table = sbd.read_parquet(table_path)
table = sbd.read_parquet(X, table_path)
self._unique_values_candidates[table_path] = find_unique_values(table)

# find unique values for the query columns
Expand All @@ -274,8 +270,10 @@ def transform(self, X) -> pl.DataFrame:
Returns:
pl.DataFrame: The joined table.
"""
_joined = execute_join(X, self._ranking, self.multiaggjoiner_params)
return _joined
return self._ranking

# _joined = execute_join(X, self._ranking, self.multiaggjoiner_params)
# return _joined

def fit_transform(self, X, y) -> pl.DataFrame:
"""Execute fit and transform sequentially.
Expand Down

0 comments on commit 920546b

Please sign in to comment.