Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to choose the columns we index on #22

Merged
merged 16 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dev = [
"pytest-env>=1.1.5",
"ruff>=0.6.8",
"docker>=7.1.0",
"tomli-w>=1.1.0",
"vcrpy>=6.0.2",
]
typing = [
Expand Down
24 changes: 23 additions & 1 deletion sample.datasets.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,43 @@ database = "pg_warehouse"
db_schema = "companieshouse"
db_table = "companies"
db_pk = "id"
index = [
{ literal = "crn", alias = "crn_id", type = "VARCHAR" },
{ literal = "company_name", alias = "name" },
{ literal = "*" },
{ literal = "postcode" }
]

[datasets.data_hub_companies]
database = "pg_warehouse"
db_schema = "dbt"
db_table = "data_hub__companies"
db_pk = "id"
index = [
{ literal = "cdms", alias = "cdms_id", type = "VARCHAR" },
{ literal = "company_name", alias = "name" },
{ literal = "postcode" },
{ literal = "*" }
]

[datasets.hmrc_exporters]
database = "pg_warehouse"
db_schema = "hmrc"
db_table = "trade__exporters"
db_pk = "id"
index = [
{ literal = "company_name", alias = "name" },
{ literal = "postcode" },
]

[datasets.export_wins]
database = "pg_warehouse"
db_schema = "dbt"
db_table = "export_wins__wins_dataset"
db_pk = "id"
db_pk = "id"
index = [
{ literal = "company_name" },
{ literal = "postcode" },
{ literal = "cdms", alias = "cdms_id", type = "VARCHAR" },
{ literal = "data_hub_company_id", alias = "dh_id", type = "VARCHAR" },
]
6 changes: 2 additions & 4 deletions src/matchbox/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from matchbox.common.db import SourceWarehouse
from matchbox.server import MatchboxDBAdapter, inject_backend
from matchbox.server.base import (
Source,
)
from matchbox.server.base import Source

logger = logging.getLogger("mb_logic")

Expand All @@ -34,7 +32,7 @@ def load_datasets_from_config(datasets: Path) -> dict[str, Source]:
for dataset_name, dataset_config in config["datasets"].items():
warehouse_alias = dataset_config.get("database")
dataset_config["database"] = warehouses[warehouse_alias]
datasets[dataset_name] = Source(**dataset_config)
datasets[dataset_name] = Source(alias=dataset_name, **dataset_config)

return datasets

Expand Down
200 changes: 186 additions & 14 deletions src/matchbox/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
import connectorx as cx
import pyarrow as pa
from matchbox.common.exceptions import MatchboxValidatonError
from matchbox.common.hash import HASH_FUNC
from matchbox.common.hash import HASH_FUNC, hash_to_base64
from pandas import DataFrame
from pyarrow import Table as ArrowTable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
field_validator,
model_validator,
)
from sqlalchemy import (
LABEL_STYLE_TABLENAME_PLUS_COL,
ColumnElement,
Expand Down Expand Up @@ -69,7 +76,7 @@ class SourceWarehouse(BaseModel):
alias: str
db_type: str
user: str
password: str = Field(repr=False)
password: SecretStr
host: str
port: int
database: str
Expand All @@ -78,11 +85,24 @@ class SourceWarehouse(BaseModel):
@property
def engine(self) -> Engine:
if self._engine is None:
connection_string = f"{self.db_type}://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
connection_string = f"{self.db_type}://{self.user}:{self.password.get_secret_value()}@{self.host}:{self.port}/{self.database}"
self._engine = create_engine(connection_string)
self.test_connection()
return self._engine

def __eq__(self, other):
if not isinstance(other, SourceWarehouse):
return False
return (
self.alias == other.alias
and self.db_type == other.db_type
and self.user == other.user
and self.password == other.password
and self.host == other.host
and self.port == other.port
and self.database == other.database
)

def test_connection(self):
try:
with self.engine.connect() as connection:
Expand All @@ -91,12 +111,6 @@ def test_connection(self):
self._engine = None
raise

def __str__(self):
return (
f"SourceWarehouse(alias={self.alias}, type={self.db_type}, "
f"host={self.host}, port={self.port}, database={self.database})"
)

@classmethod
def from_engine(cls, engine: Engine, alias: str | None = None) -> "SourceWarehouse":
"""Create a SourceWarehouse instance from an SQLAlchemy Engine object."""
Expand All @@ -116,17 +130,96 @@ def from_engine(cls, engine: Engine, alias: str | None = None) -> "SourceWarehou
return warehouse


class SourceColumnName(BaseModel):
"""A column name in the Matchbox database."""

name: str

@property
def hash(self) -> bytes:
"""Generate a unique hash based on the column name."""
return HASH_FUNC(self.name.encode("utf-8")).digest()

@property
def base64(self) -> str:
"""Generate a base64 encoded hash based on the column name."""
return hash_to_base64(self.hash)


class SourceColumn(BaseModel):
leo-mazzone marked this conversation as resolved.
Show resolved Hide resolved
"""A column in a dataset that can be indexed in the Matchbox database."""

model_config = ConfigDict(arbitrary_types_allowed=True)

literal: SourceColumnName = Field(
description="The literal name of the column in the database."
)
alias: SourceColumnName = Field(
default_factory=lambda data: SourceColumnName(name=data["literal"].name),
description="The alias to use when hashing the dataset in Matchbox.",
)
type: str | None = Field(
leo-mazzone marked this conversation as resolved.
Show resolved Hide resolved
default=None, description="The type to cast the column to before hashing data."
)
indexed: bool = Field(description="Whether the column is indexed in the database.")

def __eq__(self, other: object) -> bool:
"""Compare SourceColumn with another SourceColumn or bytes object.

Two SourceColumns are equal if:

* Their literal names match, or
leo-mazzone marked this conversation as resolved.
Show resolved Hide resolved
* Their alias names match, or
* The hash of either their literal or alias matches the other object's
corresponding hash

A SourceColumn is equal to a bytes object if:

* The hash of either its literal or alias matches the bytes object

Args:
other: Another SourceColumn or a bytes object to compare against

Returns:
bool: True if the objects are considered equal, False otherwise
"""
if isinstance(other, SourceColumn):
if self.literal == other.literal or self.alias == other.alias:
return True

self_hashes = {self.literal.hash, self.alias.hash}
other_hashes = {other.literal.hash, other.alias.hash}

return bool(self_hashes & other_hashes)

if isinstance(other, bytes):
return other in {self.literal.hash, self.alias.hash}

return NotImplemented

@field_validator("literal", "alias", mode="before")
def string_to_name(cls: "SourceColumn", value: str) -> SourceColumnName:
leo-mazzone marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(value, str):
return SourceColumnName(name=value)
else:
raise ValueError("Column name must be a string.")


class Source(BaseModel):
"""A dataset that can be indexed in the Matchbox database."""

model_config = ConfigDict(
populate_by_name=True,
)

database: SourceWarehouse | None = None
database: SourceWarehouse
db_pk: str
db_schema: str
db_table: str
db_columns: list[SourceColumn]
alias: str = Field(
default_factory=lambda data: f"{data['db_schema']}.{data['db_table']}"
)

def __str__(self) -> str:
return f"{self.db_schema}.{self.db_table}"
Expand All @@ -136,6 +229,86 @@ def __hash__(self) -> int:
(type(self), self.db_pk, self.db_schema, self.db_table, self.database.alias)
)

@model_validator(mode="before")
@classmethod
def hash_columns(cls, data: dict[str, Any]) -> "Source":
"""Shapes indices data from either the backend or TOML.

Handles three scenarios:
1. No columns specified - all columns except primary key are indexed
2. Indices from database - uses existing column hash information
3. Columns specified in TOML - specified columns are indexed
"""
# Initialise warehouse and get table metadata
warehouse = (
data["database"]
if isinstance(data["database"], SourceWarehouse)
else SourceWarehouse(**data["database"])
leo-mazzone marked this conversation as resolved.
Show resolved Hide resolved
)

metadata = MetaData(schema=data["db_schema"])
table = Table(data["db_table"], metadata, autoload_with=warehouse.engine)

# Get all columns except primary key
remote_columns = [
SourceColumn(literal=col.name, type=str(col.type), indexed=False)
for col in table.columns
if col.name not in data["db_pk"]
]

index_data = data.get("index")

# Case 1: No columns specified - index everything
if not index_data:
data["db_columns"] = [
SourceColumn(literal=col.literal.name, type=col.type, indexed=True)
for col in remote_columns
]
return data
leo-mazzone marked this conversation as resolved.
Show resolved Hide resolved

# Case 2: Columns from database
if isinstance(index_data, dict):
data["db_columns"] = [
SourceColumn(
literal=col.literal.name,
type=col.type,
indexed=col in index_data["literal"] + index_data["alias"],
)
for col in remote_columns
]
return data

# Case 3: Columns from TOML
local_columns = []

# Process TOML column specifications
for column in index_data:
local_columns.append(
SourceColumn(
literal=column["literal"],
alias=column.get("alias", column["literal"]),
indexed=True,
)
)

# Match remote columns with local specifications
indexed_columns = []
non_indexed_columns = []

for remote_col in remote_columns:
matched = False
for local_col in local_columns:
if remote_col.literal == local_col.literal:
indexed_columns.append(local_col)
matched = True
break
if not matched:
non_indexed_columns.append(remote_col)

data["db_columns"] = indexed_columns + non_indexed_columns

return data

def to_table(self) -> Table:
"""Returns the dataset as a SQLAlchemy Table object."""
metadata = MetaData(schema=self.db_schema)
Expand Down Expand Up @@ -180,9 +353,8 @@ def _get_column(col_name: str) -> ColumnElement:

def to_hash(self) -> bytes:
"""Generate a unique hash based on the table's columns and datatypes."""
table = self.to_table()
schema_representation = f"{str(self)}: " + ",".join(
f"{col.name}:{str(col.type)}" for col in table.columns
schema_representation = f"{self.alias}: " + ",".join(
f"{col.alias.name}:{col.type}" for col in self.db_columns if col.indexed
)
return HASH_FUNC(schema_representation.encode("utf-8")).digest()

Expand Down
8 changes: 6 additions & 2 deletions src/matchbox/common/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import StrEnum

import rustworkx as rx
from matchbox.common.hash import hash_to_str
from matchbox.common.hash import hash_to_base64
from pydantic import BaseModel


Expand Down Expand Up @@ -36,7 +36,11 @@ def to_rx(self) -> rx.PyDiGraph:
nodes = {}
G = rx.PyDiGraph()
for n in self.nodes:
node_data = {"id": hash_to_str(n.hash), "name": n.name, "type": str(n.type)}
node_data = {
"id": hash_to_base64(n.hash),
"name": n.name,
"type": str(n.type),
}
nodes[n.hash] = G.add_node(node_data)
for e in self.edges:
G.add_edge(nodes[e.parent], nodes[e.child], {})
Expand Down
12 changes: 5 additions & 7 deletions src/matchbox/common/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
HASH_FUNC = hashlib.sha256


def hash_to_str(hash: bytes) -> str:
def hash_to_base64(hash: bytes) -> str:
return base64.b64encode(hash).decode("utf-8")


Expand All @@ -28,16 +28,14 @@ def dataset_to_hashlist(
"""Retrieve and hash a dataset from its warehouse, ready to be inserted."""
with Session(dataset.database.engine) as warehouse_session:
source_table = dataset.to_table()

# Exclude the primary key from the columns to be hashed
cols = tuple(
[col for col in list(source_table.c.keys()) if col != dataset.db_pk]
cols_to_index = tuple(
[col.literal.name for col in dataset.db_columns if col.indexed]
)

slct_stmt = select(
func.concat(*source_table.c[cols]).label("raw"),
func.concat(*source_table.c[cols_to_index]).label("raw"),
func.array_agg(source_table.c[dataset.db_pk].cast(String)).label("id"),
).group_by(*source_table.c[cols])
).group_by(*source_table.c[cols_to_index])

raw_result = warehouse_session.execute(slct_stmt)

Expand Down
Loading
Loading