diff --git a/10_discovery.py b/10_discovery.py index 5fc411de4..be229d9bd 100644 --- a/10_discovery.py +++ b/10_discovery.py @@ -18,5 +18,7 @@ print("fitting") discover.fit(base_table) print("transforming") -joined_table = discover.transform(base_table) -print(joined_table) +ranking = discover.transform(base_table) +print(ranking) + +# %% diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 8b19c2015..c2bdc0f66 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -406,33 +406,36 @@ def _collect_polars_lazyframe(df): # Loading data # ============ # + + +# TODO: Adding X here as a placeholder to get around the type check, @dispatch -def read_parquet(input_path): +def read_parquet(X, input_path): raise NotImplementedError() @read_parquet.specialize("pandas", argument_type=["DataFrame"]) -def _read_parquet_pandas(input_path): +def _read_parquet_pandas(X, input_path): return pd.read_parquet(input_path) @read_parquet.specialize("polars", argument_type=["DataFrame"]) -def _read_parquet_polars(input_path): +def _read_parquet_polars(X, input_path): return pl.read_parquet(input_path) @dispatch -def read_csv(input_path): +def read_csv(X, input_path): raise NotImplementedError() @read_csv.specialize("pandas", argument_type=["DataFrame"]) -def _read_csv_pandas(input_path): +def _read_csv_pandas(X, input_path): return pd.read_csv(input_path) @read_csv.specialize("polars", argument_type=["DataFrame"]) -def _read_csv_polars(input_path): +def _read_csv_polars(X, input_path): return pl.read_csv(input_path) diff --git a/skrub/_discover.py b/skrub/_discover.py index 02d4908a5..f0c70a786 100644 --- a/skrub/_discover.py +++ b/skrub/_discover.py @@ -24,7 +24,6 @@ def find_unique_values(table: pl.DataFrame, columns: list[str] = None) -> dict: Returns: A dict that contains the unique values found for each selected column. """ - print("eh") # select the columns of interest if columns is not None: # error checking columns @@ -37,7 +36,7 @@ def find_unique_values(table: pl.DataFrame, columns: list[str] = None) -> dict: else: # Selecting only columns with strings # TODO: string? categorical? both? - columns = table.select(cs.string()).columns + columns = cs.select(table, cs.string()).columns unique_values = {} # find the unique values @@ -73,13 +72,10 @@ def measure_containment_tables( for col_base, values_base in unique_values_base.items(): for col_cand, values_cand in dict_cand.items(): containment = measure_containment(values_base, values_cand) - tup = (col_base, path, col_cand, containment) - containment_list.append(tup) - # convert the containment list to a pl dataframe and return that - df_cont = pl.from_records( - containment_list, ["query_column", "cand_path", "cand_column", "containment"] - ).filter(pl.col("containment") > 0) - return df_cont + if containment > 0: + tup = (col_base, str(path), col_cand, containment) + containment_list.append(tup) + return containment_list def measure_containment(unique_values_query: set, unique_values_candidate: set): @@ -111,14 +107,14 @@ def prepare_ranking(containment_list: list[tuple], budget: int): """ # Sort the list - containment_list = containment_list.sort("containment", descending=True) + containment_list = sorted(containment_list, key=lambda x: x[3], reverse=True) # TODO: Somewhere here we might want to do some fancy filtering of the # candidates in the ranking (with profiling) # Return `budget` candidates - ranking = containment_list.top_k(budget, by="containment") - return ranking.rows() + ranking = containment_list[:budget] + return ranking def execute_join( @@ -252,7 +248,7 @@ def fit(self, X: pl.DataFrame, y=None): for table_path in self._candidate_paths: # TODO: check type of the table, is it parquet or csv? # TODO: is this going to hold everything in memory? - table = sbd.read_parquet(table_path) + table = sbd.read_parquet(X, table_path) self._unique_values_candidates[table_path] = find_unique_values(table) # find unique values for the query columns @@ -274,8 +270,10 @@ def transform(self, X) -> pl.DataFrame: Returns: pl.DataFrame: The joined table. """ - _joined = execute_join(X, self._ranking, self.multiaggjoiner_params) - return _joined + return self._ranking + + # _joined = execute_join(X, self._ranking, self.multiaggjoiner_params) + # return _joined def fit_transform(self, X, y) -> pl.DataFrame: """Execute fit and transform sequentially.