Skip to content

Commit

Permalink
Merge latest from main.
Browse files Browse the repository at this point in the history
  • Loading branch information
kgpai committed Apr 1, 2024
1 parent 73a7f71 commit 3a461c0
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ if(${VELOX_BUILD_PYTHON_PACKAGE})
set(VELOX_ENABLE_EXPRESSION ON)
set(VELOX_ENABLE_PARSE ON)
set(VELOX_ENABLE_EXEC ON)
set(VELOX_ENABLE_AGGREGATES ON)
set(VELOX_ENABLE_SPARK_FUNCTIONS ON)
endif()

Expand Down
4 changes: 3 additions & 1 deletion pyvelox/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ if(VELOX_BUILD_PYTHON_PACKAGE)
velox_functions_prestosql
velox_parse_parser
velox_functions_prestosql
velox_functions_spark)
velox_functions_spark
velox_aggregates
velox_functions_spark_aggregates)

target_include_directories(pyvelox SYSTEM
PRIVATE ${CMAKE_CURRENT_LIST_DIR}/..)
Expand Down
57 changes: 57 additions & 0 deletions pyvelox/signatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
*/

#include "signatures.h" // @manual
#include "velox/exec/Aggregate.h"
#include "velox/functions/FunctionRegistry.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/sparksql/Register.h"
#include "velox/functions/sparksql/aggregates/Register.h"

namespace facebook::velox::py {

Expand All @@ -31,6 +34,24 @@ void registerSparkFunctions(const std::string& prefix) {
facebook::velox::functions::sparksql::registerFunctions(prefix);
}

void registerPrestoAggregateFunctions(const std::string& prefix) {
facebook::velox::aggregate::prestosql::registerAllAggregateFunctions(prefix);
}

void registerSparkAggregateFunctions(const std::string& prefix) {
facebook::velox::functions::aggregate::sparksql::registerAggregateFunctions(
prefix);
}

exec::AggregateFunctionSignatureMap getAggregateSignatures() {
return exec::getAggregateFunctionSignatures();
}

void clearAggregateSignatures() {
exec::aggregateFunctions().withWLock(
[&](auto& aggregateFunctions) { aggregateFunctions.clear(); });
}

void addSignatureBindings(py::module& m, bool asModuleLocalDefinitions) {
// TypeSignature
py::class_<exec::TypeSignature> typeSignature(
Expand All @@ -53,6 +74,19 @@ void addSignatureBindings(py::module& m, bool asModuleLocalDefinitions) {
functionSignature.def(
"constant_arguments", &exec::FunctionSignature::constantArguments);

// AggregateFunctionSignature
py::class_<
exec::AggregateFunctionSignature,
std::unique_ptr<exec::AggregateFunctionSignature, py::nodelete>>
aggregateFunctionSignature(
m,
"AggregateFunctionSignature",
py::module_local(asModuleLocalDefinitions));
aggregateFunctionSignature.def(
"__str__", &exec::AggregateFunctionSignature::toString);
aggregateFunctionSignature.def(
"intermediate_type", &exec::AggregateFunctionSignature::intermediateType);

m.def(
"clear_signatures",
&clearFunctionRegistry,
Expand All @@ -75,5 +109,28 @@ void addSignatureBindings(py::module& m, bool asModuleLocalDefinitions) {
&getFunctionSignatures,
py::return_value_policy::reference,
"Returns a dictionary of the current signatures.");

m.def(
"register_presto_aggregate_signatures",
&registerPrestoAggregateFunctions,
"Adds Presto Aggregate signatures to the function registry.",
py::arg("prefix") = "");

m.def(
"register_spark_aggregate_signatures",
&registerSparkAggregateFunctions,
"Adds Spark Aggregate signatures to the function registry.",
py::arg("prefix") = "");

m.def(
"get_aggregate_function_signatures",
&getAggregateSignatures,
py::return_value_policy::reference,
"Returns a dictionary of the current aggregate signatures.");

m.def(
"clear_aggregate_signatures",
&clearAggregateSignatures,
"Clears the Aggregate function registry.");
}
} // namespace facebook::velox::py
18 changes: 18 additions & 0 deletions pyvelox/test/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,21 @@ def test_function_prefix(self):

concat_signatures = spark_signatures["barconcat"]
self.assertTrue(len(concat_signatures) > 0)

def test_aggregate_signatures(self):
pv.clear_aggregate_signatures()

pv.register_presto_aggregate_signatures()
presto_agg_signatures = pv.get_aggregate_function_signatures()

min_signatures = presto_agg_signatures["min"]
self.assertTrue(len(min_signatures) > 0)

max_signatures = presto_agg_signatures["max"]
self.assertTrue(len(max_signatures) > 0)

pv.clear_aggregate_signatures()

pv.register_spark_aggregate_signatures()
spark_agg_signatures = pv.get_aggregate_function_signatures()
self.assertTrue(len(spark_agg_signatures) > 0)
29 changes: 29 additions & 0 deletions scripts/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ def export(args):
return 0


def export_aggregates(args):
"""Exports Velox Aggregate function signatures."""
pv.clear_aggregate_signatures()

if args.spark:
pv.register_spark_aggregate_signatures()

if args.presto:
pv.register_presto_aggregate_signatures()

signatures = pv.get_aggregate_function_signatures()

# Convert signatures to json
jsoned_signatures = {}
for key in signatures.keys():
jsoned_signatures[key] = [str(value) for value in signatures[key]]

# Persist to file
json.dump(jsoned_signatures, args.output_file)
return 0


def diff_signatures(base_signatures, contender_signatures, error_path=""):
"""Diffs Velox function signatures. Returns a tuple of the delta diff and exit status"""

Expand Down Expand Up @@ -253,6 +275,13 @@ def parse_args(args):
export_command_parser.add_argument("--presto", action="store_true")
export_command_parser.add_argument("output_file", type=str)

export_aggregates_command_parser = command.add_parser("export_aggregates")
export_aggregates_command_parser.add_argument("--spark", action="store_true")
export_aggregates_command_parser.add_argument("--presto", action="store_true")
export_aggregates_command_parser.add_argument(
"output_file", type=argparse.FileType("w")
)

diff_command_parser = command.add_parser("diff")
diff_command_parser.add_argument("base", type=str)
diff_command_parser.add_argument("contender", type=str)
Expand Down

0 comments on commit 3a461c0

Please sign in to comment.