Skip to content

Commit

Permalink
Working query of deduplicated data
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Oct 21, 2024
1 parent 72c7e61 commit 9dd78b2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
35 changes: 26 additions & 9 deletions src/matchbox/server/postgresql/utils/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pyarrow as pa
from pandas import DataFrame
from sqlalchemy import Engine, func, select
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.orm import Session, aliased
from sqlalchemy.sql.selectable import Select

Expand All @@ -17,6 +18,16 @@
T = TypeVar("T")


def hash_to_hex_decode(hash: bytes) -> bytes:
"""A workround for PostgreSQL so we can compile the query and use ConnectorX."""
return func.decode(hash.hex(), "hex")


def key_to_sqlalchemy_label(key: str, source: Source) -> str:
"""Converts a key to the SQLAlchemy LABEL_STYLE_TABLENAME_PLUS_COL."""
return f"{source.db_schema}_{source.db_table}_{key}"


def get_all_parents(model: Models | list[Models]) -> list[Models]:
"""
Takes a Models object and returns all items in its parent tree.
Expand Down Expand Up @@ -118,12 +129,15 @@ def _tree_to_reachable_stmt(model_tree: list[bytes]) -> Select:
c1 = aliased(Clusters)
c2 = aliased(Clusters)

bytea_array = array([hash_to_hex_decode(m) for m in model_tree])
subquery = select(func.unnest(bytea_array).label("hash")).subquery()

dd_stmt = (
select(DDupeContains.parent, DDupeContains.child)
.join(c1, DDupeContains.parent == c1.sha1)
.join(clusters_association, clusters_association.c.child == c1.sha1)
.join(Models, clusters_association.c.parent == Models.sha1)
.where(Models.sha1.in_(model_tree))
.where(Models.sha1.in_(select(subquery.c.hash)))
)

lk_stmt = (
Expand All @@ -132,7 +146,7 @@ def _tree_to_reachable_stmt(model_tree: list[bytes]) -> Select:
.join(c2, LinkContains.child == c2.sha1)
.join(clusters_association, clusters_association.c.child == c1.sha1)
.join(Models, clusters_association.c.parent == Models.sha1)
.where(Models.sha1.in_(model_tree))
.where(Models.sha1.in_(select(subquery.c.hash)))
)

return dd_stmt.union(lk_stmt)
Expand Down Expand Up @@ -164,7 +178,7 @@ def _reachable_to_parent_data_stmt(
.join(Clusters, Clusters.sha1 == allowed.c.parent)
.join(clusters_association, clusters_association.c.child == Clusters.sha1)
.join(Models, clusters_association.c.parent == Models.sha1)
.where(Models.sha1 == parent_hash)
.where(Models.sha1 == hash_to_hex_decode(parent_hash))
.cte("root")
)

Expand Down Expand Up @@ -236,7 +250,7 @@ def _model_to_hashes(
the hash key of the SourceData
"""
parent, child = _parent_to_tree(model, engine=engine)
if len(parent) == 0:
if not parent:
raise ValueError(f"Model {model} not found")
tree = [parent] + child
reachable_stmt = _tree_to_reachable_stmt(tree)
Expand Down Expand Up @@ -316,17 +330,20 @@ def query(
fields=set([source.db_pk] + fields), pks=mb_hashes["id"].to_pylist()
)

# Tablename plus column SQLAlchemy label style
right_key = f"{source.db_schema}_{source.db_table}_{source.db_pk}"

joined_table = raw_data.join(
right_table=mb_hashes,
keys=right_key,
keys=key_to_sqlalchemy_label(key=source.db_pk, source=source),
right_keys="id",
join_type="inner",
)

tables.append(joined_table)
# Keep only the columns we want
keep_cols = ["cluster_hash", "data_hash"] + [
key_to_sqlalchemy_label(f, source) for f in fields
]
match_cols = [col for col in joined_table.column_names if col in keep_cols]

tables.append(joined_table.select(match_cols))

result = pa.concat_tables(tables, promote_options="default")

Expand Down
10 changes: 5 additions & 5 deletions test/server/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_query_with_dedupe_model(

select_crn = selector(
table=str(crn),
fields=["id", "crn"],
fields=["company_name", "crn"],
engine=crn.database.engine,
)

Expand All @@ -164,13 +164,13 @@ def test_query_with_dedupe_model(
assert isinstance(df_crn, DataFrame)
assert df_crn.shape[0] == 3000
assert set(df_crn.columns) == {
"cluster_sha1",
"data_sha1",
"cluster_hash",
"data_hash",
"test_crn_crn",
"test_crn_company_name",
}
assert df_crn.data_sha1.nunique() == 3000
assert df_crn.cluster_sha1.nunique() == 1000
assert df_crn.data_hash.nunique() == 3000
assert df_crn.cluster_hash.nunique() == 1000


def test_query_with_link_model():
Expand Down

0 comments on commit 9dd78b2

Please sign in to comment.