Skip to content

Commit

Permalink
Revert to using pandas based batched query, removing dask as a depend…
Browse files Browse the repository at this point in the history
…ency
  • Loading branch information
amrit110 committed Nov 2, 2023
1 parent 83188e5 commit 9b5cce8
Show file tree
Hide file tree
Showing 11 changed files with 787 additions and 456 deletions.
45 changes: 27 additions & 18 deletions cycquery/interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""A query interface class to wrap database objects and queries."""

import logging
from typing import List, Literal, Optional, Tuple, Union
from typing import Generator, List, Literal, Optional, Tuple, Union

import dask.dataframe as dd
import pandas as pd
from sqlalchemy.sql.elements import BinaryExpression

Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(
self._data = None

@property
def data(self) -> Optional[Union[pd.DataFrame, dd.core.DataFrame]]:
def data(self) -> Union[pd.DataFrame, None]:
"""Get data."""
return self._data

Expand Down Expand Up @@ -176,10 +175,10 @@ def union_all(
def run(
self,
limit: Optional[int] = None,
backend: Literal["pandas", "dask", "datasets"] = "pandas",
index_col: Optional[str] = None,
n_partitions: Optional[int] = None,
) -> Union[pd.DataFrame, dd.core.DataFrame]:
batch_mode: bool = False,
batch_size: int = 1000000,
) -> Union[pd.DataFrame, Generator[pd.DataFrame, None, None]]:
"""Run the query, and fetch data.
Parameters
Expand All @@ -191,22 +190,29 @@ def run(
index_col
Column which becomes the index, and defines the partitioning.
Should be a indexed column in the SQL server, and any orderable type.
n_partitions
Number of partitions. Check dask documentation for additional details.
batch_mode
Whether to run the query in batch mode. A generator is returned if True.
batch_size
Batch size for the query, default 1 million rows.
Returns
-------
pandas.DataFrame or dask.DataFrame or datasets.Dataset
pandas.DataFrame or Generator[pandas.DataFrame, None, None]
Query result.
"""
self._data = self.database.run_query(
self.query,
limit=limit,
backend=backend,
index_col=index_col,
n_partitions=n_partitions,
)
if not batch_mode:
self._data = self.database.run_query(
self.query,
limit=limit,
index_col=index_col,
)
else:
self._data = self.database.run_query_batch(
self.query,
index_col=index_col,
batch_size=batch_size,
)

return self._data

Expand All @@ -232,8 +238,11 @@ def save(
"""
# If the query was already run.
if self._data is not None:
return save_dataframe(self._data, path, file_format=file_format)

if isinstance(self._data, pd.DataFrame):
return save_dataframe(self._data, path, file_format=file_format)
if isinstance(self._data, Generator):
for i, df in enumerate(self._data):
save_dataframe(df, f"{path}/batch-{i:03d}", file_format=file_format)
# Save without running.
if file_format == "csv":
path = self.database.save_query_to_csv(self.query, path)
Expand Down
40 changes: 40 additions & 0 deletions cycquery/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3260,3 +3260,43 @@ def __call__(self, table: TableTypes) -> Subquery:
table = _process_checks(table, cols=cols)

return select(table).distinct(*get_columns(table, cols)).subquery()


class Count(QueryOp):
"""Count the number of rows.
Parameters
----------
col
Column to count.
Examples
--------
>>> Count("person_id")(table)
"""

def __init__(self, col: str):
super().__init__()
self.col = col

def __call__(self, table: TableTypes) -> Subquery:
"""Process the table.
Parameters
----------
table
Table on which to perform the operation.
col
Column to count.
Returns
-------
sqlalchemy.sql.selectable.Subquery
Processed table.
"""
table = _process_checks(table, cols=self.col)
count = func.count(get_column(table, self.col))

return select(count).subquery()
170 changes: 140 additions & 30 deletions cycquery/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
import os
import socket
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, Generator, List, Optional, Union
from urllib.parse import quote_plus

import dask.dataframe as dd
import pandas as pd
import pyarrow.csv as pv
import pyarrow.parquet as pq
from sqlalchemy import MetaData, create_engine, inspect
from sqlalchemy import MetaData, and_, create_engine, func, inspect, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
from sqlalchemy.sql.selectable import Select, Subquery

from cycquery.util import (
DBSchema,
DBTable,
TableTypes,
get_attr_name,
get_column,
table_params_to_type,
)
from cycquery.utils.file import exchange_extension, process_file_save_path
Expand Down Expand Up @@ -196,10 +197,8 @@ def run_query(
self,
query: Union[TableTypes, str],
limit: Optional[int] = None,
backend: Literal["pandas", "dask", "datasets"] = "pandas",
index_col: Optional[str] = None,
n_partitions: Optional[int] = None,
) -> Union[pd.DataFrame, dd.core.DataFrame]:
) -> pd.DataFrame:
"""Run query.
Parameters
Expand All @@ -208,48 +207,27 @@ def run_query(
Query to run.
limit
Limit query result to limit.
backend
Backend library to use, Pandas or Dask or HF datasets.
index_col
Column which becomes the index, and defines the partitioning.
Should be a indexed column in the SQL server, and any orderable type.
n_partitions
Number of partitions. Check dask documentation for additional details.
Returns
-------
pandas.DataFrame or dask.DataFrame
pandas.DataFrame
Extracted data from query.
"""
if isinstance(query, str) and limit is not None:
raise ValueError(
"Cannot use limit argument when running raw SQL string query!",
)
if backend in ["pandas", "datasets"] and n_partitions is not None:
raise ValueError(
"Partitions not applicable with pandas or datasets backend, use dask!",
)
# Limit the results returned.
if limit is not None:
query = query.limit(limit) # type: ignore

# Run the query and return the results.
with self.session.connection():
if backend == "pandas":
data = pd.read_sql_query(query, self.engine, index_col=index_col)
elif backend == "dask":
data = dd.read_sql_query( # type: ignore
query,
self.conn,
index_col=index_col,
npartitions=n_partitions,
)
data = data.reset_index(drop=False)
else:
raise ValueError(
"Invalid backend, can either be pandas or dask or datasets!",
)
data = pd.read_sql_query(query, self.engine, index_col=index_col)
LOGGER.info("Query returned successfully!")

return data
Expand Down Expand Up @@ -311,3 +289,135 @@ def save_query_to_parquet(self, query: TableTypes, path: str) -> str:
pq.write_table(table, path)

return path

def _query_batch_conditions(
self,
query: TableTypes,
index_col: str,
batch_size: int,
) -> List[Union[BinaryExpression, BooleanClauseList]]:
"""Return a list of WHERE conditions to segment a query into batches.
Batches are created via SQL windowing, based on segmenting the values in a
given column, such as an ID column, into intervals.
Requires a database that supports window functions.
Parameters
----------
query
Query to run.
index_col
Name of the sample ID column by which to batch.
batch_size
Batch size for the query.
Returns
-------
list of sqlalchemy.sql.elements.BinaryExpression or
sqlalchemy.sql.elements.BooleanClauseList
The window conditions on which to filter.
"""

def _compute_query_dividers(
query: Subquery,
index_col: str,
maximum: int,
) -> List[int]:
# Compute the row count for each unique value
col = get_column(query, index_col)
table = select(col, func.count(col).label("count")).group_by(col)
count_data = self.run_query(table)

# Check that all values can actually fit into the maximum batch size
max_count = count_data["count"].max()
if maximum < max_count:
raise ValueError(f"Maximum must be at least {max_count}.")

# Sort and create a cumulative sum of row counts
count_data = count_data.sort_values(index_col)
count_data["cumsum"] = count_data["count"].cumsum()

# Create query dividers
last_sum = 0

if len(count_data) == 0:
raise ValueError("Query is empty. Cannot return batched results.")

dividers = [int(count_data[index_col].iloc[0])]
for i, cumsum in enumerate(count_data["cumsum"].values[1:]):
# If adding the next value will put the sum over the max,
# then add another divider on the previous value
if cumsum - last_sum > maximum:
dividers.append(int(count_data[index_col].iloc[i]))
last_sum = count_data["cumsum"].iloc[i]

return dividers

def _range_condition(
start_id: int,
end_id: Optional[int] = None,
) -> Union[BinaryExpression, BooleanClauseList]:
if end_id:
return and_(column >= start_id, column < end_id)

return column >= start_id

# Create interval dividers
dividers = _compute_query_dividers(query, index_col, batch_size)

# Create filtering conditions
column = get_column(query, index_col)
conditions = []
while dividers:
# Create interval ranges
start = dividers.pop(0)
end = dividers[0] if dividers else None

# Create condition
conditions.append(_range_condition(start, end))

return conditions

@table_params_to_type(Subquery)
def run_query_batch(
self,
query: TableTypes,
index_col: str,
batch_size: int,
) -> Generator[pd.DataFrame, None, None]:
"""Generate query batches with complete sets of IDs in a batch.
Queries are sorted and grouped such that the rows for a given sample ID are kept
together in a single batch.
Parameters
----------
query
Query to run.
index_col
Name of the sample ID column by which to batch.
batch_size
Batch size for the query. Since the partitioning happens on the index
column, the batch size is the approximate number of rows that will
be returned in a batch.
Yields
------
pandas.DataFrame
A query batch with complete sets of sample IDs.
"""
if "limit" in str(query).lower():
raise NotImplementedError(
"Currently not supporting batching for queries with a LIMIT.",
)

conditions = self._query_batch_conditions(query, index_col, batch_size)
sess_query = self.session.query(query)

# Opportunity for easy multi-processing/parallelization here!
for condition in conditions:
run = (sess_query.filter(condition)).subquery()
yield pd.read_sql_query(run, self.engine)
Loading

0 comments on commit 9b5cce8

Please sign in to comment.