Skip to content

Commit

Permalink
Add efficient database I/O
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-thom committed Oct 14, 2024
1 parent 847c993 commit bcbaa7f
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 25 deletions.
12 changes: 12 additions & 0 deletions src/chronify/duckdb/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterable
from datetime import datetime, timedelta
from pathlib import Path

import duckdb
from duckdb import DuckDBPyRelation
Expand Down Expand Up @@ -29,6 +30,17 @@ def add_datetime_column(
# )


def make_write_parquet_query(table_or_view: str, file_path: Path | str) -> str:
"""Make an SQL string that can be used to write a Parquet file from a table or view."""
# TODO: Hive partitioning?
return f"""
COPY
(SELECT * FROM {table_or_view})
TO '{file_path}'
(FORMAT 'parquet');
"""


def unpivot(
rel: DuckDBPyRelation,
pivoted_columns: Iterable[str],
Expand Down
65 changes: 51 additions & 14 deletions src/chronify/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import pandas as pd
import polars as pl
from sqlalchemy import Column, Engine, MetaData, Table, create_engine, text
from loguru import logger
from sqlalchemy import Column, Engine, MetaData, Selectable, Table, create_engine, text

import chronify.duckdb.functions as ddbf
from chronify.exceptions import ConflictingInputsError, InvalidTable
from chronify.csv_io import read_csv
from chronify.duckdb.functions import unpivot as duckdb_unpivot
from chronify.duckdb.functions import add_datetime_column
from chronify.models import (
CsvTableSchema,
TableSchema,
Expand All @@ -18,6 +18,8 @@
)
from chronify.time_configs import DatetimeRange, IndexTimeRange
from chronify.time_series_checker import TimeSeriesChecker
from chronify.utils.sql import make_temp_view_name
from chronify.utils.sqlalchemy_view import create_view


class Store:
Expand Down Expand Up @@ -71,6 +73,11 @@ def engine(self) -> Engine:
"""Return the sqlalchemy engine."""
return self._engine

@property
def metadata(self) -> MetaData:
"""Return the sqlalchemy metadata."""
return self._metadata

def create_view_from_parquet(self, name: str, path: Path) -> None:
"""Create a view in the database from a Parquet file."""
with self._engine.begin() as conn:
Expand All @@ -84,12 +91,6 @@ def create_view_from_parquet(self, name: str, path: Path) -> None:
conn.commit()
self.update_table_schema()

# def export_csv(self, table: str, path: Path) -> None:
# """Export a table or view to a CSV file."""

# def export_parquet(self, table: str, path: Path) -> None:
# """Export a table or view to a Parquet file."""

def ingest_from_csv(
self,
path: Path | str,
Expand All @@ -105,7 +106,7 @@ def ingest_from_csv(
check_schema_compatibility(src_schema, dst_schema)

if src_schema.pivoted_dimension_name is not None:
rel = duckdb_unpivot(
rel = ddbf.unpivot(
rel,
src_schema.value_columns,
src_schema.pivoted_dimension_name,
Expand All @@ -114,7 +115,7 @@ def ingest_from_csv(

if isinstance(src_schema.time_config, IndexTimeRange):
if isinstance(dst_schema.time_config, DatetimeRange):
rel = add_datetime_column(
rel = ddbf.add_datetime_column(
rel=rel,
start=dst_schema.time_config.start,
resolution=dst_schema.time_config.resolution,
Expand All @@ -140,13 +141,49 @@ def ingest_from_csv(
with self._engine.begin() as conn:
df.write_database(dst_schema.name, connection=conn, if_table_exists="append")
conn.commit()
self.update_table_schema()

def read_table(self, name: str, query: Optional[str] = None) -> pd.DataFrame:
def read_table(self, name: str, query: Optional[Selectable | str] = None) -> pd.DataFrame:
"""Return the table as a pandas DataFrame, optionally applying a query."""
if query is None:
query_ = f"select * from {name}"
elif isinstance(query, Selectable) and self.engine.name == "duckdb":
# TODO: unsafe. Need duckdb_engine support.
# https://github.com/Mause/duckdb_engine/issues/1119
# https://github.com/pola-rs/polars/issues/19221
query_ = str(query.compile(compile_kwargs={"literal_binds": True}))
else:
query_ = query

with self._engine.begin() as conn:
query_ = query or f"select * from {name}"
return pl.read_database(query_, connection=conn).to_pandas()

def write_query_to_parquet(self, stmt: Selectable, file_path: Path | str) -> None:
"""Write the result of a query to a Parquet file."""
view_name = make_temp_view_name()
create_view(view_name, stmt, self._engine, self._metadata)
try:
self.write_table_to_parquet(view_name, file_path)
finally:
with self._engine.connect() as conn:
conn.execute(text(f"DROP VIEW {view_name}"))

def write_table_to_parquet(self, name: str, file_path: Path | str) -> None:
"""Write a table or view to a Parquet file."""
match self._engine.name:
case "duckdb":
cmd = ddbf.make_write_parquet_query(name, file_path)
# case "spark":
# pass
case _:
msg = f"{self.engine.name=}"
raise NotImplementedError(msg)

with self._engine.connect() as conn:
conn.execute(text(cmd))

logger.info("Wrote table or view to {}", file_path)

def has_table(self, name: str) -> bool:
"""Return True if the database has a table with the given name."""
return name in self._metadata.tables
Expand Down Expand Up @@ -179,7 +216,7 @@ def _check_table_schema(self, schema: TableSchema) -> None:
check_columns(columns, schema.list_columns())

def _check_timestamps(self, schema: TableSchema) -> None:
checker = TimeSeriesChecker(self._engine)
checker = TimeSeriesChecker(self._engine, self._metadata)
checker.check_timestamps(schema)


Expand Down
23 changes: 13 additions & 10 deletions src/chronify/time_series_checker.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
from sqlalchemy import Connection, Engine, text
from sqlalchemy import Connection, Engine, MetaData, Table, text

from chronify.exceptions import InvalidTable
from chronify.models import TableSchema
from chronify.utils.sql import make_temp_view_name


class TimeSeriesChecker:
"""Performs checks on time series arrays in a table."""

def __init__(self, engine: Engine) -> None:
def __init__(self, engine: Engine, metadata: MetaData) -> None:
self._engine = engine
self._metadata = metadata

def check_timestamps(self, schema: TableSchema) -> None:
# TODO: the conn.execute calls here are slow but not vulnerable to sql injection.
# Data extraction should never be large.
# Consider changing when we have a better implementation.
self._check_expected_timestamps(schema)
self._check_expected_timestamps_by_time_array(schema)

def _check_expected_timestamps(self, schema: TableSchema) -> None:
expected = schema.time_config.list_timestamps()
with self._engine.connect() as conn:
filters = []
table = Table(schema.name, self._metadata)
stmt = table.select().distinct()
for col in schema.time_config.time_columns:
filters.append(f"{col} IS NOT NULL")
time_cols = ",".join(schema.time_config.time_columns)
where_clause = "AND ".join(filters)
query = f"SELECT DISTINCT {time_cols} FROM {schema.name} WHERE {where_clause}"
actual = set(schema.time_config.convert_database_timestamps(conn.execute(text(query))))
stmt = stmt.where(table.c[col].is_not(None))
# This is slow, but the size should never be large.
actual = set(schema.time_config.convert_database_timestamps(conn.execute(stmt)))
diff = actual.symmetric_difference(expected)
if diff:
msg = f"Actual timestamps do not match expected timestamps: {diff}"
Expand All @@ -32,7 +36,7 @@ def _check_expected_timestamps(self, schema: TableSchema) -> None:

def _check_expected_timestamps_by_time_array(self, schema: TableSchema) -> None:
with self._engine.connect() as conn:
tmp_name = "tss_tmp_table"
tmp_name = make_temp_view_name()
self._run_timestamp_checks_on_tmp_table(schema, conn, tmp_name)
conn.execute(text(f"DROP TABLE IF EXISTS {tmp_name}"))

Expand All @@ -41,7 +45,6 @@ def _run_timestamp_checks_on_tmp_table(schema: TableSchema, conn: Connection, ta
id_cols = ",".join(schema.time_array_id_columns)
filters = [f"{x} IS NOT NULL" for x in schema.time_config.time_columns]
where_clause = "AND ".join(filters)
# TODO DT: change to expression API, try to avoid creating the temp table, delete if used.
query = f"""
CREATE TEMP TABLE {table_name} AS
SELECT
Expand Down
Empty file added src/chronify/utils/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions src/chronify/utils/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from uuid import uuid4


TEMP_TABLE_PREFIX = "chronify"


def make_temp_view_name() -> str:
"""Make a random name to be used as a temporary view or table."""
return f"{TEMP_TABLE_PREFIX}_{uuid4().hex}"
57 changes: 57 additions & 0 deletions src/chronify/utils/sqlalchemy_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copied this code from https://github.com/sqlalchemy/sqlalchemy/wiki/Views

import sqlalchemy as sa
from sqlalchemy import Engine, MetaData, Selectable, TableClause
from sqlalchemy.ext import compiler
from sqlalchemy.schema import DDLElement
from sqlalchemy.sql import table


class CreateView(DDLElement):
def __init__(self, name, selectable):
self.name = name
self.selectable = selectable


class DropView(DDLElement):
def __init__(self, name):
self.name = name


@compiler.compiles(CreateView)
def _create_view(element, compiler, **kw):
return "CREATE VIEW %s AS %s" % (
element.name,
compiler.sql_compiler.process(element.selectable, literal_binds=True),
)


@compiler.compiles(DropView)
def _drop_view(element, compiler, **kw):
return "DROP VIEW %s" % (element.name)


def _view_exists(ddl, target, connection, **kw):
return ddl.name in sa.inspect(connection).get_view_names()


def _view_doesnt_exist(ddl, target, connection, **kw):
return not _view_exists(ddl, target, connection, **kw)


def create_view(
name: str, selectable: Selectable, engine: Engine, metadata: MetaData
) -> TableClause:
"""Create a view from a selectable."""
view = table(name)
view._columns._populate_separate_keys(
col._make_proxy(view) for col in selectable.selected_columns
)
sa.event.listen(
metadata,
"after_create",
CreateView(name, selectable).execute_if(callable_=_view_doesnt_exist),
)
sa.event.listen(metadata, "before_drop", DropView(name).execute_if(callable_=_view_exists))
metadata.create_all(engine)
return view
16 changes: 15 additions & 1 deletion tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime, timedelta

import duckdb
import pandas as pd
import pytest
from sqlalchemy import Double
from sqlalchemy import Double, Table, select
from chronify.csv_io import read_csv
from chronify.duckdb.functions import unpivot

Expand Down Expand Up @@ -94,3 +95,16 @@ def test_load_parquet(tmp_path):
rel3.to_parquet(str(out_file))
store = Store()
store.load_table(out_file, dst_schema)


def test_to_parquet(tmp_path, generators_schema):
src_schema, dst_schema = generators_schema
store = Store()
store.ingest_from_csv(GENERATOR_TIME_SERIES_FILE, src_schema, dst_schema)
filename = tmp_path / "data.parquet"
table = Table(dst_schema.name, store.metadata)
stmt = select(table).where(table.c.generator == "gen2")
store.write_query_to_parquet(stmt, filename)
assert filename.exists()
df = pd.read_parquet(filename)
assert len(df) == 8784

0 comments on commit bcbaa7f

Please sign in to comment.