Skip to content

Commit

Permalink
Add schema to persister configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubken committed Dec 23, 2021
1 parent 3446089 commit e578450
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 58 deletions.
1 change: 0 additions & 1 deletion assets/postgres/schema.sql

This file was deleted.

21 changes: 17 additions & 4 deletions src/dsdk/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,26 @@ def __init__(
**kwargs: Asset,
):
"""__init__."""
self.path = path
self.ext = ext
self._path = path
self._ext = ext
super().__init__(**kwargs)

def as_yaml(self) -> Dict[str, Any]:
"""As yaml."""
return {
"ext": self.ext,
"path": self.path,
"ext": self._ext,
"path": self._path,
}

def __call__(self, *args):
"""__call__.
Yield (path, values).
"""
for key, value in vars(self).items():
if key.startswith("_"):
continue
if value.__class__ == Asset:
yield from value(*args, key)
continue
yield ".".join((*args, key)), value
13 changes: 11 additions & 2 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ def mogrify(
"""Safely mogrify parameters into query or fragment."""
return _mssql.substitute_params(query, parameters)

def __init__(
self,
*,
port: int = 1433,
schema: str = "dbo",
**kwargs,
):
"""__init__."""
super().__init__(port=port, schema=schema, **kwargs)

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
Expand All @@ -86,11 +96,10 @@ def connect(self) -> Generator[Any, None, None]:
def dry_run(
self,
query_parameters,
skip=(),
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run(query_parameters, skip, exceptions)
super().dry_run(query_parameters, exceptions)


class Mixin(BaseMixin):
Expand Down
48 changes: 17 additions & 31 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from re import compile as re_compile
from string import Formatter
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Generator, Optional, Sequence, Tuple

from cfgenvy import yaml_type
from pandas import DataFrame, concat
Expand All @@ -31,8 +31,9 @@ class AbstractPersistor:
CLOSE = dumps({"key": f"{KEY}.close"})
COMMIT = dumps({"key": f"{KEY}.commit"})
END = dumps({"key": f"{KEY}.end"})
ERROR = dumps({"key": f"{KEY}.table.error", "query": "%s"})
ERRORS = dumps({"key": f"{KEY}.dry_run.error", "query": "%s"})
DRY_RUN = dumps({"key": f"{KEY}.dry_run.try", "path": "%s"})
ERROR = dumps({"key": f"{KEY}.dry_run.error", "path": "%s"})
ERRORS = dumps({"key": f"{KEY}.dry_run.errors", "path": "%s"})
ON = dumps({"key": f"{KEY}.on"})
OPEN = dumps({"key": f"{KEY}.open"})
ROLLBACK = dumps({"key": f"{KEY}.rollback"})
Expand Down Expand Up @@ -165,23 +166,18 @@ def __init__(self, sql: Asset):
"""__init__."""
self.sql = sql

def on_dry_run(
def dry_run(
self,
sql: Asset,
query_parameters: Dict[str, Any],
skip: Tuple,
exceptions: Tuple,
exceptions: Tuple = (),
):
"""On dry run."""
errors: List[Exception] = []
for key, value in vars(sql).items():
if value.__class__ == Asset:
errors += self.on_dry_run(
value, query_parameters, skip, exceptions
)
continue
if key in skip:
continue
"""Execute sql found in asse with dry_run parameter set to 1."""
logger.info(self.ON)
query_parameters = query_parameters.copy()
query_parameters["dry_run"] = 1
errors = []
for path, value in self.sql():
logger.info(self.DRY_RUN, path)
with self.rollback() as cur:
rendered = self.render_without_keys(
cur,
Expand All @@ -195,21 +191,8 @@ def on_dry_run(
try:
cur.execute(rendered)
except exceptions as e:
logger.warning(self.ERROR, key)
logger.warning(self.ERROR, path)
errors.append(e)
return errors

def dry_run(
self,
query_parameters: Dict[str, Any],
skip: Tuple = (),
exceptions: Tuple = (),
):
"""Execute sql found in asse with dry_run parameter set to 1."""
logger.info(self.ON)
query_parameters = query_parameters.copy()
query_parameters["dry_run"] = 1
errors = self.on_dry_run(self.sql, query_parameters, skip, exceptions)
if bool(errors):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)
Expand Down Expand Up @@ -294,6 +277,7 @@ def __init__( # pylint: disable=too-many-arguments
host: str,
password: str,
port: int,
schema: str,
sql: Asset,
username: str,
):
Expand All @@ -302,6 +286,7 @@ def __init__( # pylint: disable=too-many-arguments
self.host = host
self.password = password
self.port = port
self.schema = schema
self.username = username
super().__init__(sql)

Expand All @@ -312,6 +297,7 @@ def as_yaml(self) -> Dict[str, Any]:
"host": self.host,
"password": self.password,
"port": self.port,
"schema": self.schema,
"sql": self.sql,
"username": self.username,
}
Expand Down
24 changes: 15 additions & 9 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,23 @@ def mogrify(
"""Safely mogrify parameters into query or fragment."""
return cur.mogrify(query, parameters)

def __init__(
self,
*,
port: int = 5432,
schema: str = "public",
**kwargs,
):
"""__init__."""
super().__init__(port=port, schema=schema, **kwargs)

def dry_run(
self,
query_parameters,
skip=("schema",),
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run(query_parameters, skip, exceptions)
super().dry_run(query_parameters, exceptions)

@contextmanager
def listen(self, *listens: str) -> Generator[Any, None, None]:
Expand Down Expand Up @@ -180,7 +189,7 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]:
sql = self.sql
columns = parent.as_insert_sql()
with self.commit() as cur:
cur.execute(sql.schema)
cur.execute(f"set search_path={self.schema}")
cur.execute(sql.runs.open, columns)
for row in cur:
(
Expand Down Expand Up @@ -208,7 +217,7 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]:
yield run

with self.commit() as cur:
cur.execute(sql.schema)
cur.execute(f"set search_path={self.schema}")
predictions = run.predictions
if predictions is not None:
# pylint: disable=unsupported-assignment-operation
Expand Down Expand Up @@ -239,7 +248,7 @@ def scores(self, run_id) -> Series:
"""Return scores series."""
sql = self.sql
with self.rollback() as cur:
cur.execute(sql.schema)
cur.execute(f"set search_path={self.schema}")
return self.df_from_query(
cur,
sql.predictions.gold,
Expand All @@ -249,7 +258,6 @@ def scores(self, run_id) -> Series:
def store_evidence(self, run: Any, *args, **kwargs) -> None:
"""Store evidence."""
sql = self.sql
schema = sql.schema
run_id = run.id
evidence = run.evidence
exclude = set(kwargs.get("exclude", ()))
Expand All @@ -266,15 +274,13 @@ def store_evidence(self, run: Any, *args, **kwargs) -> None:
f"Missing sql/postgres/{key}/insert.sql"
) from e
self._store_df(
schema,
insert,
run_id,
df[list(set(df.columns) - exclude)],
)

def _store_df(
self,
schema: str,
insert: str,
run_id: int,
df: DataFrame,
Expand All @@ -283,7 +289,7 @@ def _store_df(
out = df.to_dict("records")
try:
with self.commit() as cur:
cur.execute(schema)
cur.execute(f"set search_path={self.schema}")
execute_batch(
cur,
insert,
Expand Down
40 changes: 30 additions & 10 deletions test/test_dsdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(self, **kwargs):
database: test
host: 0.0.0.0
password: ${MSSQL_PASSWORD}
port: 1433
sql: !asset
ext: .sql
path: ./assets/mssql
Expand All @@ -87,7 +86,7 @@ def __init__(self, **kwargs):
database: test
host: 0.0.0.0
password: ${POSTGRES_PASSWORD}
port: 5432
schema: example
sql: !asset
ext: .sql
path: ./assets/postgres
Expand All @@ -110,6 +109,7 @@ def __init__(self, **kwargs):
host: 0.0.0.0
password: password
port: 1433
schema: dbo
sql: !asset
ext: .sql
path: ./assets/mssql
Expand All @@ -119,6 +119,7 @@ def __init__(self, **kwargs):
host: 0.0.0.0
password: password
port: 5432
schema: example
sql: !asset
ext: .sql
path: ./assets/postgres
Expand All @@ -135,20 +136,19 @@ def build(
cls.yaml_types()
model = Model(name="test", path="./test/model.pkl", version="0.0.1-rc.1")
mssql = Mssql(
username="mssql",
password="password",
host="0.0.0.0",
port=1433,
database="test",
host="0.0.0.0",
password="password",
sql=Asset.build(path="./assets/mssql", ext=".sql"),
username="mssql",
)
postgres = Postgres(
username="postgres",
password="password",
host="0.0.0.0",
port=5432,
database="test",
host="0.0.0.0",
password="password",
schema="example",
sql=Asset.build(path="./assets/postgres", ext=".sql"),
username="postgres",
)
return (
cls,
Expand Down Expand Up @@ -271,3 +271,23 @@ def explode():
except NotImplementedError as exception:
assert actual == expected
assert str(exception) == "when?"


def test_asset():
"""Test asset traversal."""
asset = Asset(
ext=".sql",
path="./predict/sql/mssql",
run=Asset(
ext=".sql",
path="./predict/sql/mssql/run",
select="select * from runs",
),
cohort="select * from patients",
)
actual = tuple(each for each in asset())
expected = (
("run.select", "select * from runs"),
("cohort", "select * from patients"),
)
assert actual == expected
3 changes: 2 additions & 1 deletion test/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
database=kwargs.get(
"database", env.get("POSTGRES_DATABASE", "test")
),
schema=kwargs.get("schema", env.get("POSTGRES_SCHEMA", "example")),
sql=Asset.build(
path=kwargs.get(
"sql", env.get("POSTGRES_SQL", "./assets/postgres")
Expand Down Expand Up @@ -120,7 +121,7 @@ def test_open_run(
run.predictions = df

with persistor.rollback() as cur:
cur.execute(persistor.sql.schema)
cur.execute(f"set search_path={persistor.schema}")
df = read_sql_query(
sql=check, con=cur.connection, params={"run_id": run.id}
)
Expand Down

0 comments on commit e578450

Please sign in to comment.