Skip to content

Commit

Permalink
Minor refactor of query so it's easier to read
Browse files Browse the repository at this point in the history
  • Loading branch information
wpfl-dbt committed Dec 8, 2024
1 parent 6a936d7 commit 79ee5b5
Showing 1 changed file with 40 additions and 61 deletions.
101 changes: 40 additions & 61 deletions src/matchbox/server/postgresql/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,31 @@ def key_to_sqlalchemy_label(key: str, source: Source) -> str:
return f"{source.db_schema}_{source.db_table}_{key}"


def source_to_dataset_model(source: Source | str, session: Session) -> Models:
"""Converts a Source object to a Sources ORM object."""
if isinstance(source, str):
source_schema, source_table = get_schema_table_names(source, validate=True)
else:
source_schema, source_table = source.db_schema, source.db_table

source_dataset = (
session.query(Models)
.join(Sources, Sources.model == Models.hash)
.filter(
Sources.schema == source_schema,
Sources.table == source_table,
)
.first()
)
if source_dataset is None:
raise MatchboxDatasetError(
db_schema=source_schema,
db_table=source_table,
)

return source_dataset


def _resolve_thresholds(
lineage_truths: dict[str, float],
model: Models,
Expand Down Expand Up @@ -294,26 +319,10 @@ def query(

# Process each source dataset
for source, fields in selector.items():
# Get the dataset model
dataset = (
session.query(Models)
.join(Sources, Sources.model == Models.hash)
.filter(
Sources.schema == source.db_schema,
Sources.table == source.db_table,
Sources.id == source.db_pk,
)
.first()
)

if dataset is None:
raise MatchboxDatasetError(
db_schema=source.db_schema, db_table=source.db_table
)

dataset_model = source_to_dataset_model(source, session)
hash_query = _resolve_cluster_hierarchy(
dataset_hash=dataset.hash,
model=truth_model if truth_model else dataset,
dataset_hash=dataset_model.hash,
model=truth_model if truth_model else dataset_model,
threshold=threshold,
engine=engine,
)
Expand Down Expand Up @@ -537,48 +546,19 @@ def match(
* Returns the results as Match objects, one per target
"""
# Split source and target into schema/table
source_schema, source_table = get_schema_table_names(source, validate=True)
targets = [target] if isinstance(target, str) else target
target_pairs = [get_schema_table_names(t, validate=True) for t in targets]

with Session(engine) as session:
# Get truth model
# Get source, target and truth models
source_model = source_to_dataset_model(source, session)
target_models = []
for target in targets:
target_models.append(source_to_dataset_model(target, session))
truth_model = session.query(Models).filter(Models.name == model).first()
if truth_model is None:
raise MatchboxModelError(f"Model {model} not found")

# Get source dataset
source_dataset = (
session.query(Models)
.join(Sources, Sources.model == Models.hash)
.filter(
Sources.schema == source_schema,
Sources.table == source_table,
)
.first()
)
if source_dataset is None:
raise MatchboxDatasetError(
db_schema=source_schema,
db_table=source_table,
)

# Get target datasets
target_datasets = []
for schema, table in target_pairs:
dataset = (
session.query(Models)
.join(Sources, Sources.model == Models.hash)
.filter(
Sources.schema == schema,
Sources.table == table,
)
.first()
)
if dataset is None:
raise MatchboxDatasetError(db_schema=schema, db_table=table)
target_datasets.append(dataset)

# Get model lineage and resolve thresholds
lineage_truths = truth_model.get_lineage()
thresholds = _resolve_thresholds(
Expand All @@ -593,7 +573,7 @@ def match(

# Build the query components
unnested = _build_unnested_clusters()
source_cluster = _find_source_cluster(unnested, source_dataset.hash, source_id)
source_cluster = _find_source_cluster(unnested, source_model.hash, source_id)
hierarchy_up = _build_hierarchy_up(source_cluster, valid_clusters)
highest = _find_highest_parent(hierarchy_up)
hierarchy_down = _build_hierarchy_down(highest, unnested, valid_clusters)
Expand Down Expand Up @@ -622,32 +602,31 @@ def match(

# Create Match objects for each target
result = []
for target_dataset in target_datasets:
for target_model in target_models:
# Get source/target table names
source_name = f"{source_schema}.{source_table}"
target_schema, target_table = next(
(schema, table)
for schema, table in target_pairs
if session.get(Sources, target_dataset.hash).schema == schema
and session.get(Sources, target_dataset.hash).table == table
if session.get(Sources, target_model.hash).schema == schema
and session.get(Sources, target_model.hash).table == table
)
target_name = f"{target_schema}.{target_table}"

# Get source and target IDs
source_ids = {
id
for _, dataset_hash, id in matches
if dataset_hash == source_dataset.hash
if dataset_hash == source_model.hash
}
target_ids = {
id
for _, dataset_hash, id in matches
if dataset_hash == target_dataset.hash
if dataset_hash == target_model.hash
}

match_obj = Match(
cluster=cluster,
source=source_name,
source=source,
source_id=source_ids,
target=target_name,
target_id=target_ids,
Expand Down

0 comments on commit 79ee5b5

Please sign in to comment.