diff --git a/CMakeLists.txt b/CMakeLists.txt index d3eb572c90a35..597352c662301 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,7 +173,7 @@ if(${VELOX_BUILD_PYTHON_PACKAGE}) set(VELOX_ENABLE_EXPRESSION ON) set(VELOX_ENABLE_PARSE ON) set(VELOX_ENABLE_EXEC ON) - set(VELOX_ENABLE_AGGREGATES OFF) + set(VELOX_ENABLE_AGGREGATES ON) set(VELOX_ENABLE_HIVE_CONNECTOR OFF) set(VELOX_ENABLE_TPCH_CONNECTOR OFF) set(VELOX_ENABLE_SPARK_FUNCTIONS ON) diff --git a/pyvelox/CMakeLists.txt b/pyvelox/CMakeLists.txt index 4bffa203b7d49..d0c2d04c932f4 100644 --- a/pyvelox/CMakeLists.txt +++ b/pyvelox/CMakeLists.txt @@ -35,7 +35,9 @@ if(VELOX_BUILD_PYTHON_PACKAGE) velox_functions_prestosql velox_parse_parser velox_functions_prestosql - velox_functions_spark) + velox_functions_spark + velox_aggregates + velox_functions_spark_aggregates) install(TARGETS pyvelox LIBRARY DESTINATION .) else() diff --git a/pyvelox/signatures.cpp b/pyvelox/signatures.cpp index c12912f759040..27b5674f6557f 100644 --- a/pyvelox/signatures.cpp +++ b/pyvelox/signatures.cpp @@ -15,9 +15,12 @@ */ #include "signatures.h" // @manual +#include "velox/exec/Aggregate.h" #include "velox/functions/FunctionRegistry.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/aggregates/Register.h" namespace facebook::velox::py { @@ -31,6 +34,24 @@ void registerSparkFunctions(const std::string& prefix) { facebook::velox::functions::sparksql::registerFunctions(prefix); } +void registerPrestoAggregateFunctions(const std::string& prefix) { + facebook::velox::aggregate::prestosql::registerAllAggregateFunctions(prefix); +} + +void registerSparkAggregateFunctions(const std::string& prefix) { + facebook::velox::functions::aggregate::sparksql::registerAggregateFunctions( + prefix); +} + +exec::AggregateFunctionSignatureMap getAggregateSignatures() { + return exec::getAggregateFunctionSignatures(); +} + +void clearAggregateSignatures() { + exec::aggregateFunctions().withWLock( + [&](auto& aggregateFunctions) { aggregateFunctions.clear(); }); +} + void addSignatureBindings(py::module& m, bool asModuleLocalDefinitions) { // TypeSignature py::class_ typeSignature( @@ -53,6 +74,19 @@ void addSignatureBindings(py::module& m, bool asModuleLocalDefinitions) { functionSignature.def( "constant_arguments", &exec::FunctionSignature::constantArguments); + // AggregateFunctionSignature + py::class_< + exec::AggregateFunctionSignature, + std::unique_ptr> + aggregateFunctionSignature( + m, + "AggregateFunctionSignature", + py::module_local(asModuleLocalDefinitions)); + aggregateFunctionSignature.def( + "__str__", &exec::AggregateFunctionSignature::toString); + aggregateFunctionSignature.def( + "intermediate_type", &exec::AggregateFunctionSignature::intermediateType); + m.def( "clear_signatures", &clearFunctionRegistry, @@ -75,5 +109,28 @@ void addSignatureBindings(py::module& m, bool asModuleLocalDefinitions) { &getFunctionSignatures, py::return_value_policy::reference, "Returns a dictionary of the current signatures."); + + m.def( + "register_presto_aggregate_signatures", + ®isterPrestoAggregateFunctions, + "Adds Presto Aggregate signatures to the function registry.", + py::arg("prefix") = ""); + + m.def( + "register_spark_aggregate_signatures", + ®isterSparkAggregateFunctions, + "Adds Spark Aggregate signatures to the function registry.", + py::arg("prefix") = ""); + + m.def( + "get_aggregate_function_signatures", + &getAggregateSignatures, + py::return_value_policy::reference, + "Returns a dictionary of the current aggregate signatures."); + + m.def( + "clear_aggregate_signatures", + &clearAggregateSignatures, + "Clears the Aggregate function registry."); } } // namespace facebook::velox::py diff --git a/pyvelox/test/test_signatures.py b/pyvelox/test/test_signatures.py index c8ce8354e0d07..dddb2869537ae 100644 --- a/pyvelox/test/test_signatures.py +++ b/pyvelox/test/test_signatures.py @@ -64,3 +64,21 @@ def test_function_prefix(self): concat_signatures = spark_signatures["barconcat"] self.assertTrue(len(concat_signatures) > 0) + + def test_aggregate_signatures(self): + pv.clear_aggregate_signatures() + + pv.register_presto_aggregate_signatures() + presto_agg_signatures = pv.get_aggregate_function_signatures() + + min_signatures = presto_agg_signatures["min"] + self.assertTrue(len(min_signatures) > 0) + + max_signatures = presto_agg_signatures["max"] + self.assertTrue(len(max_signatures) > 0) + + pv.clear_aggregate_signatures() + + pv.register_spark_aggregate_signatures() + spark_agg_signatures = pv.get_aggregate_function_signatures() + self.assertTrue(len(spark_agg_signatures) > 0) diff --git a/scripts/signature.py b/scripts/signature.py index 17aa0b33c87e5..190a527d631ed 100644 --- a/scripts/signature.py +++ b/scripts/signature.py @@ -63,6 +63,28 @@ def export(args): return 0 +def export_aggregates(args): + """Exports Velox Aggregate function signatures.""" + pv.clear_aggregate_signatures() + + if args.spark: + pv.register_spark_aggregate_signatures() + + if args.presto: + pv.register_presto_aggregate_signatures() + + signatures = pv.get_aggregate_function_signatures() + + # Convert signatures to json + jsoned_signatures = {} + for key in signatures.keys(): + jsoned_signatures[key] = [str(value) for value in signatures[key]] + + # Persist to file + json.dump(jsoned_signatures, args.output_file) + return 0 + + def diff_signatures(base_signatures, contender_signatures): """Diffs Velox function signatures. Returns a tuple of the delta diff and exit status""" @@ -177,6 +199,13 @@ def parse_args(args): export_command_parser.add_argument("--presto", action="store_true") export_command_parser.add_argument("output_file", type=argparse.FileType("w")) + export_aggregates_command_parser = command.add_parser("export_aggregates") + export_aggregates_command_parser.add_argument("--spark", action="store_true") + export_aggregates_command_parser.add_argument("--presto", action="store_true") + export_aggregates_command_parser.add_argument( + "output_file", type=argparse.FileType("w") + ) + diff_command_parser = command.add_parser("diff") diff_command_parser.add_argument("base", type=argparse.FileType("r")) diff_command_parser.add_argument("contender", type=argparse.FileType("r"))