diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b655081e0..4aac6b32f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,7 +84,7 @@ jobs: run: | # explicitly install docker, fugue and sqlalchemy package conda install sqlalchemy psycopg2 -c conda-forge - pip install docker fugue + pip install docker "fugue<=0.5.0" if: matrix.os == 'ubuntu-latest' - name: Install Java (again) and test with pytest shell: bash -l {0} diff --git a/dask_sql/context.py b/dask_sql/context.py index 77ae10b6f..cced80548 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -68,6 +68,8 @@ def __init__(self): """ # Storage for the registered tables self.tables = {} + # Storage for the registered views + self.views = {} # Storage for the registered functions self.functions: Dict[str, Callable] = {} self.function_list: List[FunctionDescription] = [] @@ -86,6 +88,8 @@ def __init__(self): RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False) + RelConverter.add_plugin_class(logical.LogicalIntersectPlugin, replace=False) + RelConverter.add_plugin_class(logical.LogicalMinusPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False) @@ -116,6 +120,7 @@ def create_table( input_table: InputType, format: str = None, persist: bool = True, + sql: str = None, **kwargs, ): """ @@ -191,6 +196,8 @@ def create_table( **kwargs, ) self.tables[table_name.lower()] = dc + if sql is not None: + self.views[table_name.lower()] = sql def register_dask_table(self, df: dd.DataFrame, name: str): """ @@ -448,11 +455,17 @@ def _prepare_schema(self): logger.warning("No tables are registered.") for name, dc in self.tables.items(): - table = DaskTable(name) df = dc.df - logger.debug( - f"Adding table '{name}' to schema with columns: {list(df.columns)}" - ) + if name in self.views: + table = DaskTable(name, self.views[name]) + logger.debug( + f"Adding materialied table '{name}' to schema with columns: {list(df.columns)}" + ) + else: + table = DaskTable(name) + logger.debug( + f"Adding table '{name}' to schema with columns: {list(df.columns)}" + ) for column in df.columns: data_type = df[column].dtype sql_data_type = python_to_sql_type(data_type) @@ -513,8 +526,14 @@ def _get_ral(self, sql): else: validatedSqlNode = generator.getValidatedNode(sqlNode) nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode) + rel_string_non_op = str(generator.getRelationalAlgebraString(nonOptimizedRelNode)) + rel_non_op_count = rel_string_non_op.count('\n') rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode) rel_string = str(generator.getRelationalAlgebraString(rel)) + logger.debug( + f"Non optimised query plan: {rel_non_op_count} ops\n " + f"{rel_string_non_op}" + ) except (ValidationException, SqlParseException) as e: logger.debug(f"Original exception raised by Java:\n {e}") # We do not want to re-raise an exception here @@ -544,8 +563,8 @@ def _get_ral(self, sql): "Not extracting output column names as the SQL is not a SELECT call" ) select_names = None - - logger.debug(f"Extracted relational algebra:\n {rel_string}") + br = '\n' + logger.debug(f"Extracted relational algebra {rel_string.count(br)} ops:\n {rel_string}") return rel, select_names, rel_string def _to_sql_string(self, s: "org.apache.calcite.sql.SqlNode", default_dialect=None): diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index bddfbcad4..c471f9173 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -42,7 +42,8 @@ def _copy(self) -> ColumnContainer: Internal function to copy this container """ return ColumnContainer( - self._frontend_columns.copy(), self._frontend_backend_mapping.copy() + self._frontend_columns.copy(), + self._frontend_backend_mapping.copy(), ) def limit_to(self, fields: List[str]) -> ColumnContainer: @@ -137,7 +138,9 @@ def make_unique(self, prefix="col"): where is the column index. """ return self.rename( - columns={str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns)} + columns={ + str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns) + } ) @@ -166,11 +169,22 @@ def assign(self) -> dd.DataFrame: a dataframe which has the the columns specified in the stored ColumnContainer. """ - df = self.df.assign( - **{ - col_from: self.df[col_to] - for col_from, col_to in self.column_container.mapping() - if col_from in self.column_container.columns - } - ) + # We rename as many cols as possible because renaming is much more + # efficient than assigning. + + renames = {} + assigns = {} + for col_from, col_to in self.column_container.mapping(): + if col_from in self.column_container.columns: + if ( + len(renames) < len(self.df.columns) + and col_to not in renames + and (col_from not in self.df.columns or col_from == col_to) + ): + renames[col_to] = col_from + else: + assigns[col_from] = self.df[col_to] + df = self.df.rename(columns=renames) + if len(assigns) > 0: + df = df.assign(**assigns) return df[self.column_container.columns] diff --git a/dask_sql/physical/rel/convert.py b/dask_sql/physical/rel/convert.py index 77f8a3ef7..d7ef0a263 100644 --- a/dask_sql/physical/rel/convert.py +++ b/dask_sql/physical/rel/convert.py @@ -1,4 +1,5 @@ import logging +import time import dask.dataframe as dd @@ -53,6 +54,8 @@ def convert( logger.debug( f"Processing REL {rel} using {plugin_instance.__class__.__name__}..." ) + start_time = time.perf_counter() df = plugin_instance.convert(rel, context=context) - logger.debug(f"Processed REL {rel} into {LoggableDataFrame(df)}") + elapsed_time = time.perf_counter() - start_time + logger.debug(f"Processed REL {rel} into {LoggableDataFrame(df)} ({elapsed_time}s)") return df diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 99698157c..eac53d05a 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -6,6 +6,8 @@ from .sort import LogicalSortPlugin from .table_scan import LogicalTableScanPlugin from .union import LogicalUnionPlugin +from .intersect import LogicalIntersectPlugin +from .minus import LogicalMinusPlugin from .values import LogicalValuesPlugin __all__ = [ @@ -16,6 +18,8 @@ LogicalSortPlugin, LogicalTableScanPlugin, LogicalUnionPlugin, + LogicalIntersectPlugin, + LogicalMinusPlugin, LogicalValuesPlugin, SamplePlugin, ] diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 3d1394876..55256d5dd 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -1,9 +1,10 @@ import operator from collections import defaultdict from functools import reduce -from typing import Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Type import logging +import pandas as pd import dask.dataframe as dd from dask_sql.utils import new_temporary_column @@ -13,45 +14,52 @@ logger = logging.getLogger(__name__) -class GroupDatasetDescription: +class ReduceAggregation(dd.Aggregation): """ - Helper class to put dataframes which are filtered according to a specific column - into a dictionary. - Applying the same filter twice on the same dataframe does not give different - dataframes. Therefore we only hash these dataframes according to the column - they are filtered by. + A special form of an aggregation, that applies a given operation + on all elements in a group with "reduce". """ - def __init__(self, df: dd.DataFrame, filtered_column: str = ""): - self.df = df - self.filtered_column = filtered_column + def __init__(self, name: str, operation: Callable): + series_aggregate = lambda s: s.aggregate( + lambda x: reduce(operation, x) + ) - def __eq__(self, rhs: "GroupDatasetDescription") -> bool: - """They are equal of they are filtered by the same column""" - return self.filtered_column == rhs.filtered_column + super().__init__(name, series_aggregate, series_aggregate) - def __hash__(self) -> str: - return hash(self.filtered_column) - def __repr__(self) -> str: - return f"GroupDatasetDescription({self.filtered_column})" +class AggregationOnPandas(dd.Aggregation): + """ + A special form of an aggregation, which does not apply the given function + (given as attribute name) directly to the dask groupby, but + via the groupby().apply() method. This is needed to call + functions directly on the pandas dataframes, but should be done + very carefully (as it is a performance bottleneck). + """ + def __init__(self, function_name: str): + def _f(s): + return s.apply(lambda s0: getattr(s0.dropna(), function_name)()) -# Description of an aggregation in the form of a mapping -# input column -> output column -> aggregation -AggregationDescription = Dict[str, Dict[str, Union[str, dd.Aggregation]]] + super().__init__(function_name, _f, _f) -class ReduceAggregation(dd.Aggregation): +class AggregationSpecification: """ - A special form of an aggregation, that applies a given operation - on all elements in a group with "reduce". + Most of the aggregations in SQL are already + implemented 1:1 in dask and can just be called via their name + (e.g. AVG is the mean). However sometimes those already + implemented functions only work well for numerical + functions. This small container class therefore + can have an additional aggregation function, which is + valid for non-numerical types. """ - def __init__(self, name: str, operation: Callable): - series_aggregate = lambda s: s.aggregate(lambda x: reduce(operation, x)) - - super().__init__(name, series_aggregate, series_aggregate) + def __init__(self, numerical_aggregation, non_numerical_aggregation=None): + self.numerical_aggregation = numerical_aggregation + self.non_numerical_aggregation = ( + non_numerical_aggregation or numerical_aggregation + ) class LogicalAggregatePlugin(BaseRelPlugin): @@ -63,35 +71,56 @@ class LogicalAggregatePlugin(BaseRelPlugin): group over, in the second case we "cheat" and add a 1-column to the dataframe, which allows us to reuse every aggregation function we already know of. + As NULLs are not groupable in dask, we handle them special + by adding a temporary column which is True for all NULL values + and False otherwise (and also group by it). The rest is just a lot of column-name-bookkeeping. Fortunately calcite will already make sure, that each aggregation function will only every be called with a single input column (by splitting the inner calculation to a step before). + + Open TODO: So far we are following the dask default + to only have a single partition after the group by (which is usual + a reasonable assumption). It would be nice to control + these things via HINTs. """ class_name = "org.apache.calcite.rel.logical.LogicalAggregate" AGGREGATION_MAPPING = { - "$sum0": "sum", - "any_value": dd.Aggregation( - "any_value", - lambda s: s.sample(n=1).values, - lambda s0: s0.sample(n=1).values, + "$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")), + "sum": AggregationSpecification("sum", AggregationOnPandas("sum")), + "any_value": AggregationSpecification( + dd.Aggregation( + "any_value", + lambda s: s.sample(n=1).values, + lambda s0: s0.sample(n=1).values, + ) + ), + "avg": AggregationSpecification("mean", AggregationOnPandas("mean")), + "bit_and": AggregationSpecification( + ReduceAggregation("bit_and", operator.and_) + ), + "bit_or": AggregationSpecification( + ReduceAggregation("bit_or", operator.or_) ), - "avg": "mean", - "bit_and": ReduceAggregation("bit_and", operator.and_), - "bit_or": ReduceAggregation("bit_or", operator.or_), - "bit_xor": ReduceAggregation("bit_xor", operator.xor), - "count": "count", - "every": dd.Aggregation("every", lambda s: s.all(), lambda s0: s0.all()), - "max": "max", - "min": "min", - "single_value": "first", + "bit_xor": AggregationSpecification( + ReduceAggregation("bit_xor", operator.xor) + ), + "count": AggregationSpecification("count"), + "every": AggregationSpecification( + dd.Aggregation("every", lambda s: s.all(), lambda s0: s0.all()) + ), + "max": AggregationSpecification("max", AggregationOnPandas("max")), + "min": AggregationSpecification("min", AggregationOnPandas("min")), + "single_value": AggregationSpecification("first"), } def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + self, + rel: "org.apache.calcite.rel.RelNode", + context: "dask_sql.Context", ) -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) @@ -102,7 +131,9 @@ def convert( cc = cc.make_unique() # I have no idea what that is, but so far it was always of length 1 - assert len(rel.getGroupSets()) == 1, "Do not know how to handle this case!" + assert ( + len(rel.getGroupSets()) == 1 + ), "Do not know how to handle this case!" # Extract the information, which columns we need to group for group_column_indices = [int(i) for i in rel.getGroupSet()] @@ -110,121 +141,196 @@ def convert( cc.get_backend_by_frontend_index(i) for i in group_column_indices ] - # Always keep an additional column around for empty groups and aggregates - additional_column_name = new_temporary_column(df) - - # NOTE: it might be the case that - # we do not need this additional - # column, but hopefully adding a single - # column of 1 is not so problematic... - df = df.assign(**{additional_column_name: 1}) - cc = cc.add(additional_column_name) - dc = DataContainer(df, cc) - - # Collect all aggregates - filtered_aggregations, output_column_order = self._collect_aggregations( - rel, dc, group_columns, additional_column_name, context - ) - if not group_columns: # There was actually no GROUP BY specified in the SQL # Still, this plan can also be used if we need to aggregate something over the full # data sample # To reuse the code, we just create a new column at the end with a single value - # It is important to do this after creating the aggregations, - # as we do not want this additional column to be used anywhere - group_columns = [additional_column_name] - logger.debug("Performing full-table aggregation") - # Now we can perform the aggregates - # We iterate through all pairs of (possible pre-filtered) - # dataframes and the aggregations to perform in this data... - df_agg = None - for filtered_df_desc, aggregation in filtered_aggregations.items(): - filtered_column = filtered_df_desc.filtered_column - if filtered_column: - logger.debug( - f"Aggregating {dict(aggregation)} on the data filtered by {filtered_column}" - ) - else: - logger.debug(f"Aggregating {dict(aggregation)} on the data") - - # ... we perform the aggregations ... - filtered_df = filtered_df_desc.df - # TODO: we could use the type information for - # pre-calculating the meta information - filtered_df_agg = filtered_df.groupby(by=group_columns).agg(aggregation) - - # ... fix the column names to a single level ... - filtered_df_agg.columns = filtered_df_agg.columns.get_level_values(-1) - - # ... and finally concat the new data with the already present columns - if df_agg is None: - df_agg = filtered_df_agg - else: - df_agg = df_agg.assign( - **{col: filtered_df_agg[col] for col in filtered_df_agg.columns} - ) + # Add an entry for every grouped column, as SQL wants them first + output_column_order = group_columns.copy() + additional_column_name = new_temporary_column(df) - # SQL does not care about the index, but we do not want to have any multiindices - df_agg = df_agg.reset_index(drop=True) + # Collect all aggregations we need to do + ( + collected_aggregations, + output_column_order, + ) = self._collect_aggregations( + rel, df, cc, context, additional_column_name, output_column_order + ) - # Fix the column names and the order of them, as this was messed with during the aggregations - df_agg.columns = df_agg.columns.get_level_values(-1) - cc = ColumnContainer(df_agg.columns).limit_to(output_column_order) + # Check if we're doing a real aggregation or just droping duplicates + if ( + len(group_columns) == len(cc.columns) + and len(collected_aggregations) == 0 + ): + # Just drop duplicates + df_agg = df.drop_duplicates() + else: + # Do the aggregations + df_result = self._do_aggregations( + df, + group_columns, + collected_aggregations, + additional_column_name, + ) + + # SQL does not care about the index, but we do not want to have any multiindices + df_agg = df_result.reset_index(drop=True) + + # Fix the column names and the order of them, as this was messed with during the aggregations + df_agg.columns = df_agg.columns.get_level_values(-1) + cc = ColumnContainer(df_agg.columns).limit_to(output_column_order) cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df_agg, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc - def _collect_aggregations( + def _do_aggregations( self, - rel: "org.apache.calcite.rel.RelNode", - dc: DataContainer, + df: dd.DataFrame, group_columns: List[str], + collected_aggregations: Dict[ + Tuple[str, str], List[Tuple[str, str, Any, Type]] + ], additional_column_name: str, + ) -> Tuple[dd.DataFrame, List[str]]: + """ + Main functionality: return the result dataframe + and the output column order + """ + # We might need it later. + # If not, lets hope that adding a single column should not + # be a huge problem... + df = df.assign(**{additional_column_name: 1}) + + # SQL needs to have a column with the grouped values as the first + # output column. As the values of the group columns + # are the same for a single group anyways, we just use the first row. + df_result = None + if group_columns: + default_column_aggregations = [] + for col in group_columns: + default_column_aggregations.append((col, col, "first", None)) + df_result = self._apply_and_assign_aggregations( + df, + df_result, + None, + default_column_aggregations, + additional_column_name, + group_columns, + ) + + # Now we can go ahead and use these grouped aggregations + # to perform the actual aggregation + # It is very important to start with the non-filtered entry. + # Otherwise we might loose some entries in the grouped columns + if None in collected_aggregations: + key = None + aggregations = collected_aggregations.pop(key) + df_result = self._apply_and_assign_aggregations( + df, + df_result, + None, + aggregations, + additional_column_name, + group_columns, + ) + + # Now we can also add the rest + for filter_column, aggregations in collected_aggregations.items(): + df_result = self._apply_and_assign_aggregations( + df, + df_result, + filter_column, + aggregations, + additional_column_name, + group_columns, + ) + + return df_result + + def _apply_and_assign_aggregations( + self, + df: dd.DataFrame, + df_result: dd.DataFrame, + filter_column: str, + aggregations: List[Tuple[str, str, Any, Type]], + additional_column_name: str, + group_columns: List[str], + ): + agg_result = self._perform_aggregation( + df, + filter_column, + aggregations, + additional_column_name, + group_columns, + ) + if df_result is None: + df_result = agg_result + else: + df_result = df_result.assign( + **{col: agg_result[col] for col in agg_result.columns} + ) + return df_result + + def _collect_aggregations( + self, + rel: "org.apache.calcite.rel.RelNode", + df: dd.DataFrame, + cc: ColumnContainer, context: "dask_sql.Context", + additional_column_name: str, + output_column_order: List[str], ) -> Tuple[ - Dict[GroupDatasetDescription, AggregationDescription], List[int], + Dict[Tuple[str, str], List[Tuple[str, str, Any, Type]]], List[str] ]: """ - Create a mapping of dataframe -> aggregations (in the form input colum, output column, aggregation) - and the expected order of output columns. - """ - aggregations = defaultdict(lambda: defaultdict(dict)) - output_column_order = [] - df = dc.df - cc = dc.column_container + Collect all aggregations together, which have the same filter column + so that the aggregations only need to be done once. - # SQL needs to copy the old content also. As the values of the group columns - # are the same for a single group anyways, we just use the first row - for col in group_columns: - aggregations[GroupDatasetDescription(df)][col][col] = "first" - output_column_order.append(col) + Returns the aggregations as mapping filter_column -> List of Aggregations + where the aggregations are in the form (input_col, output_col, aggregation function (or string), return_type) + """ + collected_aggregations = defaultdict(list) - # Now collect all aggregations for agg_call in rel.getNamedAggCalls(): - output_col = str(agg_call.getValue()) expr = agg_call.getKey() - if expr.hasFilter(): - filter_column = cc.get_backend_by_frontend_index(expr.filterArg) - filter_expression = df[filter_column] - filtered_df = df[filter_expression] - - grouped_df = GroupDatasetDescription(filtered_df, filter_column) + # Find out about the input column + inputs = expr.getArgList() + if len(inputs) == 1: + input_col = cc.get_backend_by_frontend_index(inputs[0]) + elif len(inputs) == 0: + input_col = additional_column_name else: - grouped_df = GroupDatasetDescription(df) + input_col = tuple( + cc.get_backend_by_frontend_index(i) for i in inputs + ) - if expr.isDistinct(): - raise NotImplementedError("DISTINCT is not implemented (yet)") + # Extract flags (filtering/distinct) + if expr.isDistinct(): # pragma: no cover + raise ValueError("Apache Calcite should optimize them away!") + filter_column = None + if expr.hasFilter(): + filter_column = cc.get_backend_by_frontend_index( + expr.filterArg + ) + + # Find out which aggregation function to use aggregation_name = str(expr.getAggregation().getName()) aggregation_name = aggregation_name.lower() + return_type = None + for function_description in context.function_list: + if function_description.name == aggregation_name: + return_type = function_description.return_type try: - aggregation_function = self.AGGREGATION_MAPPING[aggregation_name] + aggregation_function = self.AGGREGATION_MAPPING[ + aggregation_name + ] except KeyError: try: aggregation_function = context.functions[aggregation_name] @@ -232,16 +338,123 @@ def _collect_aggregations( raise NotImplementedError( f"Aggregation function {aggregation_name} not implemented (yet)." ) + if isinstance(aggregation_function, AggregationSpecification): + if isinstance(input_col, tuple): + dtype = df[input_col[0]].dtype + else: + dtype = df[input_col].dtype + if pd.api.types.is_numeric_dtype(dtype): + aggregation_function = ( + aggregation_function.numerical_aggregation + ) + else: + aggregation_function = ( + aggregation_function.non_numerical_aggregation + ) - inputs = expr.getArgList() - if len(inputs) == 1: - input_col = cc.get_backend_by_frontend_index(inputs[0]) - elif len(inputs) == 0: - input_col = additional_column_name - else: - raise NotImplementedError("Can not cope with more than one input") + # Finally, extract the output column name + output_col = str(agg_call.getValue()) - aggregations[grouped_df][input_col][output_col] = aggregation_function + # Store the aggregation + key = filter_column + value = (input_col, output_col, aggregation_function, return_type) + collected_aggregations[key].append(value) output_column_order.append(output_col) - return aggregations, output_column_order + return collected_aggregations, output_column_order + + def _perform_aggregation( + self, + df: dd.DataFrame, + filter_column: str, + aggregations: List[Tuple[str, str, Any, Type]], + additional_column_name: str, + group_columns: List[str], + ): + tmp_df = df + + if filter_column: + filter_expression = tmp_df[filter_column] + tmp_df = tmp_df[filter_expression] + + logger.debug(f"Filtered by {filter_column} before aggregation.") + + # Jonas : we don't really care to have the exact same behaviour as SQL + # and grouping by series instead of column names is messing up the + # multi column aggregations so i'm just assuming this will work + # instead of the commented part below. + group_columns_and_nulls = group_columns + + # # SQL and dask are treating null columns a bit different: + # # SQL will put them to the front, dask will just ignore them + # # Therefore we use the same trick as fugue does: + # # we will group by both the NaN and the real column value + # group_columns_and_nulls = [] + # for group_column in group_columns: + # # the ~ makes NaN come first + # is_null_column = ~(tmp_df[group_column].isnull()) + # non_nan_group_column = tmp_df[group_column].fillna(0) + + # group_columns_and_nulls += [is_null_column, non_nan_group_column] + + if not group_columns_and_nulls: + # This can happen in statements like + # SELECT SUM(x) FROM data + # without any groupby statement + group_columns_and_nulls = [additional_column_name] + + grouped_df = tmp_df.groupby(by=group_columns_and_nulls, sort=False) + + # Dask supports two types of group-by aggregations: by calling .agg + # or .apply on a GroupBy dataframe. We want to call .agg if possible, + # as it's supposed to be faster and cleaner. But it doesn't work for + # all cases, in which case we use .apply instead. We start by + # preparing the aggregate calls in a format dask understands. + aggregate_aggregations = defaultdict(dict) + apply_aggregations = dict() + for aggregation in aggregations: + input_col, output_col, aggregation_f, return_type = aggregation + if isinstance( + aggregation_f, (AggregationSpecification, dd.Aggregation, str) + ): + aggregate_aggregations[input_col][output_col] = aggregation_f + else: + apply_aggregations[output_col] = ( + input_col, + aggregation_f, + return_type, + ) + + # Now we apply the aggregations + agg_result = None + if len(aggregate_aggregations) > 0: + logger.debug( + f"Performing aggregation {dict(aggregate_aggregations)}" + ) + agg_result = grouped_df.agg(aggregate_aggregations) + + # ... fix the column names to a single level ... + agg_result.columns = agg_result.columns.get_level_values(-1) + + # apply aggregations with .apply. The .persist() calls on agg_result are + # important otherwise when the dataframe gets computed, the lambda function + # to apply will be the last one of the list. + for output_col, ( + input_col, + aggregation_f, + return_type, + ) in apply_aggregations.items(): + if not isinstance(input_col, tuple): + input_col = (input_col,) + new_col = grouped_df.apply( + lambda x: aggregation_f( + *[getattr(x, col) for col in input_col] + ), + meta=(output_col, return_type), + ) + if agg_result is None: + agg_result = new_col.rename(output_col).to_frame().persist() + else: + agg_result = agg_result.assign(**{output_col: new_col}).persist() + + return agg_result diff --git a/dask_sql/physical/rel/logical/intersect.py b/dask_sql/physical/rel/logical/intersect.py new file mode 100644 index 000000000..8c5b6210b --- /dev/null +++ b/dask_sql/physical/rel/logical/intersect.py @@ -0,0 +1,73 @@ +import dask.dataframe as dd + +from dask_sql.physical.rex import RexConverter +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer + + +class LogicalIntersectPlugin(BaseRelPlugin): + """ + LogicalIntersect is used on INTERSECT clauses. + It just concatonates the two data frames. + """ + + class_name = "org.apache.calcite.rel.logical.LogicalIntersect" + + def convert( + self, + rel: "org.apache.calcite.rel.RelNode", + context: "dask_sql.Context", + ) -> DataContainer: + first_dc, second_dc = self.assert_inputs(rel, 2, context) + + first_df = first_dc.df + first_cc = first_dc.column_container + + second_df = second_dc.df + second_cc = second_dc.column_container + + # For concatenating, they should have exactly the same fields + output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] + assert len(first_cc.columns) == len(output_field_names) + first_cc = first_cc.rename( + columns={ + col: output_col + for col, output_col in zip( + first_cc.columns, output_field_names + ) + } + ) + first_dc = DataContainer(first_df, first_cc) + + assert len(second_cc.columns) == len(output_field_names) + second_cc = second_cc.rename( + columns={ + col: output_col + for col, output_col in zip( + second_cc.columns, output_field_names + ) + } + ) + second_dc = DataContainer(second_df, second_cc) + + # To concat the to dataframes, we need to make sure the + # columns actually have the specified names in the + # column containers + # Otherwise the concat won't work + first_df = first_dc.assign() + second_df = second_dc.assign() + + self.check_columns_from_row_type( + first_df, rel.getExpectedInputRowType(0) + ) + self.check_columns_from_row_type( + second_df, rel.getExpectedInputRowType(1) + ) + + df = first_df.merge(second_df, how="inner") + + cc = ColumnContainer(df.columns) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 8c2dd8e0f..e85d07349 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -42,7 +42,9 @@ class LogicalJoinPlugin(BaseRelPlugin): } def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + self, + rel: "org.apache.calcite.rel.RelNode", + context: "dask_sql.Context", ) -> DataContainer: # Joining is a bit more complicated, so lets do it in steps: @@ -78,9 +80,13 @@ def convert( # As this is probably non-sense for large tables, but there is no other # known solution so far. join_condition = rel.getCondition() - lhs_on, rhs_on, filter_condition = self._split_join_condition(join_condition) + lhs_on, rhs_on, filter_condition = self._split_join_condition( + join_condition + ) - logger.debug(f"Joining with type {join_type} on columns {lhs_on}, {rhs_on}.") + logger.debug( + f"Joining with type {join_type} on columns {lhs_on}, {rhs_on}." + ) # lhs_on and rhs_on are the indices of the columns to merge on. # The given column indices are for the full, merged table which consists @@ -100,6 +106,29 @@ def convert( f"common_{i}": df_rhs_renamed.iloc[:, index] for i, index in enumerate(rhs_on) } + + # SQL compatibility: when joining on columns that + # contain NULLs, pandas will actually happily + # keep those NULLs. That is however not compatible with + # SQL, so we get rid of them here + if join_type in ["inner", "right"]: + df_lhs_filter = reduce( + operator.and_, + [ + ~df_lhs_renamed.iloc[:, index].isna() + for index in lhs_on + ], + ) + df_lhs_renamed = df_lhs_renamed[df_lhs_filter] + if join_type in ["inner", "left"]: + df_rhs_filter = reduce( + operator.and_, + [ + ~df_rhs_renamed.iloc[:, index].isna() + for index in rhs_on + ], + ) + df_rhs_renamed = df_rhs_renamed[df_rhs_filter] else: # We are in the complex join case # where we have no column to merge on @@ -121,7 +150,9 @@ def convert( # 5. Now we can finally merge on these columns # The resulting dataframe will contain all (renamed) columns from the lhs and rhs # plus the added columns - df = dd.merge(df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type) + df = dd.merge( + df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type + ) # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) @@ -207,24 +238,29 @@ def _extract_lhs_rhs(self, rex): operands = rex.getOperands() assert len(operands) == 2 - operand_lhs = operands[0] - operand_rhs = operands[1] - - if isinstance(operand_lhs, org.apache.calcite.rex.RexInputRef) and isinstance( - operand_rhs, org.apache.calcite.rex.RexInputRef - ): - lhs_index = operand_lhs.getIndex() - rhs_index = operand_rhs.getIndex() - - # The rhs table always comes after the lhs - # table. Therefore we have a very simple - # way of checking, which index comes from which - # input - if lhs_index > rhs_index: - lhs_index, rhs_index = rhs_index, lhs_index - - return lhs_index, rhs_index - - raise TypeError( - "Invalid join condition" - ) # pragma: no cover. Do not how how it could be triggered. + indices = [] + for operand in operands: + if isinstance(operand, org.apache.calcite.rex.RexInputRef): + indices.append(operand.getIndex()) + elif isinstance( + operand, org.apache.calcite.rex.RexCall + ) and isinstance( + operand.getOperator(), + org.apache.calcite.sql.fun.SqlCastFunction, + ): + indices.append(operand.operands[0].getIndex()) + elif isinstance(operand, org.apache.calcite.rex.RexLiteral): + # i.e. join condition is col.id == constant + # raising an AssertionError means that the RexExpression will be added + # as a filter condition to be applied after the join. + raise AssertionError("This is actually a filter condition") + else: + raise TypeError( + "Invalid join condition" + ) # pragma: no cover. Do not how how it could be triggered. + lhs_index, rhs_index = indices + + if lhs_index > rhs_index: + lhs_index, rhs_index = rhs_index, lhs_index + + return lhs_index, rhs_index diff --git a/dask_sql/physical/rel/logical/minus.py b/dask_sql/physical/rel/logical/minus.py new file mode 100644 index 000000000..57352c1d5 --- /dev/null +++ b/dask_sql/physical/rel/logical/minus.py @@ -0,0 +1,70 @@ +import dask.dataframe as dd + +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer + + +class LogicalMinusPlugin(BaseRelPlugin): + """ + LogicalUnion is used on EXCEPT clauses. + It just concatonates the two data frames. + """ + + class_name = "org.apache.calcite.rel.logical.LogicalMinus" + + def convert( + self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + ) -> DataContainer: + first_dc, second_dc = self.assert_inputs(rel, 2, context) + + first_df = first_dc.df + first_cc = first_dc.column_container + + second_df = second_dc.df + second_cc = second_dc.column_container + + # For concatenating, they should have exactly the same fields + output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] + assert len(first_cc.columns) == len(output_field_names) + first_cc = first_cc.rename( + columns={ + col: output_col + for col, output_col in zip(first_cc.columns, output_field_names) + } + ) + first_dc = DataContainer(first_df, first_cc) + + assert len(second_cc.columns) == len(output_field_names) + second_cc = second_cc.rename( + columns={ + col: output_col + for col, output_col in zip(second_cc.columns, output_field_names) + } + ) + second_dc = DataContainer(second_df, second_cc) + + # To concat the to dataframes, we need to make sure the + # columns actually have the specified names in the + # column containers + # Otherwise the concat won't work + first_df = first_dc.assign() + second_df = second_dc.assign() + + self.check_columns_from_row_type(first_df, rel.getExpectedInputRowType(0)) + self.check_columns_from_row_type(second_df, rel.getExpectedInputRowType(1)) + + df = first_df.merge( + second_df, + how='left', + indicator=True, + ) + + df = df[ + df.iloc[:, -1] == "left_only" + ].iloc[:, :-1] + + cc = ColumnContainer(df.columns) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc diff --git a/dask_sql/physical/rel/logical/sort.py b/dask_sql/physical/rel/logical/sort.py index 3038e445e..4c4a22ad3 100644 --- a/dask_sql/physical/rel/logical/sort.py +++ b/dask_sql/physical/rel/logical/sort.py @@ -135,7 +135,7 @@ def _sort_first_column( col = df[first_sort_column] is_na = col.isna().persist() if is_na.any().compute(): - df_is_na = df[is_na].reset_index(drop=True) + df_is_na = df[is_na].reset_index(drop=True).repartition(1) df_not_is_na = ( df[~is_na] .set_index(first_sort_column, drop=False) diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index ffd82bd57..8d5666a2f 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -240,6 +240,21 @@ def null(self, df: SeriesOrScalar,) -> SeriesOrScalar: return pd.isna(df) or df is None or np.isnan(df) +class IsNotDistinctOperation(Operation): + """The is not distinct operator""" + + def __init__(self): + super().__init__(self.not_distinct) + + def not_distinct(self, lhs: SeriesOrScalar, rhs: SeriesOrScalar) -> SeriesOrScalar: + """ + Returns true where `lhs` is not distinct from `rhs` (or both are null). + """ + is_null = IsNullOperation() + + return (is_null(lhs) & is_null(rhs)) | (lhs == rhs) + + class RegexOperation(Operation): """An abstract regex operation, which transforms the SQL regex into something python can understand""" @@ -627,6 +642,8 @@ class RexCallPlugin(BaseRexPlugin): "-": ReduceOperation(operation=operator.sub, unary_operation=lambda x: -x), "/": ReduceOperation(operation=SQLDivisionOperator()), "*": ReduceOperation(operation=operator.mul), + "is distinct from": NotOperation().of(IsNotDistinctOperation()), + "is not distinct from": IsNotDistinctOperation(), # special operations "cast": lambda x: x, "case": CaseOperation(), diff --git a/dask_sql/utils.py b/dask_sql/utils.py index 51ad06f51..a3b8350e3 100644 --- a/dask_sql/utils.py +++ b/dask_sql/utils.py @@ -212,7 +212,7 @@ def get_table_from_compound_identifier( try: return context.tables[tableName] except KeyError: - raise AttributeError(f"Table {tableName} is not defined.") + raise AttributeError(f"Table '{tableName}' does not exist.") def convert_sql_kwargs( diff --git a/docs/pages/cmd.rst b/docs/pages/cmd.rst index 8d3926632..0dcac519f 100644 --- a/docs/pages/cmd.rst +++ b/docs/pages/cmd.rst @@ -21,7 +21,7 @@ or by running these lines of code cmd_loop() Some options can be set, e.g. to preload some testdata. -Have a look into :func:`dask_sql.cmd_loop` or call +Have a look into :func:`~dask_sql.cmd_loop` or call .. code-block:: bash diff --git a/docs/pages/custom.rst b/docs/pages/custom.rst index c0e3d0876..3f2b73527 100644 --- a/docs/pages/custom.rst +++ b/docs/pages/custom.rst @@ -11,7 +11,7 @@ Scalar Functions ---------------- A scalar function (such as :math:`x \to x^2`) turns a given column into another column of the same length. -It can be registered for usage in SQL with the :func:`dask_sql.Context.register_function` method. +It can be registered for usage in SQL with the :func:`~dask_sql.Context.register_function` method. Example: @@ -38,7 +38,7 @@ Aggregation Functions Aggregation functions run on a single column and turn them into a single value. This means they can only be used in ``GROUP BY`` aggregations. -They can be registered with the :func:`dask_sql.Context.register_aggregation` method. +They can be registered with the :func:`~dask_sql.Context.register_aggregation` method. This time however, an instance of a :class:`dask.dataframe.Aggregation` needs to be passed instead of a plain function. More information on dask aggregations can be found in the diff --git a/docs/pages/data_input.rst b/docs/pages/data_input.rst index 9a924f215..9b3413c24 100644 --- a/docs/pages/data_input.rst +++ b/docs/pages/data_input.rst @@ -3,14 +3,14 @@ Data Loading and Input ====================== -Before data can be queried with ``dask-sql``, it needs to be loaded into the dask cluster (or local instance) and registered with the :class:`dask_sql.Context`. +Before data can be queried with ``dask-sql``, it needs to be loaded into the dask cluster (or local instance) and registered with the :class:`~dask_sql.Context`. For this, ``dask-sql`` uses the wide field of possible `input formats `_ of ``dask``, plus some additional formats only suitable for `dask-sql`. You have multiple possibilities to load input data in ``dask-sql``: 1. Load it via python ------------------------------- -You can either use already created dask dataframes or create one by using the :func:`create_table` function. +You can either use already created dask dataframes or create one by using the :func:`~dask_sql.Context.create_table` function. Chances are high, there exists already a function to load your favorite format or location (e.g. s3 or hdfs). See below for all formats understood by ``dask-sql``. Make sure to install required libraries both on the driver and worker machines. @@ -58,7 +58,7 @@ In ``dask``, you can publish datasets with names into the cluster memory. This allows to reuse the same data from multiple clients/users in multiple sessions. For example, you can publish your data using the ``client.publish_dataset`` function of the ``distributed.Client``, -and then later register it in the :class:`dask_sql.Context` via SQL: +and then later register it in the :class:`~dask_sql.Context` via SQL: .. code-block:: python @@ -93,7 +93,7 @@ Input Formats * All formats and locations mentioned in `the Dask docu `_, including csv, parquet, json. Just pass in the location as string (and possibly the format, e.g. "csv" if it is not clear from the file extension). The data can be from local disc or many remote locations (S3, hdfs, Azure Filesystem, http, Google Filesystem, ...) - just prefix the path with the matching protocol. - Additional arguments passed to :func:`create_table` or ``CREATE TABLE`` are given to the ``read_`` calls. + Additional arguments passed to :func:`~dask_sql.Context.create_table` or ``CREATE TABLE`` are given to the ``read_`` calls. Example: @@ -113,7 +113,7 @@ Input Formats ) * If your data is already in Pandas (or Dask) DataFrames format, you can just use it as it is via the Python API - by giving it to :ref:`create_table` directly. + by giving it to :func:`~dask_sql.Context.create_table` directly. * You can connect ``dask-sql`` to an `intake `_ catalog and use the data registered there. Assuming you have an intake catalog stored in "catalog.yaml" (can also be the URL of an intake server), you can read in a stored table "data_table" either via Python @@ -161,7 +161,7 @@ Input Formats c.create_table("my_data", cursor, hive_table_name="the_name_in_hive") Again, ``hive_table_name`` is optional and defaults to the table name in ``dask-sql``. - You can also control the database used in Hive via the ``hive_schema_name```parameter. + You can also control the database used in Hive via the ``hive_schema_name`` parameter. Additional arguments are pushed to the internally called ``read_`` functions. .. note:: diff --git a/docs/pages/how_does_it_work.rst b/docs/pages/how_does_it_work.rst index 568605573..61f6442a7 100644 --- a/docs/pages/how_does_it_work.rst +++ b/docs/pages/how_does_it_work.rst @@ -7,8 +7,116 @@ At the core, ``dask-sql`` does two things: which is specified as a tree of java objects - similar to many other SQL engines (Hive, Flink, ...) - convert this description of the query from java objects into dask API calls (and execute them) - returning a dask dataframe. -For the first step, Apache Calcite needs to know about the columns and types of the dask dataframes, -therefore some java classes to store this information for dask dataframes are defined in ``planner``. -After the translation to a relational algebra is done (using ``RelationalAlgebraGenerator.getRelationalAlgebra``), -the python methods defined in ``dask_sql.physical`` turn this into a physical dask execution plan by converting -each piece of the relational algebra one-by-one. +Th following example explains this in quite some technical details. +For most of the users, this level of technical understanding is not needed. + +1. SQL enters the library +------------------------- + +No matter of via the Python API (:ref:`api`), the command line client (:ref:`cmd`) or the server (:ref:`server`), eventually the SQL statement by the user will end up as a string in the function :func:`~dask_sql.Context.sql`. + +2. SQL is parsed +---------------- + +This function will first give the SQL string to the implemented Java classes (especially :class:`RelationalAlgebraGenerator`) via the ``jpype`` library. +Inside this class, Apache Calcite is used to first parse the SQL string and then turn it into a relational algebra. +For this, Apache Calcite uses the SQL language description specified in the Calcite library itself and the additional definitions in the ``.ftl```files in the ``dask-sql`` repository. +They specify custom language features, such as the ``CREATE MODEL`` statement. + +.. note:: + + ``.ftl`` stands for FreeMarker Template Language and is one of the standard templating languages used in the Java ecosystem. + Each of the "functions" defined in the documents defines a part of the (extended) SQL language in ``javacc`` format. + FreeMarker is used to combine these parser definitions with the ones from Apache Calcite. Have a look into the ``config.fmpp`` file for more information. + + For example the following ``javacc`` code + + .. code-block:: + + SqlNode SqlShowTables() : + { + final Span s; + final SqlIdentifier schema; + } + { + { s = span(); } + schema = CompoundIdentifier() + { + return new SqlShowTables(s.end(this), schema); + } + } + + describes a parser line, which understands SQL statements such as + + .. code-block:: sql + + SHOW TABLES FROM "schema" + + While parsing the SQL, they are turned into an instance of the Java class :class:`SqlShowTables` (which is also defined in this project). + The :class:`Span` is used internally in Apache Calcite to store the position in the parsed SQL statement (e.g. for better error output). + The ``SqlShowTables`` javacc function (not the Java class SqlShowTables) is listed in ``config.fmpp`` as a ``statementParserMethods``, which makes it parsable as main SQL statement (similar to any normal ``SELECT ...`` statement). + All Java classes used as parser return values inherit from the Calcite class :class:`SqlNode` or any derived subclass (if it makes sense). Those classes are barely containers to store the information from the parsed SQL statements (such as the schema name in the example above) and do not have any business logic by themselves. + +3. SQL is (maybe) optimized +--------------------------- + +Once the SQL string is parsed into an instance of a :class:`SqlNode` (or a subclass of it), Apache Calcite can convert it into a relational algebra and optimize it. As this is only implemented for Calcite-own classes (and not for the custom classes such as :class:`SqlCreateModel`) this conversion and optimization is not triggered for all SQL statements (have a look into :func:`Context._get_ral`). + +After optimization, the resulting Java instance will be a class of any of the :class:`Logical*` classes in Apache Calcite (such as :class:`LogicalJoin`). Each of those can contain other instances as "inputs" creating a tree of different steps in the SQL statement (see below for an example). + +So after all, the result is either an optimized tree of steps in the relational algebra (represented by instances of the :class:`Logical*` classes) or an instance of a :class:`SqlNode` (sub)class. + +4. Translation to Dask API calls +-------------------------------- + +Depending on which type the resulting java class has, they are converted into calls to python functions using different python "converters". For each Java class, there exist a converter class in the ``dask_sql.physical.rel`` folder, which are registered at the :class:`dask_sql.physical.rel.convert.RelConverter` class. +Their job is to use the information stored in the java class instances and turn it into calls to python functions (see the example below for more information). + +As many SQL statements contain calculations using literals and/or columns, these are split into their own functionality (``dask_sql.physical.rex``) following a similar plugin-based converter system. +Have a look into the specific classes to understand how the conversion of a specific SQL language feature is implemented. + +5. Result +--------- + +The result of each of the conversions is a :class:`dask.DataFrame`, which is given to the user. In case of the command line tool or the SQL server, it is evaluated immediately - otherwise it can be used for further calculations by the user. + +Example +------- + +Let's walk through the steps above using the example SQL statement + +.. code-block:: sql + + SELECT x + y FROM timeseries WHERE x > 0 + +assuming the table "timeseries" is already registered. +If you want to follow along with the steps outlined in the following, start the command line tool in debug mode + +.. code-block:: bash + + dask-sql --load-test-data --startup --log-level DEBUG + +and enter the SQL statement above. + +First, the SQL is parsed by Apache Calcite and (as it is not a custom statement) transformed into a tree of relational algebra objects. + +.. code-block:: none + + LogicalProject(EXPR$0=[+($3, $4)]) + LogicalFilter(condition=[>($3, 0)]) + LogicalTableScan(table=[[schema, timeseries]]) + +The tree output above means, that the outer instance (:class:`LogicalProject`) needs as input the output of the previous instance (:class:`LogicalFilter`) etc. + +Therefore the conversion to python API calls is called recursively (depth-first). First, the :class:`LogicalTableScan` is converted using the :class:`rel.logical.table_scan.LogicalTableScanPlugin` plugin. It will just get the correct :class:`dask.DataFrame` from the dictionary of already registered tables of the context. +Next, the :class:`LogicalFilter` (having the dataframe as input), is converted via the :class:`rel.logical.filter.LogicalFilterPlugin`. +The filter expression ``>($3, 0)`` is converted into ``df["x"] > 0`` using a combination of REX plugins (have a look into the debug output to learn more) and applied to the dataframe. +The resulting dataframe is then passed to the converter :class:`rel.logical.project.LogicalProjectPlugin` for the :class:`LogicalProject`. +This will calculate the expression ``df["x"] + df["y"]`` (after having converted it via the class:`RexCallPlugin` plugin) and return the final result to the user. + +.. code-block:: python + + df_table_scan = context.tables["timeseries"] + df_filter = df_table_scan[df_table_scan["x"] > 0] + df_project = df_filter.assign(col=df_filter["x"] + df_filter["y"]) + return df_project[["col"]] \ No newline at end of file diff --git a/docs/pages/machine_learning.rst b/docs/pages/machine_learning.rst index a55412a0b..5abcb0f5f 100644 --- a/docs/pages/machine_learning.rst +++ b/docs/pages/machine_learning.rst @@ -19,7 +19,7 @@ Please also see :ref:`ml` for more information on the SQL statements used on thi ------------------------------------------------------------- If you are familiar with Python and the ML ecosystem in Python, this one is probably -the simplest possibility. You can use the :func:`Context.sql` call as described +the simplest possibility. You can use the :func:`~dask_sql.Context.sql` call as described before to extract the data for your training or ML prediction. The result will be a Dask dataframe, which you can either directly feed into your model or convert to a pandas dataframe with `.compute()` before. @@ -49,7 +49,7 @@ automatically. The syntax is similar to the `BigQuery Predict Syntax `_ or @@ -68,7 +68,7 @@ commands. Preregister your own data sources --------------------------------- -The python function :func:`dask_sql.run_server` accepts an already created :class:`dask_sql.Context`. +The python function :func:`~dask_sql.run_server` accepts an already created :class:`~dask_sql.Context`. This means you can preload your data sources and register them with a context before starting your server. By this, your server will already have data to query: diff --git a/docs/pages/sql.rst b/docs/pages/sql.rst index ace1297ab..c371084af 100644 --- a/docs/pages/sql.rst +++ b/docs/pages/sql.rst @@ -199,14 +199,16 @@ Limitatons ``dask-sql`` is still in early development, therefore exist some limitations: -* Not all operations and aggregations are implemented already, most prominently: ``WINDOW`` is not implemented so far. -* ``GROUP BY`` aggregations can not use ``DISTINCT`` +Not all operations and aggregations are implemented already, most prominently: ``WINDOW`` is not implemented so far. .. note:: Whenever you find a not already implemented operation, keyword or functionality, please raise an issue at our `issue tracker `_ with your use-case. +Dask/pandas and SQL treat null-values (or nan) differently on sorting, grouping and joining. +``dask-sql`` tries to follow the SQL standard as much as possible, so results might be different to what you expect from Dask/pandas. + Apart from those functional limitations, there is a operation which need special care: ``ORDER BY```. Normally, ``dask-sql`` calls create a ``dask`` data frame, which gets only computed when you call the ``.compute()`` member. Due to internal constraints, this is currently not the case for ``ORDER BY``. @@ -218,4 +220,5 @@ Including this operation will trigger a calculation of the full data frame alrea The data inside ``dask`` is partitioned, to distribute it over the cluster. ``head`` will only return the first N elements from the first partition - even if N is larger than the partition size. As a benefit, calling ``.head(N)`` is typically faster than calculating the full data sample with ``.compute()``. - ``LIMIT`` on the other hand will always return the first N elements - no matter on how many partitions they are scattered - but will also need to precalculate the first partition to find out, if it needs to have a look into all data or not. + ``LIMIT`` on the other hand will always return the first N elements - no matter on how many partitions they are scattered - + but will also need to precalculate the first partition to find out, if it needs to have a look into all data or not. diff --git a/docs/pages/sql/ml.rst b/docs/pages/sql/ml.rst index bab3837f5..9100145d6 100644 --- a/docs/pages/sql/ml.rst +++ b/docs/pages/sql/ml.rst @@ -13,7 +13,7 @@ As all SQL statements in ``dask-sql`` are eventually converted to Python calls, any custom Python function and library, e.g. Machine Learning libraries. Although it would be possible to register custom functions (see :ref:`custom`) for this and use them, it is much more convenient if this functionality is already included in the core SQL language. -These three statements help in training and using models. Every :class:`Context` has a registry for models, which +These three statements help in training and using models. Every :class:`~dask_sql.Context` has a registry for models, which can be used for training or prediction. For a full example, see :ref:`machine_learning`. @@ -128,7 +128,7 @@ Predict the target using the given model and dataframe from the ``SELECT`` query The return value is the input dataframe with an additional column named "target", which contains the predicted values. The model needs to be registered at the context before using it in this function, -either by calling :func:`Context.register_model` explicitly or by training +either by calling :func:`~dask_sql.Context.register_model` explicitly or by training a model using the ``CREATE MODEL`` SQL statement above. A model can be anything which has a ``predict`` function. diff --git a/planner/pom.xml b/planner/pom.xml index a24f5a1e0..c10fdca06 100755 --- a/planner/pom.xml +++ b/planner/pom.xml @@ -17,7 +17,7 @@ 1.7.29 ${java.version} ${java.version} - 1.26.0 + 1.27.0 @@ -47,6 +47,11 @@ javacc 4.0 + + com.google.guava + guava + 30.1-jre + diff --git a/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java b/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java new file mode 100644 index 000000000..b1f210df0 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java @@ -0,0 +1,286 @@ +package com.dask.sql.application; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import javax.annotation.Nullable; + +import com.dask.sql.schema.DaskTable; +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionConfigImpl; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptMaterialization; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.prepare.PlannerImpl; +import org.apache.calcite.prepare.RelOptTableImpl; +import org.apache.calcite.prepare.Prepare.CatalogReader; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.Schemas; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.StarTable; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.fun.SqlLibrary; +import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.util.SqlOperatorTables; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; +import org.apache.calcite.tools.FrameworkConfig; + +/** + * Utility class for preparing a list of RelOptMaterializations which can then + * be added to a RelOptPlanner (hep or volcano) to be used during optimization. + * + * Create the class using the constructor with the rootSchema, defaultSchemaName, + * and FrameworkConfig used to create a Planner. Then call getMaterializations() + * to get the list of RelOptMaterializations, containing a RelOptMaterialization + * for each view in the schema. A view is any DaskTable which has an sql query + * associated. + * + * A lot of the code for this class is taken / adapted from : + * https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/prepare/Prepare.java + * https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/prepare/CalciteMaterializer.java + * https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/prepare/CalcitePrepareImpl.java + * + */ +public class DaskCalciteMaterializer { + + private final CatalogReader catalogReader; + private final CalciteSchema schema; + private final SqlValidator sqlValidator; + private final JavaTypeFactory typeFactory; + private final FrameworkConfig config; + private final PlannerImpl planner; + + DaskCalciteMaterializer(final SchemaPlus rootSchema, final String schemaName, final FrameworkConfig config) { + final SchemaPlus schemaPlus = rootSchema.getSubSchema(schemaName); + schema = CalciteSchema.from(schemaPlus); + this.config = config; + planner = new PlannerImpl(config); + + final List schemaPath = new ArrayList(); + schemaPath.add(schema.getName()); + final Properties props = new Properties(); + props.setProperty("defaultSchema", schema.getName()); + catalogReader = new CalciteCatalogReader(schema.root(), schemaPath, + new JavaTypeFactoryImpl(DaskSqlDialect.DASKSQL_TYPE_SYSTEM), new CalciteConnectionConfigImpl(props)); + + final List sqlOperatorTables = new ArrayList<>(); + sqlOperatorTables.add(SqlStdOperatorTable.instance()); + sqlOperatorTables.add(SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable(SqlLibrary.POSTGRESQL)); + sqlOperatorTables.add(catalogReader); + SqlOperatorTable operatorTable = SqlOperatorTables.chain(sqlOperatorTables); + + typeFactory = new JavaTypeFactoryImpl(DaskSqlDialect.DASKSQL_TYPE_SYSTEM); + final CalciteConnectionConfig connectionConfig = new CalciteConnectionConfigImpl(props); + final SqlValidator.Config validatorConfig = SqlValidator.Config.DEFAULT + .withLenientOperatorLookup(connectionConfig.lenientOperatorLookup()) + .withSqlConformance(connectionConfig.conformance()) + .withDefaultNullCollation(connectionConfig.defaultNullCollation()).withIdentifierExpansion(true); + + sqlValidator = SqlValidatorUtil.newValidator(operatorTable, catalogReader, typeFactory, validatorConfig); + } + + /** + * Prepare a list of RelOptMaterialization to be added to the planner before optimizing + */ + public List getMaterializations() { + List materializations = new ArrayList(); + for (String tableName : schema.getTableNames()) { + CalciteSchema.TableEntry tableEntry = schema.getTable(tableName, true); + DaskTable table = (DaskTable) tableEntry.getTable(); + if (table.isMaterializedView()) { + List qualifiedTableName = tableEntry.path(); + // Create a materialization with the Table and SQL query + final Materialization materialization = new Materialization(tableEntry, table.getSql(), qualifiedTableName); + // Populate this materialization's tableRel and queryRel with the corresponding + // RelNode representation for the query and table + populate(materialization); + // Create a RelOptMaterialization to add to the list of materializations + materializations.add( + new RelOptMaterialization( + materialization.tableRel, + materialization.queryRel, + materialization.starRelOptTable, + qualifiedTableName)); + } + } + return materializations; + } + + protected SqlToRelConverter getSqlToRelConverter(SqlValidator validator, CatalogReader catalogReader, + SqlToRelConverter.Config relConfig) { + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RelOptPlanner optPlanner = new VolcanoPlanner(this.config.getCostFactory(), this.config.getContext()); + final RelOptCluster cluster = RelOptCluster.create(optPlanner, rexBuilder); + return new SqlToRelConverter(planner, validator, catalogReader, cluster, StandardConvertletTable.INSTANCE, + relConfig); + } + + /** + * Populates a materialization record, converting an sql query string and + * table path (essentially a list of strings, like ["hr", "sales"]) into + * RelNodes which can be used during the relational algebra planning process. + */ + protected void populate(final Materialization materialization) { + SqlParser parser = SqlParser.create(materialization.sql, config.getParserConfig()); + SqlNode node; + try { + node = parser.parseStmt(); + } catch (SqlParseException e) { + throw new RuntimeException("parse failed", e); + } + final SqlToRelConverter.Config relConfig = SqlToRelConverter.config().withTrimUnusedFields(true); + SqlToRelConverter sqlToRelConverter2 = getSqlToRelConverter(sqlValidator, catalogReader, relConfig); + + RelRoot root = sqlToRelConverter2.convertQuery(node, true, true); + materialization.queryRel = trimUnusedFields(root).rel; + + // Identify and substitute a StarTable in queryRel. + // + // It is possible that no StarTables match. That is OK, but the + // materialization patterns that are recognized will not be as rich. + // + // It is possible that more than one StarTable matches. TBD: should we + // take the best (whatever that means), or all of them? + useStar(schema, materialization); + + List tableName = materialization.materializedTable.path(); + RelOptTable table = this.catalogReader.getTable(tableName); + materialization.tableRel = sqlToRelConverter2.toRel(table, ImmutableList.of()); + } + + /** + * Walks over a tree of relational expressions, replacing each + * {@link org.apache.calcite.rel.RelNode} with a 'slimmed down' relational + * expression that projects only the columns required by its consumer. + * + * @param root Root of relational expression tree + * @return Trimmed relational expression + */ + protected RelRoot trimUnusedFields(RelRoot root) { + final SqlToRelConverter.Config config = SqlToRelConverter.config().withTrimUnusedFields(shouldTrim(root.rel)) + .withExpand(false); + final SqlToRelConverter converter = getSqlToRelConverter(sqlValidator, catalogReader, config); + final boolean ordered = !root.collation.getFieldCollations().isEmpty(); + final boolean dml = SqlKind.DML.contains(root.kind); + return root.withRel(converter.trimUnusedFields(dml || ordered, root.rel)); + } + + private static boolean shouldTrim(RelNode rootRel) { + // For now, don't trim if there are more than 3 joins. The projects + // near the leaves created by trim migrate past joins and seem to + // prevent join-reordering. + return RelOptUtil.countJoins(rootRel) < 2; + } + + /** + * Converts a relational expression to use a {@link StarTable} defined in + * {@code schema}. Uses the first star table that fits. + */ + private void useStar(CalciteSchema schema, Materialization materialization) { + RelNode queryRel = materialization.queryRel; + for (Callback x : useStar(schema, queryRel)) { + // Success -- we found a star table that matches. + materialization.materialize(x.rel, x.starRelOptTable); + System.out.println("Materialization " + materialization.materializedTable + " matched star table " + + x.starTable + "; query after re-write: " + RelOptUtil.toString(queryRel)); + } + } + + /** + * Converts a relational expression to use a + * {@link org.apache.calcite.schema.impl.StarTable} defined in {@code schema}. + * Uses the first star table that fits. + */ + private Iterable useStar(CalciteSchema schema, RelNode queryRel) { + List starTables = Schemas.getStarTables(schema.root()); + if (starTables.isEmpty()) { + // Don't waste effort converting to leaf-join form. + return ImmutableList.of(); + } + final List list = new ArrayList<>(); + final RelNode rel2 = RelOptMaterialization.toLeafJoinForm(queryRel); + for (CalciteSchema.TableEntry starTable : starTables) { + final Table table = starTable.getTable(); + assert table instanceof StarTable; + RelOptTableImpl starRelOptTable = RelOptTableImpl.create(catalogReader, table.getRowType(typeFactory), + starTable, null); + final RelNode rel3 = RelOptMaterialization.tryUseStar(rel2, starRelOptTable); + if (rel3 != null) { + list.add(new Callback(rel3, starTable, starRelOptTable)); + } + } + return list; + } + + /** Called when we discover a star table that matches. */ + static class Callback { + public final RelNode rel; + public final CalciteSchema.TableEntry starTable; + public final RelOptTableImpl starRelOptTable; + + Callback(RelNode rel, CalciteSchema.TableEntry starTable, RelOptTableImpl starRelOptTable) { + this.rel = rel; + this.starTable = starTable; + this.starRelOptTable = starRelOptTable; + } + } + + /** + * Describes that a given SQL query is materialized by a given table. The + * materialization is currently valid, and can be used in the planning process. + */ + public static class Materialization { + /** The table that holds the materialized data. */ + final CalciteSchema.TableEntry materializedTable; + /** The query that derives the data. */ + final String sql; + /** The schema path for the query. */ + final List viewSchemaPath; + /** + * Relational expression for the table. Usually a + * {@link org.apache.calcite.rel.logical.LogicalTableScan}. + */ + @Nullable + RelNode tableRel; + /** Relational expression for the query to populate the table. */ + @Nullable + RelNode queryRel; + /** Star table identified. */ + private @Nullable RelOptTable starRelOptTable; + + public Materialization(CalciteSchema.TableEntry materializedTable, String sql, List viewSchemaPath) { + assert materializedTable != null; + assert sql != null; + this.materializedTable = materializedTable; + this.sql = sql; + this.viewSchemaPath = viewSchemaPath; + } + + public void materialize(RelNode queryRel, RelOptTable starRelOptTable) { + this.queryRel = queryRel; + this.starRelOptTable = starRelOptTable; + // assert starRelOptTable.maybeUnwrap(StarTable.class).isPresent(); + } + } +} diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java new file mode 100644 index 000000000..d15236606 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -0,0 +1,182 @@ +package com.dask.sql.application; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.hep.HepMatchOrder; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.plan.hep.HepProgramBuilder; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.rules.materialize.MaterializedViewRules; +import org.apache.calcite.tools.RuleSet; +import org.apache.calcite.tools.RuleSets; + +/** + * RuleSets and utilities for creating Programs to use with Calcite's query + * planners. This is inspired both from Apache Calcite's default optimization + * programs + * (https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/tools/Programs.java) + * and Apache Flink's multi-phase query optimization + * (https://github.com/apache/flink/blob/master/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkStreamProgram.scala) + */ +public class DaskRuleSets { + + // private constructor + private DaskRuleSets() { + } + + /** + * RuleSet to reduce expressions + */ + static final RuleSet REDUCE_EXPRESSION_RULES = RuleSets.ofList(CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.CALC_REDUCE_EXPRESSIONS, + CoreRules.JOIN_REDUCE_EXPRESSIONS, CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + + /** + * RuleSet about filter + */ + static final RuleSet FILTER_RULES = RuleSets.ofList( + // push a filter into a join + CoreRules.FILTER_INTO_JOIN, + // Jonas : We need JOIN_PUSH_EXPRESSIONS rule to work + // with the FILTER_INTO_JOIN rule, + // otherwise we end up with filter expressions on join conditions + // i.e. (emp join dept on emp.deptno * 2 = dept.deptno) which + // LogicalJoinPlugin can't handle. + // CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES, + CoreRules.JOIN_PUSH_EXPRESSIONS, + // push filter into the children of a join + CoreRules.JOIN_CONDITION_PUSH, + // push filter through an aggregation + CoreRules.FILTER_AGGREGATE_TRANSPOSE, + // Jonas : the FILTER_PROJECT_TRANSPOSE rule causes Calcite to push filter conditions (like x = y) + // into a projection (i.e. project(x, y, z)) which (sometimes) causes errors as the project will + // select column y instead of x for instance and loose the reference to y. + // It's the reason for several test failures in neurolang. + // CoreRules.FILTER_PROJECT_TRANSPOSE, + // push a filter past a setop + CoreRules.FILTER_SET_OP_TRANSPOSE, CoreRules.FILTER_MERGE); + + /** + * RuleSet about project Dont' add CoreRules.PROJECT_REMOVE + */ + static final RuleSet PROJECT_RULES = RuleSets.ofList( + CoreRules.AGGREGATE_PROJECT_MERGE, + // push a projection past a filter + CoreRules.PROJECT_FILTER_TRANSPOSE, + // merge projections + CoreRules.PROJECT_MERGE, + // removes constant keys from an Agg + CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + // push project through a Union + CoreRules.PROJECT_SET_OP_TRANSPOSE + ); + + /** + * RuleSet about aggregate + */ + static final RuleSet AGGREGATE_RULES = RuleSets.ofList(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, + CoreRules.AGGREGATE_JOIN_TRANSPOSE, CoreRules.AGGREGATE_PROJECT_MERGE, + CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_MERGE, + // Important. Removes unecessary distinct calls + CoreRules.AGGREGATE_REMOVE, CoreRules.AGGREGATE_JOIN_REMOVE); + + /** + * RuleSet for merging joins. All joins are merged into a large multi-join, + * which is then optimised by one of the JOIN_REORDER_RULES. + */ + static final RuleSet JOIN_REORDER_PREPARE_RULES = RuleSets.ofList( + // merge project to MultiJoin + CoreRules.PROJECT_MULTI_JOIN_MERGE, + // merge filter to MultiJoin + CoreRules.FILTER_MULTI_JOIN_MERGE, + // merge join to MultiJoin + CoreRules.JOIN_TO_MULTI_JOIN); + + /** + * Rules to reorder joins + */ + static final RuleSet JOIN_REORDER_RULES = RuleSets.ofList( + // optimize multi joins + CoreRules.MULTI_JOIN_OPTIMIZE, + // optmize bushy multi joins + CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY); + + /** + * Rules to reorder joins using associate and commute rules. See + * https://www.querifylabs.com/blog/rule-based-query-optimization for an + * explanation. JoinCommuteRule causes exhaustive search and should probably not + * be used. + */ + static final RuleSet JOIN_COMMUTE_ASSOCIATE_RULES = RuleSets.ofList( + // changes a join based on associativity rule. + CoreRules.JOIN_ASSOCIATE, CoreRules.JOIN_COMMUTE); + + /** + * RuleSet to do logical optimize. + */ + static final RuleSet LOGICAL_RULES = RuleSets.ofList( + // remove union with only a single child + CoreRules.UNION_REMOVE, + // convert non-all union into all-union + distinct + CoreRules.UNION_TO_DISTINCT, CoreRules.MINUS_MERGE, + // aggregation and projection rules + // CoreRules.AGGREGATE_PROJECT_MERGE, + // CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + // CoreRules.AGGREGATE_REMOVE, CoreRules.AGGREGATE_JOIN_TRANSPOSE_EXTENDED, + CoreRules.AGGREGATE_UNION_AGGREGATE_FIRST, CoreRules.AGGREGATE_UNION_AGGREGATE_SECOND); + + /** + * RuleSet for MaterializedViews. + */ + static final RuleSet MATERIALIZATION_RULES = RuleSets.ofList(MaterializedViewRules.FILTER_SCAN, + MaterializedViewRules.PROJECT_FILTER, MaterializedViewRules.FILTER, + MaterializedViewRules.PROJECT_JOIN, MaterializedViewRules.JOIN, + MaterializedViewRules.PROJECT_AGGREGATE, MaterializedViewRules.AGGREGATE); + + /** + * Initial rule set from dask_sql with a couple rules added by Demian. Not used + * but kept for reference. + */ + static final RuleSet DASK_DEFAULT_CORE_RULES = RuleSets.ofList( + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.FILTER_SET_OP_TRANSPOSE, + CoreRules.FILTER_AGGREGATE_TRANSPOSE, CoreRules.FILTER_INTO_JOIN, CoreRules.JOIN_CONDITION_PUSH, + CoreRules.PROJECT_JOIN_TRANSPOSE, CoreRules.PROJECT_MULTI_JOIN_MERGE, + CoreRules.JOIN_TO_MULTI_JOIN, CoreRules.MULTI_JOIN_OPTIMIZE, + CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY, CoreRules.AGGREGATE_JOIN_TRANSPOSE, + CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_MERGE, + CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, + // Don't add this rule as it removes projections which are used to rename colums + // CoreRules.PROJECT_REMOVE, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + + /** + * Builds a HepProgram for the given set of rules and with the given order. If + * type is COLLECTION, rules are added as collection. Otherwise, rules are added + * sequentially. + */ + public static HepProgram hepProgram(final RuleSet rules, final HepMatchOrder order, + final HepExecutionType type) { + final HepProgramBuilder builder = new HepProgramBuilder().addMatchOrder(order); + switch (type) { + case SEQUENCE: + for (RelOptRule rule : rules) { + builder.addRuleInstance(rule); + } + break; + case COLLECTION: + List rulesCollection = new ArrayList(); + rules.iterator().forEachRemaining(rulesCollection::add); + builder.addRuleCollection(rulesCollection); + break; + } + return builder.build(); + } + + public enum HepExecutionType { + SEQUENCE, COLLECTION + } + +} \ No newline at end of file diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 21bd319b6..798e4001f 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -7,10 +7,9 @@ import java.util.List; import java.util.Properties; +import com.dask.sql.application.DaskRuleSets.HepExecutionType; import com.dask.sql.schema.DaskSchema; import org.apache.calcite.sql.SqlDialect; -import org.apache.calcite.sql.dialect.PostgresqlSqlDialect; -import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.config.CalciteConnectionProperty; @@ -19,23 +18,13 @@ import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.plan.Context; import org.apache.calcite.plan.Contexts; +import org.apache.calcite.plan.RelOptMaterialization; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.hep.HepMatchOrder; import org.apache.calcite.plan.hep.HepPlanner; -import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; -import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; -import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; -import org.apache.calcite.rel.rules.FilterJoinRule; -import org.apache.calcite.rel.rules.FilterMergeRule; -import org.apache.calcite.rel.rules.FilterRemoveIsNotDistinctFromRule; -import org.apache.calcite.rel.rules.ProjectJoinTransposeRule; -import org.apache.calcite.rel.rules.ProjectMergeRule; -import org.apache.calcite.rel.rules.ProjectRemoveRule; -import org.apache.calcite.rel.rules.ReduceExpressionsRule; -import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rex.RexExecutorImpl; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.SqlNode; @@ -61,10 +50,13 @@ * This class is taken (in parts) from the blazingSQL project. */ public class RelationalAlgebraGenerator { + /// The created planner private Planner planner; /// The planner for optimized queries private HepPlanner hepPlanner; + /// The planner to optimise queries with materialized views + private HepPlanner viewBasedPlanner; /// Create a new relational algebra generator from a schema public RelationalAlgebraGenerator(final DaskSchema schema) throws ClassNotFoundException, SQLException { @@ -80,6 +72,13 @@ public RelationalAlgebraGenerator(final DaskSchema schema) throws ClassNotFoundE planner = Frameworks.getPlanner(config); hepPlanner = getHepPlanner(config); + viewBasedPlanner = getViewBasedPlanner(config); + final DaskCalciteMaterializer materializer = new DaskCalciteMaterializer(rootSchema, schema.getName(), config); + for (RelOptMaterialization materialization : materializer.getMaterializations()) { + // System.out.println("Adding materialized view for \n" + getRelationalAlgebraString(materialization.tableRel) + // + "\nwith sql query plan " + getRelationalAlgebraString(materialization.queryRel)); + viewBasedPlanner.addMaterialization(materialization); + } } /// Create the framework config, e.g. containing with SQL dialect we speak @@ -131,29 +130,35 @@ private CalciteConnection getCalciteConnection() throws SQLException { /// get an optimizer hep planner private HepPlanner getHepPlanner(final FrameworkConfig config) { - // TODO: check if these rules are sensible - // Taken from blazingSQL - final HepProgram program = new HepProgramBuilder() - .addRuleInstance(AggregateExpandDistinctAggregatesRule.Config.JOIN.toRule()) - .addRuleInstance(FilterAggregateTransposeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterJoinRule.JoinConditionPushRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectJoinTransposeRule.Config.DEFAULT.toRule()) - // In principle, not a bad idea. But we need to keep the most - // outer project - because otherwise the column name information is lost - // in cases such as SELECT x AS a, y AS B FROM df - // .addRuleInstance(ProjectRemoveRule.Config.DEFAULT.toRule()) - .addRuleInstance(ReduceExpressionsRule.ProjectReduceExpressionsRule.Config.DEFAULT.toRule()) - // this rule might make sense, but turns a < 1 into a SEARCH expression - // which is currently not supported by dask-sql - // .addRuleInstance(ReduceExpressionsRule.FilterReduceExpressionsRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterRemoveIsNotDistinctFromRule.Config.DEFAULT.toRule()) - // TODO: remove AVG - .addRuleInstance(AggregateReduceFunctionsRule.Config.DEFAULT.toRule()).build(); - - return new HepPlanner(program, config.getContext()); + final HepProgramBuilder builder = new HepProgramBuilder(); + builder.addMatchOrder(HepMatchOrder.ARBITRARY).addMatchLimit(Integer.MAX_VALUE); + // Legacy rule set + // for (RelOptRule rule : DaskRuleSets.DASK_DEFAULT_CORE_RULES){ + // builder.addRuleInstance(rule); + // } + + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + // join reorder. The first set of rules transforms joins into a large multijoin. + // the second set of rules splits the multijoins by applying a heuristic to + // determine the best join order. + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.SEQUENCE)); + + // optimize logical plan. Be careful not to introduce rules in this set which + // mess up the join order from the step before. + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.SEQUENCE)); + + return new HepPlanner(builder.build(), config.getContext()); } /// Parse a sql string into a sql tree @@ -198,4 +203,17 @@ public RelNode getOptimizedRelationalAlgebra(final RelNode nonOptimizedPlan) { public String getRelationalAlgebraString(final RelNode relNode) { return RelOptUtil.toString(relNode); } + + public RelNode getMaterializedViewsOptimizedRelationalAlgebra(final RelNode relPlan) { + viewBasedPlanner.setRoot(relPlan); + return viewBasedPlanner.findBestExp(); + } + + private HepPlanner getViewBasedPlanner(final FrameworkConfig config) { + final HepProgramBuilder builder = new HepProgramBuilder(); + builder.addMatchOrder(HepMatchOrder.ARBITRARY).addMatchLimit(Integer.MAX_VALUE); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.MATERIALIZATION_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + return new HepPlanner(builder.build(), config.getContext()); + } } diff --git a/planner/src/main/java/com/dask/sql/schema/DaskTable.java b/planner/src/main/java/com/dask/sql/schema/DaskTable.java index 31681f9fe..c5c730728 100644 --- a/planner/src/main/java/com/dask/sql/schema/DaskTable.java +++ b/planner/src/main/java/com/dask/sql/schema/DaskTable.java @@ -28,13 +28,22 @@ public class DaskTable implements ProjectableFilterableTable { private final ArrayList> tableColumns; // Name of this table private final String name; + // Optional sql query. If given, the table is considered a materialized view + // and added to the planner for view-based optimization + private final String sql; - /// Construct a new table with the given name - public DaskTable(final String name) { + /// Construct a new table with the given name and sql + public DaskTable(final String name, final String sql) { this.name = name; + this.sql = sql; this.tableColumns = new ArrayList>(); } + /// Construct a new table with the given name + public DaskTable(final String name) { + this(name, null); + } + /// Add a column with the given type public void addColumn(final String columnName, final SqlTypeName columnType) { this.tableColumns.add(new Pair<>(columnName, columnType)); @@ -45,6 +54,14 @@ public String getTableName() { return this.name; } + public String getSql() { + return this.sql; + } + + public boolean isMaterializedView() { + return this.sql != null; + } + /// calcite method: Get the type of a row of this table (using the type factory) @Override public RelDataType getRowType(final RelDataTypeFactory relDataTypeFactory) { diff --git a/setup.py b/setup.py index 86e0e424f..e816cc3a6 100755 --- a/setup.py +++ b/setup.py @@ -75,10 +75,8 @@ def run(self): setup_requires=["setuptools_scm"] + sphinx_requirements, install_requires=[ "dask[dataframe]>=2.19.0", - "pandas<1.2.0,>=1.0.0", # pandas 1.2.0 introduced float NaN dtype, - # which is currently not working with dask, - # so the test is failing, see https://github.com/dask/dask/issues/7156 - # below 1.0, there were no nullable ext. types + "distributed", + "pandas>=1.0.0", # below 1.0, there were no nullable ext. types "jpype1>=1.0.2", "fastapi>=0.61.1", "uvicorn>=0.11.3", diff --git a/tests/integration/test_compatibility.py b/tests/integration/test_compatibility.py new file mode 100644 index 000000000..b3f1dab1b --- /dev/null +++ b/tests/integration/test_compatibility.py @@ -0,0 +1,884 @@ +""" +The tests in this module are taken from +the fugue-sql module to test the compatibility +with their "understanding" of SQL +They run randomized tests and compare with sqlite. + +There are some changes compared to the fugueSQL +tests, especially when it comes to sort order: +dask-sql does not enforce a specific order after groupby +""" + +import sqlite3 +from datetime import datetime, timedelta + +import pandas as pd +import numpy as np +from pandas.testing import assert_frame_equal +from dask_sql import Context + + +def eq_sqlite(sql, **dfs): + c = Context() + engine = sqlite3.connect(":memory:") + + for name, df in dfs.items(): + c.create_table(name, df) + df.to_sql(name, engine, index=False) + + dask_result = c.sql(sql).compute().reset_index(drop=True) + sqlite_result = pd.read_sql(sql, engine).reset_index(drop=True) + + assert_frame_equal(dask_result, sqlite_result, check_dtype=False) + + +def make_rand_df(size: int, **kwargs): + np.random.seed(0) + data = {} + for k, v in kwargs.items(): + if not isinstance(v, tuple): + v = (v, 0.0) + dt, null_ct = v[0], v[1] + if dt is int: + s = np.random.randint(10, size=size) + elif dt is bool: + s = np.where(np.random.randint(2, size=size), True, False) + elif dt is float: + s = np.random.rand(size) + elif dt is str: + r = [f"ssssss{x}" for x in range(10)] + c = np.random.randint(10, size=size) + s = np.array([r[x] for x in c]) + elif dt is datetime: + rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)] + c = np.random.randint(10, size=size) + s = np.array([rt[x] for x in c]) + else: + raise NotImplementedError + ps = pd.Series(s) + if null_ct > 0: + idx = np.random.choice(size, null_ct, replace=False).tolist() + ps[idx] = None + data[k] = ps + return pd.DataFrame(data) + + +def test_basic_select_from(): + df = make_rand_df(5, a=(int, 2), b=(str, 3), c=(float, 4)) + eq_sqlite("SELECT 1 AS a, 1.5 AS b, 'x' AS c") + eq_sqlite("SELECT 1+2 AS a, 1.5*3 AS b, 'x' AS c") + eq_sqlite("SELECT * FROM a", a=df) + eq_sqlite("SELECT * FROM a AS x", a=df) + eq_sqlite("SELECT b AS bb, a+1-2*3.0/4 AS cc, x.* FROM a AS x", a=df) + eq_sqlite("SELECT *, 1 AS x, 2.5 AS y, 'z' AS z FROM a AS x", a=df) + eq_sqlite("SELECT *, -(1.0+a)/3 AS x, +(2.5) AS y FROM a AS x", a=df) + + +def test_case_when(): + a = make_rand_df(100, a=(int, 20), b=(str, 30), c=(float, 40)) + eq_sqlite( + """ + SELECT a,b,c, + CASE + WHEN a<10 THEN a+3 + WHEN c<0.5 THEN a+5 + ELSE (1+2)*3 + a + END AS d + FROM a + """, + a=a, + ) + + +def test_drop_duplicates(): + # simplest + a = make_rand_df(100, a=int, b=int) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + # mix of number and nan + a = make_rand_df(100, a=(int, 50), b=(int, 50)) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + # mix of number and string and nulls + a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + + +def test_order_by_no_limit(): + a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + + +def test_order_by_limit(): + a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a LIMIT 0 + """, + a=a, + ) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a ORDER BY a NULLS FIRST, b NULLS FIRST LIMIT 2 + """, + a=a, + ) + eq_sqlite( + """ + SELECT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST LIMIT 10 + """, + a=a, + ) + + +def test_where(): + df = make_rand_df(100, a=(int, 30), b=(str, 30), c=(float, 30)) + eq_sqlite("SELECT * FROM a WHERE TRUE OR TRUE", a=df) + eq_sqlite("SELECT * FROM a WHERE TRUE AND TRUE", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE OR FALSE", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE AND FALSE", a=df) + + eq_sqlite("SELECT * FROM a WHERE TRUE OR b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE TRUE AND b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE OR b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE AND b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE a=10 OR b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE c IS NOT NULL OR (a<5 AND b IS NOT NULL)", a=df) + + df = make_rand_df(100, a=(float, 30), b=(float, 30), c=(float, 30)) + eq_sqlite("SELECT * FROM a WHERE a<0.5 AND b<0.5 AND c<0.5", a=df) + eq_sqlite("SELECT * FROM a WHERE a<0.5 OR b<0.5 AND c<0.5", a=df) + eq_sqlite("SELECT * FROM a WHERE a IS NULL OR (b<0.5 AND c<0.5)", a=df) + eq_sqlite("SELECT * FROM a WHERE a*b IS NULL OR (b*c<0.5 AND c*a<0.5)", a=df) + + +def test_in_between(): + df = make_rand_df(10, a=(int, 3), b=(str, 3)) + eq_sqlite("SELECT * FROM a WHERE a IN (2,4,6)", a=df) + eq_sqlite("SELECT * FROM a WHERE a BETWEEN 2 AND 4+1", a=df) + eq_sqlite("SELECT * FROM a WHERE a NOT IN (2,4,6) AND a IS NOT NULL", a=df) + eq_sqlite("SELECT * FROM a WHERE a NOT BETWEEN 2 AND 4+1 AND a IS NOT NULL", a=df) + + +def test_join_inner(): + a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) + b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) + eq_sqlite( + """ + SELECT + a.*, d, d*c AS x + FROM a + INNER JOIN b ON a.a=b.a AND a.b=b.b + ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST + """, + a=a, + b=b, + ) + + +def test_join_left(): + a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) + b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) + eq_sqlite( + """ + SELECT + a.*, d, d*c AS x + FROM a LEFT JOIN b ON a.a=b.a AND a.b=b.b + ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST + """, + a=a, + b=b, + ) + + +def test_join_cross(): + a = make_rand_df(10, a=(int, 4), b=(str, 4), c=(float, 4)) + b = make_rand_df(20, dd=(float, 1), aa=(int, 1), bb=(str, 1)) + eq_sqlite("SELECT * FROM a CROSS JOIN b", a=a, b=b) + + +def test_join_multi(): + a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) + b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) + c = make_rand_df(80, dd=(float, 10), a=(int, 10), b=(str, 10)) + eq_sqlite( + """ + SELECT a.*,d,dd FROM a + INNER JOIN b ON a.a=b.a AND a.b=b.b + INNER JOIN c ON a.a=c.a AND c.b=b.b + ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, dd NULLS FIRST, d NULLS FIRST + """, + a=a, + b=b, + c=c, + ) + + +def test_agg_count_no_group_by(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + COUNT(a) AS c_a, + COUNT(DISTINCT a) AS cd_a, + COUNT(b) AS c_b, + COUNT(DISTINCT b) AS cd_b, + COUNT(c) AS c_c, + COUNT(DISTINCT c) AS cd_c, + COUNT(d) AS c_d, + COUNT(DISTINCT d) AS cd_d, + COUNT(e) AS c_e, + COUNT(DISTINCT a) AS cd_e + FROM a + """, + a=a, + ) + + +def test_agg_count(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + a, b, a+1 AS c, + COUNT(c) AS c_c, + COUNT(DISTINCT c) AS cd_c, + COUNT(d) AS c_d, + COUNT(DISTINCT d) AS cd_d, + COUNT(e) AS c_e, + COUNT(DISTINCT a) AS cd_e + FROM a GROUP BY a, b + """, + a=a, + ) + + +def test_agg_sum_avg_no_group_by(): + eq_sqlite( + """ + SELECT + SUM(a) AS sum_a, + AVG(a) AS avg_a + FROM a + """, + a=pd.DataFrame({"a": [float("nan")]}), + ) + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + SUM(a) AS sum_a, + AVG(a) AS avg_a, + SUM(c) AS sum_c, + AVG(c) AS avg_c, + SUM(e) AS sum_e, + AVG(e) AS avg_e, + SUM(a)+AVG(e) AS mix_1, + SUM(a+e) AS mix_2 + FROM a + """, + a=a, + ) + + +def test_agg_sum_avg(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + a,b, a+1 AS c, + SUM(c) AS sum_c, + AVG(c) AS avg_c, + SUM(e) AS sum_e, + AVG(e) AS avg_e, + SUM(a)+AVG(e) AS mix_1, + SUM(a+e) AS mix_2 + FROM a GROUP BY a,b + """, + a=a, + ) + + +def test_agg_min_max_no_group_by(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + MIN(a) AS min_a, + MAX(a) AS max_a, + MIN(b) AS min_b, + MAX(b) AS max_b, + MIN(c) AS min_c, + MAX(c) AS max_c, + MIN(d) AS min_d, + MAX(d) AS max_d, + MIN(e) AS min_e, + MAX(e) AS max_e, + MIN(a+e) AS mix_1, + MIN(a)+MIN(e) AS mix_2 + FROM a + """, + a=a, + ) + + +def test_agg_min_max(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + a, b, a+1 AS c, + MIN(c) AS min_c, + MAX(c) AS max_c, + MIN(d) AS min_d, + MAX(d) AS max_d, + MIN(e) AS min_e, + MAX(e) AS max_e, + MIN(a+e) AS mix_1, + MIN(a)+MIN(e) AS mix_2 + FROM a GROUP BY a, b + """, + a=a, + ) + + +# TODO: Except not implemented so far +# def test_window_row_number(): +# a = make_rand_df(100, a=int, b=(float, 50)) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS FIRST) AS a1, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS LAST) AS a2, +# ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS FIRST) AS a3, +# ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS LAST) AS a4, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC) AS a5 +# FROM a +# """, +# a=a, +# ) + +# a = make_rand_df( +# 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float +# ) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS FIRST, e) AS a1, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS LAST, e) AS a2, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC, e) AS a3, +# ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a,b DESC, e) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_row_number_partition_by(): +# a = make_rand_df(100, a=int, b=(float, 50)) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC) AS a5 +# FROM a +# """, +# a=a, +# ) + +# a = make_rand_df( +# 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float +# ) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC, e) AS a3, +# ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a,b DESC, e) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_ranks(): +# a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT *, +# RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, +# DENSE_RANK() OVER (ORDER BY a ASC, b DESC NULLS LAST, c DESC) AS a2, +# PERCENT_RANK() OVER (ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_ranks_partition_by(): +# a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT *, +# RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, +# DENSE_RANK() OVER +# (PARTITION BY a ORDER BY a ASC, b DESC NULLS LAST, c DESC) +# AS a2, +# PERCENT_RANK() OVER +# (PARTITION BY a ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_lead_lag(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT +# LEAD(b,1) OVER (ORDER BY a) AS a1, +# LEAD(b,2,10) OVER (ORDER BY a) AS a2, +# LEAD(b,1) OVER (PARTITION BY c ORDER BY a) AS a3, +# LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5, + +# LAG(b,1) OVER (ORDER BY a) AS b1, +# LAG(b,2,10) OVER (ORDER BY a) AS b2, +# LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3, +# LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_lead_lag_partition_by(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT +# LEAD(b,1,10) OVER (PARTITION BY c ORDER BY a) AS a3, +# LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5, + +# LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3, +# LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_sum_avg(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# for func in ["SUM", "AVG"]: +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER () AS a1, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # >= 1.1.0 has bug on these agg function with groupby+rolloing +# # https://github.com/pandas-dev/pandas/issues/35557 +# if pd.__version__ < "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_sum_avg_partition_by(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# for func in ["SUM", "AVG"]: +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # 1.1.0 has bug on these agg function with groupby+rolloing +# # https://github.com/pandas-dev/pandas/issues/35557 +# if pd.__version__ < "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_min_max(): +# for func in ["MIN", "MAX"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER () AS a1, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # < 1.1.0 has bugs on these agg function with rolloing (no group by) +# if pd.__version__ >= "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 +# FROM a +# """, +# a=a, +# ) +# # == 1.1.0 has bugs on these agg function with rolloing (with group by) +# # https://github.com/pandas-dev/pandas/issues/35557 +# # < 1.1.0 has bugs on nulls when rolling with forward looking +# if pd.__version__ < "1.1": +# b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6 +# FROM a +# """, +# a=b, +# ) + +# TODO: Except not implemented so far +# def test_window_min_max_partition_by(): +# for func in ["MIN", "MAX"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # >= 1.1.0 has bugs on these agg function with rolloing (with group by) +# # https://github.com/pandas-dev/pandas/issues/35557 +# # < 1.1.0 has bugs on nulls when rolling with forward looking +# if pd.__version__ < "1.1": +# b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6 +# FROM a +# """, +# a=b, +# ) + +# TODO: Except not implemented so far +# def test_window_count(): +# for func in ["COUNT"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER () AS a1, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6, + +# {func}(c) OVER () AS b1, +# {func}(c) OVER (PARTITION BY c) AS b2, +# {func}(c) OVER (PARTITION BY c,b) AS b3, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4, +# {func}(c) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS b6 +# FROM a +# """, +# a=a, +# ) +# # < 1.1.0 has bugs on these agg function with rolloing (no group by) +# # == 1.1.0 has this bug +# # https://github.com/pandas-dev/pandas/issues/35579 +# if pd.__version__ >= "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a6, +# {func}(b) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9, + +# {func}(c) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b6, +# {func}(c) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_count_partition_by(): +# for func in ["COUNT"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6, + +# {func}(c) OVER (PARTITION BY c) AS b2, +# {func}(c) OVER (PARTITION BY c,b) AS b3, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4, +# {func}(c) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS b6 +# FROM a +# """, +# a=a, +# ) +# # < 1.1.0 has bugs on these agg function with rolloing (no group by) +# # == 1.1.0 has this bug +# # https://github.com/pandas-dev/pandas/issues/35579 +# if pd.__version__ >= "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9, + +# {func}(c) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9 +# FROM a +# """, +# a=a, +# ) + +# TODO: Windowing not implemented so far +# def test_nested_query(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT * FROM ( +# SELECT *, +# ROW_NUMBER() OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS r +# FROM a) +# WHERE r=1 +# """, +# a=a, +# ) + + +def test_union(): + a = make_rand_df(30, b=(int, 10), c=(str, 10)) + b = make_rand_df(80, b=(int, 50), c=(str, 50)) + c = make_rand_df(100, b=(int, 50), c=(str, 50)) + eq_sqlite( + """ + SELECT * FROM a + UNION SELECT * FROM b + UNION SELECT * FROM c + ORDER BY b NULLS FIRST, c NULLS FIRST + """, + a=a, + b=b, + c=c, + ) + eq_sqlite( + """ + SELECT * FROM a + UNION ALL SELECT * FROM b + UNION ALL SELECT * FROM c + ORDER BY b NULLS FIRST, c NULLS FIRST + """, + a=a, + b=b, + c=c, + ) + + +# TODO: Except not implemented so far +# def test_except(): +# a = make_rand_df(30, b=(int, 10), c=(str, 10)) +# b = make_rand_df(80, b=(int, 50), c=(str, 50)) +# c = make_rand_df(100, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT * FROM c +# EXCEPT SELECT * FROM b +# EXCEPT SELECT * FROM c +# """, +# a=a, +# b=b, +# c=c, +# ) + +# TODO: Intersect not implemented so far +# def test_intersect(): +# a = make_rand_df(30, b=(int, 10), c=(str, 10)) +# b = make_rand_df(80, b=(int, 50), c=(str, 50)) +# c = make_rand_df(100, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT * FROM c +# INTERSECT SELECT * FROM b +# INTERSECT SELECT * FROM c +# """, +# a=a, +# b=b, +# c=c, +# ) + + +def test_with(): + a = make_rand_df(30, a=(int, 10), b=(str, 10)) + b = make_rand_df(80, ax=(int, 10), bx=(str, 10)) + eq_sqlite( + """ + WITH + aa AS ( + SELECT a AS aa, b AS bb FROM a + ), + c AS ( + SELECT aa-1 AS aa, bb FROM aa + ) + SELECT * FROM c UNION SELECT * FROM b + ORDER BY aa NULLS FIRST, bb NULLS FIRST + """, + a=a, + b=b, + ) + + +def test_integration_1(): + a = make_rand_df(100, a=int, b=str, c=float, d=int, e=bool, f=str, g=str, h=float) + eq_sqlite( + """ + WITH + a1 AS ( + SELECT a+1 AS a, b, c FROM a + ), + a2 AS ( + SELECT a,MAX(b) AS b_max, AVG(c) AS c_avg FROM a GROUP BY a + ), + a3 AS ( + SELECT d+2 AS d, f, g, h FROM a WHERE e + ) + SELECT a1.a,b,c,b_max,c_avg,f,g,h FROM a1 + INNER JOIN a2 ON a1.a=a2.a + LEFT JOIN a3 ON a1.a=a3.d + ORDER BY a1.a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, f NULLS FIRST, g NULLS FIRST, h NULLS FIRST + """, + a=a, + ) diff --git a/tests/integration/test_except.py b/tests/integration/test_except.py new file mode 100644 index 000000000..9fe98131a --- /dev/null +++ b/tests/integration/test_except.py @@ -0,0 +1,29 @@ +def test_except_empty(c, df): + result_df = c.sql( + """ + SELECT * FROM df + EXCEPT + SELECT * FROM df + """ + ) + result_df = result_df.compute() + assert len(result_df) == 0 + + +def test_except_non_empty(c, df): + result_df = c.sql( + """ + ( + SELECT 1 as "a" + UNION + SELECT 2 as "a" + UNION + SELECT 3 as "a" + ) + EXCEPT + SELECT 2 as "a" + """ + ) + result_df = result_df.compute() + assert result_df.columns == "a" + assert set(result_df["a"]) == set([1, 3]) diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index 76d35939d..d305d6b06 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -127,7 +127,7 @@ def test_group_by_nan(c): ) df = df.compute() - expected_df = pd.DataFrame({"c": [3, 1]}) + expected_df = pd.DataFrame({"c": [3, float("nan"), 1]}) # The dtype in pandas 1.0.5 and pandas 1.1.0 are different, so # we can not check here assert_frame_equal(df, expected_df, check_dtype=False) @@ -206,3 +206,16 @@ def test_aggregations(c): } ) assert_frame_equal(df.sort_values("user_id").reset_index(drop=True), expected_df) + + df = c.sql( + """ + SELECT + MAX(a) AS "max", + MIN(a) AS "min" + FROM string_table + """ + ) + df = df.compute() + + expected_df = pd.DataFrame({"max": ["a normal string"], "min": ["%_%"]}) + assert_frame_equal(df.reset_index(drop=True), expected_df) diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 5d8c83807..ffec77a84 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -228,6 +228,29 @@ def test_sort_with_nan_more_columns(): ) +def test_sort_with_nan_many_partitions(): + c = Context() + df = pd.DataFrame({"a": [float("nan"), 1] * 30, "b": [1, 2, 3] * 20,}) + c.create_table("df", dd.from_pandas(df, npartitions=10)) + + df_result = ( + c.sql("SELECT * FROM df ORDER BY a NULLS FIRST, b ASC NULLS FIRST") + .compute() + .reset_index(drop=True) + ) + + assert_frame_equal( + df_result, + pd.DataFrame( + { + "a": [float("nan")] * 30 + [1] * 30, + "b": [1] * 10 + [2] * 10 + [3] * 10 + [1] * 10 + [2] * 10 + [3] * 10, + } + ), + check_names=False, + ) + + def test_sort_strings(c): string_table = pd.DataFrame({"a": ["zzhsd", "öfjdf", "baba"]}) c.create_table("string_table", string_table)