diff --git a/src/matchbox/server/postgresql/utils/selector.py b/src/matchbox/server/postgresql/utils/selector.py index ec7cffb..99d213c 100644 --- a/src/matchbox/server/postgresql/utils/selector.py +++ b/src/matchbox/server/postgresql/utils/selector.py @@ -3,6 +3,7 @@ import pyarrow as pa from pandas import DataFrame from sqlalchemy import Engine, func, select +from sqlalchemy.dialects.postgresql import array from sqlalchemy.orm import Session, aliased from sqlalchemy.sql.selectable import Select @@ -17,6 +18,16 @@ T = TypeVar("T") +def hash_to_hex_decode(hash: bytes) -> bytes: + """A workround for PostgreSQL so we can compile the query and use ConnectorX.""" + return func.decode(hash.hex(), "hex") + + +def key_to_sqlalchemy_label(key: str, source: Source) -> str: + """Converts a key to the SQLAlchemy LABEL_STYLE_TABLENAME_PLUS_COL.""" + return f"{source.db_schema}_{source.db_table}_{key}" + + def get_all_parents(model: Models | list[Models]) -> list[Models]: """ Takes a Models object and returns all items in its parent tree. @@ -118,12 +129,15 @@ def _tree_to_reachable_stmt(model_tree: list[bytes]) -> Select: c1 = aliased(Clusters) c2 = aliased(Clusters) + bytea_array = array([hash_to_hex_decode(m) for m in model_tree]) + subquery = select(func.unnest(bytea_array).label("hash")).subquery() + dd_stmt = ( select(DDupeContains.parent, DDupeContains.child) .join(c1, DDupeContains.parent == c1.sha1) .join(clusters_association, clusters_association.c.child == c1.sha1) .join(Models, clusters_association.c.parent == Models.sha1) - .where(Models.sha1.in_(model_tree)) + .where(Models.sha1.in_(select(subquery.c.hash))) ) lk_stmt = ( @@ -132,7 +146,7 @@ def _tree_to_reachable_stmt(model_tree: list[bytes]) -> Select: .join(c2, LinkContains.child == c2.sha1) .join(clusters_association, clusters_association.c.child == c1.sha1) .join(Models, clusters_association.c.parent == Models.sha1) - .where(Models.sha1.in_(model_tree)) + .where(Models.sha1.in_(select(subquery.c.hash))) ) return dd_stmt.union(lk_stmt) @@ -164,7 +178,7 @@ def _reachable_to_parent_data_stmt( .join(Clusters, Clusters.sha1 == allowed.c.parent) .join(clusters_association, clusters_association.c.child == Clusters.sha1) .join(Models, clusters_association.c.parent == Models.sha1) - .where(Models.sha1 == parent_hash) + .where(Models.sha1 == hash_to_hex_decode(parent_hash)) .cte("root") ) @@ -236,7 +250,7 @@ def _model_to_hashes( the hash key of the SourceData """ parent, child = _parent_to_tree(model, engine=engine) - if len(parent) == 0: + if not parent: raise ValueError(f"Model {model} not found") tree = [parent] + child reachable_stmt = _tree_to_reachable_stmt(tree) @@ -316,17 +330,20 @@ def query( fields=set([source.db_pk] + fields), pks=mb_hashes["id"].to_pylist() ) - # Tablename plus column SQLAlchemy label style - right_key = f"{source.db_schema}_{source.db_table}_{source.db_pk}" - joined_table = raw_data.join( right_table=mb_hashes, - keys=right_key, + keys=key_to_sqlalchemy_label(key=source.db_pk, source=source), right_keys="id", join_type="inner", ) - tables.append(joined_table) + # Keep only the columns we want + keep_cols = ["cluster_hash", "data_hash"] + [ + key_to_sqlalchemy_label(f, source) for f in fields + ] + match_cols = [col for col in joined_table.column_names if col in keep_cols] + + tables.append(joined_table.select(match_cols)) result = pa.concat_tables(tables, promote_options="default") diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index bf7ed6e..3e89665 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -150,7 +150,7 @@ def test_query_with_dedupe_model( select_crn = selector( table=str(crn), - fields=["id", "crn"], + fields=["company_name", "crn"], engine=crn.database.engine, ) @@ -164,13 +164,13 @@ def test_query_with_dedupe_model( assert isinstance(df_crn, DataFrame) assert df_crn.shape[0] == 3000 assert set(df_crn.columns) == { - "cluster_sha1", - "data_sha1", + "cluster_hash", + "data_hash", "test_crn_crn", "test_crn_company_name", } - assert df_crn.data_sha1.nunique() == 3000 - assert df_crn.cluster_sha1.nunique() == 1000 + assert df_crn.data_hash.nunique() == 3000 + assert df_crn.cluster_hash.nunique() == 1000 def test_query_with_link_model():