Skip to content

Commit

Permalink
Merge pull request #22 from uktrade/feature/index-columns
Browse files Browse the repository at this point in the history
Add the ability to choose the columns we index on
  • Loading branch information
wpfl-dbt authored Dec 12, 2024
2 parents 2f825f8 + 3e75fb5 commit c426882
Show file tree
Hide file tree
Showing 14 changed files with 451 additions and 75 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dev = [
"pytest-env>=1.1.5",
"ruff>=0.6.8",
"docker>=7.1.0",
"tomli-w>=1.1.0",
"vcrpy>=6.0.2",
]
typing = [
Expand Down
24 changes: 23 additions & 1 deletion sample.datasets.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,43 @@ database = "pg_warehouse"
db_schema = "companieshouse"
db_table = "companies"
db_pk = "id"
index = [
{ literal = "crn", alias = "crn_id", type = "VARCHAR" },
{ literal = "company_name", alias = "name" },
{ literal = "*" },
{ literal = "postcode" }
]

[datasets.data_hub_companies]
database = "pg_warehouse"
db_schema = "dbt"
db_table = "data_hub__companies"
db_pk = "id"
index = [
{ literal = "cdms", alias = "cdms_id", type = "VARCHAR" },
{ literal = "company_name", alias = "name" },
{ literal = "postcode" },
{ literal = "*" }
]

[datasets.hmrc_exporters]
database = "pg_warehouse"
db_schema = "hmrc"
db_table = "trade__exporters"
db_pk = "id"
index = [
{ literal = "company_name", alias = "name" },
{ literal = "postcode" },
]

[datasets.export_wins]
database = "pg_warehouse"
db_schema = "dbt"
db_table = "export_wins__wins_dataset"
db_pk = "id"
db_pk = "id"
index = [
{ literal = "company_name" },
{ literal = "postcode" },
{ literal = "cdms", alias = "cdms_id", type = "VARCHAR" },
{ literal = "data_hub_company_id", alias = "dh_id", type = "VARCHAR" },
]
6 changes: 2 additions & 4 deletions src/matchbox/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from matchbox.common.db import SourceWarehouse
from matchbox.server import MatchboxDBAdapter, inject_backend
from matchbox.server.base import (
Source,
)
from matchbox.server.base import Source

logger = logging.getLogger("mb_logic")

Expand All @@ -34,7 +32,7 @@ def load_datasets_from_config(datasets: Path) -> dict[str, Source]:
for dataset_name, dataset_config in config["datasets"].items():
warehouse_alias = dataset_config.get("database")
dataset_config["database"] = warehouses[warehouse_alias]
datasets[dataset_name] = Source(**dataset_config)
datasets[dataset_name] = Source(alias=dataset_name, **dataset_config)

return datasets

Expand Down
200 changes: 186 additions & 14 deletions src/matchbox/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
import connectorx as cx
import pyarrow as pa
from matchbox.common.exceptions import MatchboxValidatonError
from matchbox.common.hash import HASH_FUNC
from matchbox.common.hash import HASH_FUNC, hash_to_base64
from pandas import DataFrame
from pyarrow import Table as ArrowTable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
field_validator,
model_validator,
)
from sqlalchemy import (
LABEL_STYLE_TABLENAME_PLUS_COL,
ColumnElement,
Expand Down Expand Up @@ -69,7 +76,7 @@ class SourceWarehouse(BaseModel):
alias: str
db_type: str
user: str
password: str = Field(repr=False)
password: SecretStr
host: str
port: int
database: str
Expand All @@ -78,11 +85,24 @@ class SourceWarehouse(BaseModel):
@property
def engine(self) -> Engine:
if self._engine is None:
connection_string = f"{self.db_type}://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
connection_string = f"{self.db_type}://{self.user}:{self.password.get_secret_value()}@{self.host}:{self.port}/{self.database}"
self._engine = create_engine(connection_string)
self.test_connection()
return self._engine

def __eq__(self, other):
if not isinstance(other, SourceWarehouse):
return False
return (
self.alias == other.alias
and self.db_type == other.db_type
and self.user == other.user
and self.password == other.password
and self.host == other.host
and self.port == other.port
and self.database == other.database
)

def test_connection(self):
try:
with self.engine.connect() as connection:
Expand All @@ -91,12 +111,6 @@ def test_connection(self):
self._engine = None
raise

def __str__(self):
return (
f"SourceWarehouse(alias={self.alias}, type={self.db_type}, "
f"host={self.host}, port={self.port}, database={self.database})"
)

@classmethod
def from_engine(cls, engine: Engine, alias: str | None = None) -> "SourceWarehouse":
"""Create a SourceWarehouse instance from an SQLAlchemy Engine object."""
Expand All @@ -116,17 +130,96 @@ def from_engine(cls, engine: Engine, alias: str | None = None) -> "SourceWarehou
return warehouse


class SourceColumnName(BaseModel):
"""A column name in the Matchbox database."""

name: str

@property
def hash(self) -> bytes:
"""Generate a unique hash based on the column name."""
return HASH_FUNC(self.name.encode("utf-8")).digest()

@property
def base64(self) -> str:
"""Generate a base64 encoded hash based on the column name."""
return hash_to_base64(self.hash)


class SourceColumn(BaseModel):
"""A column in a dataset that can be indexed in the Matchbox database."""

model_config = ConfigDict(arbitrary_types_allowed=True)

literal: SourceColumnName = Field(
description="The literal name of the column in the database."
)
alias: SourceColumnName = Field(
default_factory=lambda data: SourceColumnName(name=data["literal"].name),
description="The alias to use when hashing the dataset in Matchbox.",
)
type: str | None = Field(
default=None, description="The type to cast the column to before hashing data."
)
indexed: bool = Field(description="Whether the column is indexed in the database.")

def __eq__(self, other: object) -> bool:
"""Compare SourceColumn with another SourceColumn or bytes object.
Two SourceColumns are equal if:
* Their literal names match, or
* Their alias names match, or
* The hash of either their literal or alias matches the other object's
corresponding hash
A SourceColumn is equal to a bytes object if:
* The hash of either its literal or alias matches the bytes object
Args:
other: Another SourceColumn or a bytes object to compare against
Returns:
bool: True if the objects are considered equal, False otherwise
"""
if isinstance(other, SourceColumn):
if self.literal == other.literal or self.alias == other.alias:
return True

self_hashes = {self.literal.hash, self.alias.hash}
other_hashes = {other.literal.hash, other.alias.hash}

return bool(self_hashes & other_hashes)

if isinstance(other, bytes):
return other in {self.literal.hash, self.alias.hash}

return NotImplemented

@field_validator("literal", "alias", mode="before")
def string_to_name(cls: "SourceColumn", value: str) -> SourceColumnName:
if isinstance(value, str):
return SourceColumnName(name=value)
else:
raise ValueError("Column name must be a string.")


class Source(BaseModel):
"""A dataset that can be indexed in the Matchbox database."""

model_config = ConfigDict(
populate_by_name=True,
)

database: SourceWarehouse | None = None
database: SourceWarehouse
db_pk: str
db_schema: str
db_table: str
db_columns: list[SourceColumn]
alias: str = Field(
default_factory=lambda data: f"{data['db_schema']}.{data['db_table']}"
)

def __str__(self) -> str:
return f"{self.db_schema}.{self.db_table}"
Expand All @@ -136,6 +229,86 @@ def __hash__(self) -> int:
(type(self), self.db_pk, self.db_schema, self.db_table, self.database.alias)
)

@model_validator(mode="before")
@classmethod
def hash_columns(cls, data: dict[str, Any]) -> "Source":
"""Shapes indices data from either the backend or TOML.
Handles three scenarios:
1. No columns specified - all columns except primary key are indexed
2. Indices from database - uses existing column hash information
3. Columns specified in TOML - specified columns are indexed
"""
# Initialise warehouse and get table metadata
warehouse = (
data["database"]
if isinstance(data["database"], SourceWarehouse)
else SourceWarehouse(**data["database"])
)

metadata = MetaData(schema=data["db_schema"])
table = Table(data["db_table"], metadata, autoload_with=warehouse.engine)

# Get all columns except primary key
remote_columns = [
SourceColumn(literal=col.name, type=str(col.type), indexed=False)
for col in table.columns
if col.name not in data["db_pk"]
]

index_data = data.get("index")

# Case 1: No columns specified - index everything
if not index_data:
data["db_columns"] = [
SourceColumn(literal=col.literal.name, type=col.type, indexed=True)
for col in remote_columns
]
return data

# Case 2: Columns from database
if isinstance(index_data, dict):
data["db_columns"] = [
SourceColumn(
literal=col.literal.name,
type=col.type,
indexed=col in index_data["literal"] + index_data["alias"],
)
for col in remote_columns
]
return data

# Case 3: Columns from TOML
local_columns = []

# Process TOML column specifications
for column in index_data:
local_columns.append(
SourceColumn(
literal=column["literal"],
alias=column.get("alias", column["literal"]),
indexed=True,
)
)

# Match remote columns with local specifications
indexed_columns = []
non_indexed_columns = []

for remote_col in remote_columns:
matched = False
for local_col in local_columns:
if remote_col.literal == local_col.literal:
indexed_columns.append(local_col)
matched = True
break
if not matched:
non_indexed_columns.append(remote_col)

data["db_columns"] = indexed_columns + non_indexed_columns

return data

def to_table(self) -> Table:
"""Returns the dataset as a SQLAlchemy Table object."""
metadata = MetaData(schema=self.db_schema)
Expand Down Expand Up @@ -180,9 +353,8 @@ def _get_column(col_name: str) -> ColumnElement:

def to_hash(self) -> bytes:
"""Generate a unique hash based on the table's columns and datatypes."""
table = self.to_table()
schema_representation = f"{str(self)}: " + ",".join(
f"{col.name}:{str(col.type)}" for col in table.columns
schema_representation = f"{self.alias}: " + ",".join(
f"{col.alias.name}:{col.type}" for col in self.db_columns if col.indexed
)
return HASH_FUNC(schema_representation.encode("utf-8")).digest()

Expand Down
8 changes: 6 additions & 2 deletions src/matchbox/common/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import StrEnum

import rustworkx as rx
from matchbox.common.hash import hash_to_str
from matchbox.common.hash import hash_to_base64
from pydantic import BaseModel


Expand Down Expand Up @@ -36,7 +36,11 @@ def to_rx(self) -> rx.PyDiGraph:
nodes = {}
G = rx.PyDiGraph()
for n in self.nodes:
node_data = {"id": hash_to_str(n.hash), "name": n.name, "type": str(n.type)}
node_data = {
"id": hash_to_base64(n.hash),
"name": n.name,
"type": str(n.type),
}
nodes[n.hash] = G.add_node(node_data)
for e in self.edges:
G.add_edge(nodes[e.parent], nodes[e.child], {})
Expand Down
12 changes: 5 additions & 7 deletions src/matchbox/common/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
HASH_FUNC = hashlib.sha256


def hash_to_str(hash: bytes) -> str:
def hash_to_base64(hash: bytes) -> str:
return base64.b64encode(hash).decode("utf-8")


Expand All @@ -28,16 +28,14 @@ def dataset_to_hashlist(
"""Retrieve and hash a dataset from its warehouse, ready to be inserted."""
with Session(dataset.database.engine) as warehouse_session:
source_table = dataset.to_table()

# Exclude the primary key from the columns to be hashed
cols = tuple(
[col for col in list(source_table.c.keys()) if col != dataset.db_pk]
cols_to_index = tuple(
[col.literal.name for col in dataset.db_columns if col.indexed]
)

slct_stmt = select(
func.concat(*source_table.c[cols]).label("raw"),
func.concat(*source_table.c[cols_to_index]).label("raw"),
func.array_agg(source_table.c[dataset.db_pk].cast(String)).label("id"),
).group_by(*source_table.c[cols])
).group_by(*source_table.c[cols_to_index])

raw_result = warehouse_session.execute(slct_stmt)

Expand Down
Loading

0 comments on commit c426882

Please sign in to comment.