Skip to content

Commit

Permalink
Use polars for database I/O
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-thom committed Oct 14, 2024
1 parent 321c5a5 commit 87ac971
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 35 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ dependencies = [
"duckdb ~= 1.1.0",
"duckdb_engine",
"loguru",
"rich",
"pandas >= 2.2, < 3",
"polars >= 1.9, < 2",
"pyarrow",
"pydantic >= 2.7, < 3",
"pytz",
"rich",
"sqlalchemy",
"tzdata",
]
Expand Down
90 changes: 90 additions & 0 deletions scripts/perf_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from datetime import datetime, timedelta

import duckdb
import pandas as pd
import polars as pl
from IPython import get_ipython

from sqlalchemy import Double, text
from chronify.models import ColumnDType, CsvTableSchema, TableSchema
from chronify.store import Store
from chronify.time import TimeIntervalType, TimeZone
from chronify.time_configs import DatetimeRange


GENERATOR_TIME_SERIES_FILE = "tests/data/gen.csv"


def read_duckdb(conn, name: str):
return conn.sql(f"SELECT * FROM {name}").df()


def read_pandas(store: Store, name: str):
with store.engine.begin() as conn:
query = f"select * from {name}"
return pd.read_sql(query, conn)


def read_polars(store: Store, name: str):
with store.engine.begin() as conn:
query = f"select * from {name}"
return pl.read_database(query, connection=conn).to_pandas()


def read_sqlalchemy(store: Store, name: str):
with store.engine.begin() as conn:
query = f"select * from {name}"
res = conn.execute(text(query)).fetchall()
return pd.DataFrame.from_records(res, columns=["timestamp", "generator", "value"])


def setup():
time_config = DatetimeRange(
start=datetime(year=2020, month=1, day=1),
resolution=timedelta(hours=1),
length=8784,
time_interval_type=TimeIntervalType.PERIOD_BEGINNING,
time_columns=["timestamp"],
time_zone=TimeZone.UTC,
)

src_schema = CsvTableSchema(
time_config=time_config,
column_dtypes=[
ColumnDType(name="gen1", dtype=Double),
ColumnDType(name="gen2", dtype=Double),
ColumnDType(name="gen3", dtype=Double),
],
value_columns=["gen1", "gen2", "gen3"],
pivoted_dimension_name="generator",
time_array_id_columns=[],
)
dst_schema = TableSchema(
name="generators",
time_config=time_config,
time_array_id_columns=["generator"],
value_column="value",
)
return src_schema, dst_schema


def run_test(engine_name: str):
store = Store(engine_name=engine_name)
src_schema, dst_schema = setup()
store.ingest_from_csv(GENERATOR_TIME_SERIES_FILE, src_schema, dst_schema)
ipython = get_ipython()
df = read_polars(store, dst_schema.name) # noqa: F841
conn = duckdb.connect(":memory:")
conn.sql("CREATE OR REPLACE TABLE perf_test AS SELECT * from df")
print(f"Run {engine_name} database with read_duckdb.")
ipython.run_line_magic("timeit", "read_duckdb(conn, 'perf_test')")
print(f"Run {engine_name} database with read_pandas.")
ipython.run_line_magic("timeit", "read_pandas(store, dst_schema.name)")
print(f"Run {engine_name} database with read_polars.")
ipython.run_line_magic("timeit", "read_polars(store, dst_schema.name)")
print(f"Run {engine_name} database with read_sqlalchemy.")
ipython.run_line_magic("timeit", "read_sqlalchemy(store, dst_schema.name)")


run_test("duckdb")
run_test("sqlite")
12 changes: 12 additions & 0 deletions src/chronify/duckdb/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterable
from datetime import datetime, timedelta
from pathlib import Path

import duckdb
from duckdb import DuckDBPyRelation
Expand Down Expand Up @@ -29,6 +30,17 @@ def add_datetime_column(
# )


def make_write_parquet_query(table_or_view: str, file_path: Path | str) -> str:
"""Make an SQL string that can be used to write a Parquet file from a table or view."""
# TODO: Hive partitioning?
return f"""
COPY
(SELECT * FROM {table_or_view})
TO '{file_path}'
(FORMAT 'parquet');
"""


def unpivot(
rel: DuckDBPyRelation,
pivoted_columns: Iterable[str],
Expand Down
4 changes: 4 additions & 0 deletions src/chronify/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ class ChronifyExceptionBase(Exception):
"""Base class for exceptions in this package"""


class ConflictingInputsError(ChronifyExceptionBase):
"""Raised when user inputs conflict with each other."""


class InvalidTable(ChronifyExceptionBase):
"""Raised when a table does not match its schema."""

Expand Down
112 changes: 88 additions & 24 deletions src/chronify/store.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import itertools
from collections.abc import Iterable
from pathlib import Path
from typing import Optional

from sqlalchemy import Column, Engine, MetaData, Table, create_engine, text
import pandas as pd
import polars as pl
from loguru import logger
from sqlalchemy import Column, Engine, MetaData, Selectable, Table, create_engine, text

from chronify.exceptions import InvalidTable
import chronify.duckdb.functions as ddbf
from chronify.exceptions import ConflictingInputsError, InvalidTable
from chronify.csv_io import read_csv
from chronify.duckdb.functions import unpivot as duckdb_unpivot
from chronify.duckdb.functions import add_datetime_column
from chronify.models import (
CsvTableSchema,
TableSchema,
Expand All @@ -17,22 +18,44 @@
)
from chronify.time_configs import DatetimeRange, IndexTimeRange
from chronify.time_series_checker import TimeSeriesChecker
from chronify.utils.sql import make_temp_view_name
from chronify.utils.sqlalchemy_view import create_view


class Store:
"""Data store for time series data"""

def __init__(self, engine: Optional[Engine] = None, **connect_kwargs) -> None:
def __init__(
self, engine: Optional[Engine] = None, engine_name: Optional[str] = None, **connect_kwargs
) -> None:
"""Construct the Store.
Parameters
----------
engine: sqlalchemy.Engine
Optional, defaults to a DuckDB engine.
Optional, defaults to a engine connected to an in-memory DuckDB database.
Examples
--------
>>> from sqlalchemy
>>> store1 = Store()
>>> store2 = Store(engine=Engine("duckdb:///time_series.db")
>>> store3 = Store(engine=Engine(engine_name="sqlite")
>>> store4 = Store(engine=Engine("sqlite:///time_series.db")
"""
self._metadata = MetaData()
if engine and engine_name:
msg = f"{engine=} and {engine_name=} cannot both be set"
raise ConflictingInputsError(msg)
if engine is None:
self._engine = create_engine("duckdb:///:memory:", **connect_kwargs)
name = engine_name or "duckdb"
match name:
case "duckdb" | "sqlite":
engine_path = f"{name}:///:memory:"
case _:
msg = f"{engine_name=}"
raise NotImplementedError(msg)
self._engine = create_engine(engine_path, **connect_kwargs)
else:
self._engine = engine

Expand All @@ -45,6 +68,16 @@ def __init__(self, engine: Optional[Engine] = None, **connect_kwargs) -> None:
# def add_time_series(self, data: np.ndarray) -> None:
# """Add a time series array to the store."""

@property
def engine(self) -> Engine:
"""Return the sqlalchemy engine."""
return self._engine

@property
def metadata(self) -> MetaData:
"""Return the sqlalchemy metadata."""
return self._metadata

def create_view_from_parquet(self, name: str, path: Path) -> None:
"""Create a view in the database from a Parquet file."""
with self._engine.begin() as conn:
Expand All @@ -58,12 +91,6 @@ def create_view_from_parquet(self, name: str, path: Path) -> None:
conn.commit()
self.update_table_schema()

# def export_csv(self, table: str, path: Path) -> None:
# """Export a table or view to a CSV file."""

# def export_parquet(self, table: str, path: Path) -> None:
# """Export a table or view to a Parquet file."""

def ingest_from_csv(
self,
path: Path | str,
Expand All @@ -79,7 +106,7 @@ def ingest_from_csv(
check_schema_compatibility(src_schema, dst_schema)

if src_schema.pivoted_dimension_name is not None:
rel = duckdb_unpivot(
rel = ddbf.unpivot(
rel,
src_schema.value_columns,
src_schema.pivoted_dimension_name,
Expand All @@ -88,7 +115,7 @@ def ingest_from_csv(

if isinstance(src_schema.time_config, IndexTimeRange):
if isinstance(dst_schema.time_config, DatetimeRange):
rel = add_datetime_column(
rel = ddbf.add_datetime_column(
rel=rel,
start=dst_schema.time_config.start,
resolution=dst_schema.time_config.resolution,
Expand All @@ -110,15 +137,52 @@ def ingest_from_csv(
table = Table(dst_schema.name, self._metadata, *columns)
table.create(self._engine)

values = rel.fetchall()
columns = table.columns.keys()
placeholder = ",".join(itertools.repeat("?", len(columns)))
cols = ",".join(columns)
df = rel.pl()
with self._engine.begin() as conn:
query = f"INSERT INTO {dst_schema.name} ({cols}) VALUES ({placeholder})"
conn.exec_driver_sql(query, values)
query = f"select * from {dst_schema.name}"
df.write_database(dst_schema.name, connection=conn, if_table_exists="append")
conn.commit()
self.update_table_schema()

def read_table(self, name: str, query: Optional[Selectable | str] = None) -> pd.DataFrame:
"""Return the table as a pandas DataFrame, optionally applying a query."""
if query is None:
query_ = f"select * from {name}"
elif isinstance(query, Selectable) and self.engine.name == "duckdb":
# TODO: unsafe. Need duckdb_engine support.
# https://github.com/Mause/duckdb_engine/issues/1119
# https://github.com/pola-rs/polars/issues/19221
query_ = str(query.compile(compile_kwargs={"literal_binds": True}))
else:
query_ = query

with self._engine.begin() as conn:
return pl.read_database(query_, connection=conn).to_pandas()

def write_query_to_parquet(self, stmt: Selectable, file_path: Path | str) -> None:
"""Write the result of a query to a Parquet file."""
view_name = make_temp_view_name()
create_view(view_name, stmt, self._engine, self._metadata)
try:
self.write_table_to_parquet(view_name, file_path)
finally:
with self._engine.connect() as conn:
conn.execute(text(f"DROP VIEW {view_name}"))

def write_table_to_parquet(self, name: str, file_path: Path | str) -> None:
"""Write a table or view to a Parquet file."""
match self._engine.name:
case "duckdb":
cmd = ddbf.make_write_parquet_query(name, file_path)
# case "spark":
# pass
case _:
msg = f"{self.engine.name=}"
raise NotImplementedError(msg)

with self._engine.connect() as conn:
conn.execute(text(cmd))

logger.info("Wrote table or view to {}", file_path)

def has_table(self, name: str) -> bool:
"""Return True if the database has a table with the given name."""
Expand Down Expand Up @@ -152,7 +216,7 @@ def _check_table_schema(self, schema: TableSchema) -> None:
check_columns(columns, schema.list_columns())

def _check_timestamps(self, schema: TableSchema) -> None:
checker = TimeSeriesChecker(self._engine)
checker = TimeSeriesChecker(self._engine, self._metadata)
checker.check_timestamps(schema)


Expand Down
22 changes: 13 additions & 9 deletions src/chronify/time_series_checker.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
from sqlalchemy import Connection, Engine, text
from sqlalchemy import Connection, Engine, MetaData, Table, text

from chronify.exceptions import InvalidTable
from chronify.models import TableSchema
from chronify.utils.sql import make_temp_view_name


class TimeSeriesChecker:
"""Performs checks on time series arrays in a table."""

def __init__(self, engine: Engine) -> None:
def __init__(self, engine: Engine, metadata: MetaData) -> None:
self._engine = engine
self._metadata = metadata

def check_timestamps(self, schema: TableSchema) -> None:
# TODO: the conn.execute calls here are slow but not vulnerable to sql injection.
# Data extraction should never be large.
# Consider changing when we have a better implementation.
self._check_expected_timestamps(schema)
self._check_expected_timestamps_by_time_array(schema)

def _check_expected_timestamps(self, schema: TableSchema) -> None:
expected = schema.time_config.list_timestamps()
with self._engine.connect() as conn:
filters = []
table = Table(schema.name, self._metadata)
stmt = table.select().distinct()
for col in schema.time_config.time_columns:
filters.append(f"{col} IS NOT NULL")
time_cols = ",".join(schema.time_config.time_columns)
where_clause = "AND ".join(filters)
query = f"SELECT DISTINCT {time_cols} FROM {schema.name} WHERE {where_clause}"
actual = set(schema.time_config.convert_database_timestamps(conn.execute(text(query))))
stmt = stmt.where(table.c[col].is_not(None))
# This is slow, but the size should never be large.
actual = set(schema.time_config.convert_database_timestamps(conn.execute(stmt)))
diff = actual.symmetric_difference(expected)
if diff:
msg = f"Actual timestamps do not match expected timestamps: {diff}"
Expand All @@ -32,7 +36,7 @@ def _check_expected_timestamps(self, schema: TableSchema) -> None:

def _check_expected_timestamps_by_time_array(self, schema: TableSchema) -> None:
with self._engine.connect() as conn:
tmp_name = "tss_tmp_table"
tmp_name = make_temp_view_name()
self._run_timestamp_checks_on_tmp_table(schema, conn, tmp_name)
conn.execute(text(f"DROP TABLE IF EXISTS {tmp_name}"))

Expand Down
Empty file added src/chronify/utils/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions src/chronify/utils/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from uuid import uuid4


TEMP_TABLE_PREFIX = "chronify"


def make_temp_view_name() -> str:
"""Make a random name to be used as a temporary view or table."""
return f"{TEMP_TABLE_PREFIX}_{uuid4().hex}"
Loading

0 comments on commit 87ac971

Please sign in to comment.