diff --git a/docker-compose.yml b/docker-compose.yml index 0dcda60..044c536 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,15 +1,27 @@ services: - db: + warehouse: image: postgres:14 restart: always environment: - POSTGRES_USER: testuser - POSTGRES_PASSWORD: testpassword - POSTGRES_DB: testdb + POSTGRES_USER: warehouse_user + POSTGRES_PASSWORD: warehouse_password + POSTGRES_DB: warehouse + ports: + - "7654:5432" + volumes: + - warehouse_data:/var/lib/postgresql/data + matchbox-postgres: + image: postgres:14 + restart: always + environment: + POSTGRES_USER: matchbox_user + POSTGRES_PASSWORD: matchbox_password + POSTGRES_DB: matchbox ports: - "5432:5432" volumes: - - pgdata:/var/lib/postgresql/data + - matchbox_data:/var/lib/postgresql/data volumes: - pgdata: \ No newline at end of file + warehouse_data: + matchbox_data: \ No newline at end of file diff --git a/sample.env b/sample.env index 6c57485..999b098 100644 --- a/sample.env +++ b/sample.env @@ -6,4 +6,5 @@ MB__POSTGRES__HOST= MB__POSTGRES__PORT= MB__POSTGRES__USER= MB__POSTGRES__PASSWORD= +MB__POSTGRES__DATABASE= MB__POSTGRES__DB_SCHEMA=mb diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index 5d38df5..a522d86 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -4,7 +4,7 @@ from typing import Literal, Protocol import pandas as pd -from pydantic import AnyUrl, BaseModel, Field +from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict from rustworkx import PyDiGraph from sqlalchemy import create_engine @@ -53,10 +53,11 @@ class SourceWarehouse(BaseModel): alias: str db_type: str - username: str + user: str password: str = Field(repr=False) - host: AnyUrl + host: str port: int + database: str _engine: Engine | None = None class Config: @@ -67,7 +68,7 @@ class Config: @property def engine(self) -> Engine: if self._engine is None: - connection_string = f"{self.db_type}://{self.username}:{self.password}@{self.host}:{self.port}" + connection_string = f"{self.db_type}://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" self._engine = create_engine(connection_string) self.test_connection() return self._engine @@ -83,7 +84,7 @@ def test_connection(self): def __str__(self): return ( f"SourceWarehouse(alias={self.alias}, type={self.db_type}, " - f"host={self.host}, port={self.port})" + f"host={self.host}, port={self.port}, database={self.database})" ) @@ -183,7 +184,10 @@ def get_model_subgraph(self) -> PyDiGraph: ... def get_model(self, model: str) -> MatchboxModelAdapter: ... @abstractmethod - def delete_model(self, model: str) -> None: ... + def delete_model(self, model: str, certain: bool) -> None: ... @abstractmethod def insert_model(self, model: str) -> None: ... + + @abstractmethod + def clear(self, certain: bool) -> None: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 77c7211..5bc1200 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -277,3 +277,11 @@ def insert_model( description=description, engine=MBDB.get_engine(), ) + + def clear(self, certain: bool = False) -> None: + """Clears all data from the database. + + Args: + certain: Whether to clear the database without confirmation. + """ + MBDB.clear_database() diff --git a/src/matchbox/server/postgresql/db.py b/src/matchbox/server/postgresql/db.py index 11272fd..2a3725e 100644 --- a/src/matchbox/server/postgresql/db.py +++ b/src/matchbox/server/postgresql/db.py @@ -17,17 +17,21 @@ class MatchboxPostgresCoreSettings(BaseModel): - """Settings for Matchbox's PostgreSQL backend.""" + """PostgreSQL-specific settings for Matchbox.""" host: str port: int user: str password: str + database: str db_schema: str class MatchboxPostgresSettings(MatchboxSettings): - """Settings for the Matchbox PostgreSQL backend.""" + """Settings for the Matchbox PostgreSQL backend. + + Inherits the core settings and adds the PostgreSQL-specific settings. + """ backend_type: MatchboxBackends = MatchboxBackends.POSTGRES @@ -37,6 +41,8 @@ class MatchboxPostgresSettings(MatchboxSettings): class MatchboxDatabase: + """Matchbox PostgreSQL database connection.""" + def __init__(self, settings: MatchboxPostgresSettings): self.settings = settings self.engine: Engine | None = None @@ -44,10 +50,13 @@ def __init__(self, settings: MatchboxPostgresSettings): self.MatchboxBase = declarative_base() def connect(self): + """Connect to the database.""" + if not self.engine: connection_string = ( f"postgresql://{self.settings.postgres.user}:{self.settings.postgres.password}" - f"@{self.settings.postgres.host}:{self.settings.postgres.port}" + f"@{self.settings.postgres.host}:{self.settings.postgres.port}/" + f"{self.settings.postgres.database}" ) self.engine = create_engine(connection_string, logging_name="mb_pg_db") self.SessionLocal = sessionmaker( @@ -56,16 +65,22 @@ def connect(self): self.MatchboxBase.metadata.schema = self.settings.postgres.db_schema def get_engine(self) -> Engine: + """Get the database engine.""" + if not self.engine: self.connect() return self.engine def get_session(self): + """Get a new session.""" + if not self.SessionLocal: self.connect() return self.SessionLocal() def create_database(self): + """Create the database.""" + self.connect() with self.engine.connect() as conn: conn.execute( @@ -77,6 +92,20 @@ def create_database(self): self.MatchboxBase.metadata.create_all(self.engine) + def clear_database(self): + """Clear the database.""" + + self.connect() + with self.engine.connect() as conn: + conn.execute( + text( + f"DROP SCHEMA IF EXISTS {self.settings.postgres.db_schema} CASCADE;" + ) + ) + conn.commit() + + self.create_database() + # Global database instance -- everything should use this diff --git a/test/fixtures/db.py b/test/fixtures/db.py index e7effa6..d41e149 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -168,17 +168,18 @@ def _db_add_link_models_and_data( @pytest.fixture(scope="session") -def warehouse_engine() -> SourceWarehouse: +def warehouse() -> SourceWarehouse: """Create a connection to the test warehouse database.""" warehouse = SourceWarehouse( alias="test_warehouse", db_type="postgresql", - username="test_user", - password="test_password", + user="warehouse_user", + password="warehouse_password", host="localhost", - port=5432, + database="warehouse", + port=7654, ) - _ = warehouse.engine() + assert warehouse.engine return warehouse @@ -190,7 +191,9 @@ def warehouse_data( cdms_companies: DataFrame, ) -> Generator[list[IndexableDataset], None, None]: """Inserts data into the warehouse database for testing.""" - with warehouse.engine().connect() as conn: + with warehouse.engine.connect() as conn: + conn.execute(text("drop schema if exists test cascade;")) + conn.execute(text("create schema test;")) crn_companies.to_sql( "crn", con=conn, @@ -226,7 +229,7 @@ def warehouse_data( ] # Clean up the warehouse data - with warehouse.engine().connect() as conn: + with warehouse.engine.connect() as conn: conn.execute(text("drop table if exists test.crn;")) conn.execute(text("drop table if exists test.duns;")) conn.execute(text("drop table if exists test.cdms;")) @@ -241,11 +244,14 @@ def matchbox_settings() -> MatchboxPostgresSettings: """Settings for the Matchbox database.""" return MatchboxPostgresSettings( batch_size=250_000, - host="localhost", - port=5432, - user="test_user", - password="test_password", - schema="test_matchbox", + postgres={ + "host": "localhost", + "port": 5432, + "user": "matchbox_user", + "password": "matchbox_password", + "database": "matchbox", + "db_schema": "matchbox", + }, ) @@ -260,7 +266,4 @@ def matchbox_postgres( yield adapter # Clean up the Matchbox database after each test - with adapter.engine.connect() as conn: - conn.execute(text(f"drop schema if exists {matchbox_settings.schema} cascade;")) - conn.execute(text(f"create schema {matchbox_settings.schema};")) - conn.commit() + adapter.clear(certain=True)