From e578450803778c6d79b948165bac8ab9d104f5a9 Mon Sep 17 00:00:00 2001 From: Jason Lubken Date: Thu, 23 Dec 2021 12:59:23 -0500 Subject: [PATCH] Add schema to persister configuration --- assets/postgres/schema.sql | 1 - src/dsdk/asset.py | 21 +++++++++++++---- src/dsdk/mssql.py | 13 +++++++++-- src/dsdk/persistor.py | 48 ++++++++++++++------------------------ src/dsdk/postgres.py | 24 ++++++++++++------- test/test_dsdk.py | 40 +++++++++++++++++++++++-------- test/test_postgres.py | 3 ++- 7 files changed, 92 insertions(+), 58 deletions(-) delete mode 100644 assets/postgres/schema.sql diff --git a/assets/postgres/schema.sql b/assets/postgres/schema.sql deleted file mode 100644 index d2650f7..0000000 --- a/assets/postgres/schema.sql +++ /dev/null @@ -1 +0,0 @@ -set search_path = example; diff --git a/src/dsdk/asset.py b/src/dsdk/asset.py index b5e09dc..33c3275 100644 --- a/src/dsdk/asset.py +++ b/src/dsdk/asset.py @@ -67,13 +67,26 @@ def __init__( **kwargs: Asset, ): """__init__.""" - self.path = path - self.ext = ext + self._path = path + self._ext = ext super().__init__(**kwargs) def as_yaml(self) -> Dict[str, Any]: """As yaml.""" return { - "ext": self.ext, - "path": self.path, + "ext": self._ext, + "path": self._path, } + + def __call__(self, *args): + """__call__. + + Yield (path, values). + """ + for key, value in vars(self).items(): + if key.startswith("_"): + continue + if value.__class__ == Asset: + yield from value(*args, key) + continue + yield ".".join((*args, key)), value diff --git a/src/dsdk/mssql.py b/src/dsdk/mssql.py index 3827457..c8c41d6 100644 --- a/src/dsdk/mssql.py +++ b/src/dsdk/mssql.py @@ -64,6 +64,16 @@ def mogrify( """Safely mogrify parameters into query or fragment.""" return _mssql.substitute_params(query, parameters) + def __init__( + self, + *, + port: int = 1433, + schema: str = "dbo", + **kwargs, + ): + """__init__.""" + super().__init__(port=port, schema=schema, **kwargs) + @contextmanager def connect(self) -> Generator[Any, None, None]: """Connect.""" @@ -86,11 +96,10 @@ def connect(self) -> Generator[Any, None, None]: def dry_run( self, query_parameters, - skip=(), exceptions=(DatabaseError, InterfaceError), ): """Dry run.""" - super().dry_run(query_parameters, skip, exceptions) + super().dry_run(query_parameters, exceptions) class Mixin(BaseMixin): diff --git a/src/dsdk/persistor.py b/src/dsdk/persistor.py index 36dda0b..c9c349e 100644 --- a/src/dsdk/persistor.py +++ b/src/dsdk/persistor.py @@ -9,7 +9,7 @@ from re import compile as re_compile from string import Formatter from tempfile import NamedTemporaryFile -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple +from typing import Any, Dict, Generator, Optional, Sequence, Tuple from cfgenvy import yaml_type from pandas import DataFrame, concat @@ -31,8 +31,9 @@ class AbstractPersistor: CLOSE = dumps({"key": f"{KEY}.close"}) COMMIT = dumps({"key": f"{KEY}.commit"}) END = dumps({"key": f"{KEY}.end"}) - ERROR = dumps({"key": f"{KEY}.table.error", "query": "%s"}) - ERRORS = dumps({"key": f"{KEY}.dry_run.error", "query": "%s"}) + DRY_RUN = dumps({"key": f"{KEY}.dry_run.try", "path": "%s"}) + ERROR = dumps({"key": f"{KEY}.dry_run.error", "path": "%s"}) + ERRORS = dumps({"key": f"{KEY}.dry_run.errors", "path": "%s"}) ON = dumps({"key": f"{KEY}.on"}) OPEN = dumps({"key": f"{KEY}.open"}) ROLLBACK = dumps({"key": f"{KEY}.rollback"}) @@ -165,23 +166,18 @@ def __init__(self, sql: Asset): """__init__.""" self.sql = sql - def on_dry_run( + def dry_run( self, - sql: Asset, query_parameters: Dict[str, Any], - skip: Tuple, - exceptions: Tuple, + exceptions: Tuple = (), ): - """On dry run.""" - errors: List[Exception] = [] - for key, value in vars(sql).items(): - if value.__class__ == Asset: - errors += self.on_dry_run( - value, query_parameters, skip, exceptions - ) - continue - if key in skip: - continue + """Execute sql found in asse with dry_run parameter set to 1.""" + logger.info(self.ON) + query_parameters = query_parameters.copy() + query_parameters["dry_run"] = 1 + errors = [] + for path, value in self.sql(): + logger.info(self.DRY_RUN, path) with self.rollback() as cur: rendered = self.render_without_keys( cur, @@ -195,21 +191,8 @@ def on_dry_run( try: cur.execute(rendered) except exceptions as e: - logger.warning(self.ERROR, key) + logger.warning(self.ERROR, path) errors.append(e) - return errors - - def dry_run( - self, - query_parameters: Dict[str, Any], - skip: Tuple = (), - exceptions: Tuple = (), - ): - """Execute sql found in asse with dry_run parameter set to 1.""" - logger.info(self.ON) - query_parameters = query_parameters.copy() - query_parameters["dry_run"] = 1 - errors = self.on_dry_run(self.sql, query_parameters, skip, exceptions) if bool(errors): raise RuntimeError(self.ERRORS, errors) logger.info(self.END) @@ -294,6 +277,7 @@ def __init__( # pylint: disable=too-many-arguments host: str, password: str, port: int, + schema: str, sql: Asset, username: str, ): @@ -302,6 +286,7 @@ def __init__( # pylint: disable=too-many-arguments self.host = host self.password = password self.port = port + self.schema = schema self.username = username super().__init__(sql) @@ -312,6 +297,7 @@ def as_yaml(self) -> Dict[str, Any]: "host": self.host, "password": self.password, "port": self.port, + "schema": self.schema, "sql": self.sql, "username": self.username, } diff --git a/src/dsdk/postgres.py b/src/dsdk/postgres.py index 4cee8c0..24f51a4 100644 --- a/src/dsdk/postgres.py +++ b/src/dsdk/postgres.py @@ -131,14 +131,23 @@ def mogrify( """Safely mogrify parameters into query or fragment.""" return cur.mogrify(query, parameters) + def __init__( + self, + *, + port: int = 5432, + schema: str = "public", + **kwargs, + ): + """__init__.""" + super().__init__(port=port, schema=schema, **kwargs) + def dry_run( self, query_parameters, - skip=("schema",), exceptions=(DatabaseError, InterfaceError), ): """Dry run.""" - super().dry_run(query_parameters, skip, exceptions) + super().dry_run(query_parameters, exceptions) @contextmanager def listen(self, *listens: str) -> Generator[Any, None, None]: @@ -180,7 +189,7 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]: sql = self.sql columns = parent.as_insert_sql() with self.commit() as cur: - cur.execute(sql.schema) + cur.execute(f"set search_path={self.schema}") cur.execute(sql.runs.open, columns) for row in cur: ( @@ -208,7 +217,7 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]: yield run with self.commit() as cur: - cur.execute(sql.schema) + cur.execute(f"set search_path={self.schema}") predictions = run.predictions if predictions is not None: # pylint: disable=unsupported-assignment-operation @@ -239,7 +248,7 @@ def scores(self, run_id) -> Series: """Return scores series.""" sql = self.sql with self.rollback() as cur: - cur.execute(sql.schema) + cur.execute(f"set search_path={self.schema}") return self.df_from_query( cur, sql.predictions.gold, @@ -249,7 +258,6 @@ def scores(self, run_id) -> Series: def store_evidence(self, run: Any, *args, **kwargs) -> None: """Store evidence.""" sql = self.sql - schema = sql.schema run_id = run.id evidence = run.evidence exclude = set(kwargs.get("exclude", ())) @@ -266,7 +274,6 @@ def store_evidence(self, run: Any, *args, **kwargs) -> None: f"Missing sql/postgres/{key}/insert.sql" ) from e self._store_df( - schema, insert, run_id, df[list(set(df.columns) - exclude)], @@ -274,7 +281,6 @@ def store_evidence(self, run: Any, *args, **kwargs) -> None: def _store_df( self, - schema: str, insert: str, run_id: int, df: DataFrame, @@ -283,7 +289,7 @@ def _store_df( out = df.to_dict("records") try: with self.commit() as cur: - cur.execute(schema) + cur.execute(f"set search_path={self.schema}") execute_batch( cur, insert, diff --git a/test/test_dsdk.py b/test/test_dsdk.py index 07da1fc..1ac72b2 100644 --- a/test/test_dsdk.py +++ b/test/test_dsdk.py @@ -77,7 +77,6 @@ def __init__(self, **kwargs): database: test host: 0.0.0.0 password: ${MSSQL_PASSWORD} - port: 1433 sql: !asset ext: .sql path: ./assets/mssql @@ -87,7 +86,7 @@ def __init__(self, **kwargs): database: test host: 0.0.0.0 password: ${POSTGRES_PASSWORD} - port: 5432 + schema: example sql: !asset ext: .sql path: ./assets/postgres @@ -110,6 +109,7 @@ def __init__(self, **kwargs): host: 0.0.0.0 password: password port: 1433 + schema: dbo sql: !asset ext: .sql path: ./assets/mssql @@ -119,6 +119,7 @@ def __init__(self, **kwargs): host: 0.0.0.0 password: password port: 5432 + schema: example sql: !asset ext: .sql path: ./assets/postgres @@ -135,20 +136,19 @@ def build( cls.yaml_types() model = Model(name="test", path="./test/model.pkl", version="0.0.1-rc.1") mssql = Mssql( - username="mssql", - password="password", - host="0.0.0.0", - port=1433, database="test", + host="0.0.0.0", + password="password", sql=Asset.build(path="./assets/mssql", ext=".sql"), + username="mssql", ) postgres = Postgres( - username="postgres", - password="password", - host="0.0.0.0", - port=5432, database="test", + host="0.0.0.0", + password="password", + schema="example", sql=Asset.build(path="./assets/postgres", ext=".sql"), + username="postgres", ) return ( cls, @@ -271,3 +271,23 @@ def explode(): except NotImplementedError as exception: assert actual == expected assert str(exception) == "when?" + + +def test_asset(): + """Test asset traversal.""" + asset = Asset( + ext=".sql", + path="./predict/sql/mssql", + run=Asset( + ext=".sql", + path="./predict/sql/mssql/run", + select="select * from runs", + ), + cohort="select * from patients", + ) + actual = tuple(each for each in asset()) + expected = ( + ("run.select", "select * from runs"), + ("cohort", "select * from patients"), + ) + assert actual == expected diff --git a/test/test_postgres.py b/test/test_postgres.py index 63ec54b..59d13fb 100644 --- a/test/test_postgres.py +++ b/test/test_postgres.py @@ -37,6 +37,7 @@ def __init__( database=kwargs.get( "database", env.get("POSTGRES_DATABASE", "test") ), + schema=kwargs.get("schema", env.get("POSTGRES_SCHEMA", "example")), sql=Asset.build( path=kwargs.get( "sql", env.get("POSTGRES_SQL", "./assets/postgres") @@ -120,7 +121,7 @@ def test_open_run( run.predictions = df with persistor.rollback() as cur: - cur.execute(persistor.sql.schema) + cur.execute(f"set search_path={persistor.schema}") df = read_sql_query( sql=check, con=cur.connection, params={"run_id": run.id} )