From 702e9cae138e8d372643e3545375d8fc4c6212a2 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 22 Aug 2023 09:27:38 -0700 Subject: [PATCH] GH-33985: [C++] Add substrait serialization/deserialization for expressions (#34834) ### Rationale for this change Substrait provides a library-independent way to represent compute expressions. By serializing and deserializing pyarrow compute expression to substrait we can allow interoperability with other libraries. Originally it was thought this would not be needed because users would be sending entire query plans (which contain expressions) back and forth and so there was no need to work with expressions by themselves. However, as more and more APIs and integration points emerge it turns out there are situations where serializing expressions by themselves is useful. For example, the proposed datasets protocol, or for the Java JNI datasets implementation (which uses Arrow-C++'s datasets) ### What changes are included in this PR? In Arrow-C++ we add two new methods to serialize and deserialize a collection of named, bound expressions to Substrait's ExtendedExpression message. In pyarrow we expose these two methods and also add utility methods to pyarrow.compute.Expression to convert a single expression to/from substrait (these will be encoded as an ExtendedExpression message with one expression named "expression") In addition, this PR exposed that we do not have very many bindings for arrow-functions to substrait-functions (previous work has mostly focused on the reverse direction). This PR adds many (though not all) new bindings. In addition, this PR adds ToProto for cast and both FromProto and ToProto support for the SingularOrList expression type (we convert is_in to SingularOrList and convert SingularOrList to an or list). This should provide support for all the sargable operators except between (there is no Arrow-C++ between function) and like (we still don't have arrow->substrait bindings for the string functions) which should be a sufficient set of expressions for a first release. ### Are these changes tested? Yes. ### Are there any user-facing changes? There are new features, as described above, but no backwards incompatible changes. ### Caveats There are a fair number of minor inconsistencies or surprises, many of which can be smoothed over by follow-up work. #### Bound Expressions Arrow-C++ has long had a distinction between "unbound expressions" (e.g. `a + b`) and "bound expressions" (e.g. `a:i32 + b:i32`). A bound expression is an expression that has been bound to a schema of some kind. Field references are resolved and the output type is known for every node of the AST. Pyarrow has hidden this complexity and most pyarrow compute expressions that the user encounters will be unbound expressions. Substrait is only capable (currently) of representing bound expressions. As a result, in order to serialize expressions, the user will need to provide an input schema. This can be an inconvenience for some workflows. To resolve this, I would like to eventually add support for unbound expressions to Substrait (https://github.com/substrait-io/substrait/issues/515) Another minor annoyance of bound expressions is that an unbound pyarrow.compute.Expression object will not be equal to a bound pyarrow.compute.Expression object. It would make testing easier if we had a `pyarrow.compute.Expression.equals` variant that did not examine bound fields. #### Named field references Pyarrow datasets users are used to working with named field references. For example, one can set a filter `pc.equal(ds.field("x"), 7)`. Substrait, since it requires everything to be bound, considers named references to be superfluous and does everything in terms of numeric indices into the base schema. So the above expression, after round tripping, would become something like `pc.equal(ds.field(3), 7)` (assuming `"x"` is at index `3` in the schema used for serialization). This is something that can be overcome in the future if Substrait adds support for unbound expressions. Or, if that doesn't happen, it could still be implemented as a Substrait expression hint (this would allow named references to be used even if the user wants to work with bound expressions). #### UDFs UDFs ARE supported by this PR. This covers both "builtin arrow functions that do not exist in substrait (e.g. shift_left)" and "custom UDFs added with `register_scalar_function`". By default, UDFs will not be allowed when converting to Substrait because the resulting message would not be portable (e.g. you can't expect an external system to know about your custom UDFs). However, you can set the `allow_udfs` flag to True and these will be allowed. The Substrait representation will have the URI `urn:arrow:substrait_simple_extension_function`. **Options**: Although UDFs are allowed we do not yet support UDFs that take function options. These are trickier to convert to Substrait (though it should be possible in the future if someone is motivated enough). #### Rough Edges There are a few corner cases: * The function `is_in` converts to Substrait's `SingularOrList`. On conversion back to Arrow this becomes an or list. In other words, the function `is_in(5, [1, 2, 5])` converts to `5 == 1 || 5 == 2 || 5 == 5`. This is because Substrait's or list is more expression and allows things like `5 == field_ref(0) || 5 == 7` which cannot be expressed as an `is_in` function. * Arrow functions can either be converted to Substrait or are considered UDFs. However, there are a small number of functions which can "sometimes" be converted to Substrait depending on the function options. At the moment I think this is only the `is_null` function. The `is_null` function has an option `nan_is_null` which will allow you to consider `NaN` as a null value. Substrait has no single function that evaluates both `NULL` and `NaN` as true. In the meantime you can use `is_null || is_nan`. In the future, should someone want to, they could add special logic to convert this case. * Closes: #33985 Lead-authored-by: Weston Pace Co-authored-by: Joris Van den Bossche Signed-off-by: Weston Pace --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 7 +- cpp/src/arrow/compute/cast.cc | 10 + cpp/src/arrow/compute/cast.h | 9 + cpp/src/arrow/engine/CMakeLists.txt | 1 + .../engine/substrait/expression_internal.cc | 162 +++++++++++++- .../substrait/extended_expression_internal.cc | 210 ++++++++++++++++++ .../substrait/extended_expression_internal.h | 51 +++++ .../arrow/engine/substrait/extension_set.cc | 112 ++++++++-- .../arrow/engine/substrait/extension_set.h | 12 + cpp/src/arrow/engine/substrait/options.h | 8 +- .../arrow/engine/substrait/plan_internal.cc | 114 +--------- cpp/src/arrow/engine/substrait/relation.h | 17 ++ cpp/src/arrow/engine/substrait/serde.cc | 42 ++-- cpp/src/arrow/engine/substrait/serde.h | 28 +++ cpp/src/arrow/engine/substrait/serde_test.cc | 95 ++++++++ cpp/src/arrow/engine/substrait/util.cc | 10 + cpp/src/arrow/engine/substrait/util.h | 10 + .../arrow/engine/substrait/util_internal.cc | 12 + .../arrow/engine/substrait/util_internal.h | 114 ++++++++++ cpp/thirdparty/versions.txt | 4 +- docs/source/python/api/substrait.rst | 13 ++ python/pyarrow/_compute.pyx | 69 ++++++ python/pyarrow/_substrait.pyx | 141 +++++++++++- .../pyarrow/includes/libarrow_substrait.pxd | 18 ++ python/pyarrow/substrait.py | 3 + python/pyarrow/tests/test_compute.py | 142 +++++++++++- python/pyarrow/tests/test_substrait.py | 107 ++++++++- 27 files changed, 1363 insertions(+), 158 deletions(-) create mode 100644 cpp/src/arrow/engine/substrait/extended_expression_internal.cc create mode 100644 cpp/src/arrow/engine/substrait/extended_expression_internal.h diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 4422e17e85cb6..1767c05b5ee3a 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1785,7 +1785,12 @@ macro(build_substrait) # Note: not all protos in Substrait actually matter to plan # consumption. No need to build the ones we don't need. - set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type) + set(SUBSTRAIT_PROTOS + algebra + extended_expression + extensions/extensions + plan + type) set(ARROW_SUBSTRAIT_PROTOS extension_rels) set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto") diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 13bf6f85a4874..232638b7fc738 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -223,6 +223,16 @@ CastOptions::CastOptions(bool safe) allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} +bool CastOptions::is_safe() const { + return !allow_int_overflow && !allow_time_truncate && !allow_time_overflow && + !allow_decimal_truncate && !allow_float_truncate && !allow_invalid_utf8; +} + +bool CastOptions::is_unsafe() const { + return allow_int_overflow && allow_time_truncate && allow_time_overflow && + allow_decimal_truncate && allow_float_truncate && allow_invalid_utf8; +} + constexpr char CastOptions::kTypeName[]; Result Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 7432933a12469..613e8a55addd2 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -69,6 +69,15 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { // Indicate if conversions from Binary/FixedSizeBinary to string must // validate the utf8 payload. bool allow_invalid_utf8; + + /// true if the safety options all match CastOptions::Safe + /// + /// Note, if this returns false it does not mean is_unsafe will return true + bool is_safe() const; + /// true if the safety options all match CastOptions::Unsafe + /// + /// Note, if this returns false it does not mean is_safe will return true + bool is_unsafe() const; }; /// @} diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index 7494be8ebb19c..fcaa242b11487 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -21,6 +21,7 @@ arrow_install_all_headers("arrow/engine") set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc + substrait/extended_expression_internal.cc substrait/extension_set.cc substrait/extension_types.cc substrait/options.cc diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 5e214bdda4d4a..0df8425609ff1 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -41,6 +41,7 @@ #include "arrow/buffer.h" #include "arrow/builder.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/expression.h" #include "arrow/compute/expression_internal.h" #include "arrow/engine/substrait/extension_set.h" @@ -338,6 +339,22 @@ Result FromProto(const substrait::Expression& expr, return compute::call("case_when", std::move(args)); } + case substrait::Expression::kSingularOrList: { + const auto& or_list = expr.singular_or_list(); + + ARROW_ASSIGN_OR_RAISE(compute::Expression value, + FromProto(or_list.value(), ext_set, conversion_options)); + + std::vector option_eqs; + for (const auto& option : or_list.options()) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arrow_option, + FromProto(option, ext_set, conversion_options)); + option_eqs.push_back(compute::call("equal", {value, arrow_option})); + } + + return compute::or_(option_eqs); + } + case substrait::Expression::kScalarFunction: { const auto& scalar_fn = expr.scalar_function(); @@ -1055,9 +1072,68 @@ Result> EncodeSubstraitCa " arguments but no argument could be found at index ", i); } } + + for (const auto& option : call.options()) { + substrait::FunctionOption* fn_option = scalar_fn->add_options(); + fn_option->set_name(option.first); + for (const auto& opt_val : option.second) { + std::string* pref = fn_option->add_preference(); + *pref = opt_val; + } + } + return std::move(scalar_fn); } +Result>> DatumToLiterals( + const Datum& datum, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + std::vector> literals; + + auto ScalarToLiteralExpr = [&](const std::shared_ptr& scalar) + -> Result> { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr literal, + ToProto(scalar, ext_set, conversion_options)); + auto literal_expr = std::make_unique(); + literal_expr->set_allocated_literal(literal.release()); + return literal_expr; + }; + + switch (datum.kind()) { + case Datum::Kind::SCALAR: { + ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(datum.scalar())); + literals.push_back(std::move(literal_expr)); + break; + } + case Datum::Kind::ARRAY: { + std::shared_ptr values = datum.make_array(); + for (int64_t i = 0; i < values->length(); i++) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, values->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar)); + literals.push_back(std::move(literal_expr)); + } + break; + } + case Datum::Kind::CHUNKED_ARRAY: { + std::shared_ptr values = datum.chunked_array(); + for (const auto& chunk : values->chunks()) { + for (int64_t i = 0; i < chunk->length(); i++) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, chunk->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar)); + literals.push_back(std::move(literal_expr)); + } + } + break; + } + case Datum::Kind::RECORD_BATCH: + case Datum::Kind::TABLE: + case Datum::Kind::NONE: + return Status::Invalid("Expected a literal or an array of literals, got ", + datum.ToString()); + } + return literals; +} + Result> ToProto( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { @@ -1164,15 +1240,89 @@ Result> ToProto( out->set_allocated_if_then(if_then.release()); return std::move(out); + } else if (call->function_name == "cast") { + auto cast = std::make_unique(); + + // Arrow's cast function does not have a "return null" option and so throw exception + // is the only behavior we can support. + cast->set_failure_behavior( + substrait::Expression::Cast::FAILURE_BEHAVIOR_THROW_EXCEPTION); + + std::shared_ptr cast_options = + internal::checked_pointer_cast(call->options); + if (!cast_options->is_unsafe()) { + return Status::Invalid("Substrait is only capable of representing unsafe casts"); + } + + if (arguments.size() != 1) { + return Status::Invalid( + "A call to the cast function must have exactly one argument"); + } + + cast->set_allocated_input(arguments[0].release()); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr to_type, + ToProto(*cast_options->to_type.type, /*nullable=*/true, ext_set, + conversion_options)); + + cast->set_allocated_type(to_type.release()); + + out->set_allocated_cast(cast.release()); + return std::move(out); + } else if (call->function_name == "is_in") { + auto or_list = std::make_unique(); + + if (arguments.size() != 1) { + return Status::Invalid( + "A call to the is_in function must have exactly one argument"); + } + + or_list->set_allocated_value(arguments[0].release()); + std::shared_ptr is_in_options = + internal::checked_pointer_cast(call->options); + + // TODO(GH-36420) Acero does not currently handle nulls correctly + ARROW_ASSIGN_OR_RAISE( + std::vector> options, + DatumToLiterals(is_in_options->value_set, ext_set, conversion_options)); + for (auto& option : options) { + or_list->mutable_options()->AddAllocated(option.release()); + } + out->set_allocated_singular_or_list(or_list.release()); + return std::move(out); } // other expression types dive into extensions immediately - ARROW_ASSIGN_OR_RAISE( - ExtensionIdRegistry::ArrowToSubstraitCall converter, - ext_set->registry()->GetArrowToSubstraitCall(call->function_name)); - ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr scalar_fn, - EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + Result maybe_converter = + ext_set->registry()->GetArrowToSubstraitCall(call->function_name); + + ExtensionIdRegistry::ArrowToSubstraitCall converter; + std::unique_ptr scalar_fn; + if (maybe_converter.ok()) { + converter = *maybe_converter; + ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); + ARROW_ASSIGN_OR_RAISE( + scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + } else if (maybe_converter.status().IsNotImplemented() && + conversion_options.allow_arrow_extensions) { + if (call->options) { + return Status::NotImplemented( + "The function ", call->function_name, + " has no Substrait mapping. Arrow extensions are enabled but the call " + "contains function options and there is no current mechanism to encode those."); + } + Id persistent_id = ext_set->RegisterPlanSpecificId( + {kArrowSimpleExtensionFunctionsUri, call->function_name}); + SubstraitCall substrait_call(persistent_id, call->type.GetSharedPtr(), + /*nullable=*/true); + for (int i = 0; i < static_cast(call->arguments.size()); i++) { + substrait_call.SetValueArg(i, call->arguments[i]); + } + ARROW_ASSIGN_OR_RAISE( + scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + } else { + return maybe_converter.status(); + } out->set_allocated_scalar_function(scalar_fn.release()); return std::move(out); } diff --git a/cpp/src/arrow/engine/substrait/extended_expression_internal.cc b/cpp/src/arrow/engine/substrait/extended_expression_internal.cc new file mode 100644 index 0000000000000..a6401e1d0b36d --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extended_expression_internal.cc @@ -0,0 +1,210 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#include "arrow/engine/substrait/extended_expression_internal.h" + +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/relation_internal.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/engine/substrait/util.h" +#include "arrow/engine/substrait/util_internal.h" +#include "arrow/status.h" +#include "arrow/util/iterator.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace engine { + +namespace { +Result GetExtensionSetFromExtendedExpression( + const substrait::ExtendedExpression& expr, + const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { + return GetExtensionSetFromMessage(expr, conversion_options, registry); +} + +Status AddExtensionSetToExtendedExpression(const ExtensionSet& ext_set, + substrait::ExtendedExpression* expr) { + return AddExtensionSetToMessage(ext_set, expr); +} + +Status VisitNestedFields(const DataType& type, + std::function visitor) { + if (!is_nested(type.id())) { + return Status::OK(); + } + for (const auto& field : type.fields()) { + ARROW_RETURN_NOT_OK(VisitNestedFields(*field->type(), visitor)); + ARROW_RETURN_NOT_OK(visitor(*field)); + } + return Status::OK(); +} + +Result ExpressionFromProto( + const substrait::ExpressionReference& expression, const Schema& input_schema, + const ExtensionSet& ext_set, const ConversionOptions& conversion_options, + const ExtensionIdRegistry* registry) { + NamedExpression named_expr; + switch (expression.expr_type_case()) { + case substrait::ExpressionReference::ExprTypeCase::kExpression: { + ARROW_ASSIGN_OR_RAISE( + named_expr.expression, + FromProto(expression.expression(), ext_set, conversion_options)); + break; + } + case substrait::ExpressionReference::ExprTypeCase::kMeasure: { + return Status::NotImplemented("ExtendedExpression containing aggregate functions"); + } + default: { + return Status::Invalid( + "Unrecognized substrait::ExpressionReference::ExprTypeCase: ", + expression.expr_type_case()); + } + } + + ARROW_ASSIGN_OR_RAISE(named_expr.expression, named_expr.expression.Bind(input_schema)); + const DataType& output_type = *named_expr.expression.type(); + + // An expression reference has the entire DFS tree of field names for the output type + // which is usually redundant. Then it has one extra name for the name of the + // expression which is not redundant. + // + // For example, if the base schema is [struct, i32] and the expression is + // field(0) the the extended expression output names might be ["foo", "my_expression"]. + // The "foo" is redundant but we can verify it matches and reject if it does not. + // + // The one exception is struct literals which have no field names. For example, if + // the base schema is [i32, i64] and the expression is {7, 3}_struct then the + // output type is struct and we do not know the names of the output type. + // + // TODO(weston) we could patch the names back in at this point using the output + // names field but this is rather complex and it might be easier to give names to + // struct literals in Substrait. + int output_name_idx = 0; + ARROW_RETURN_NOT_OK(VisitNestedFields(output_type, [&](const Field& field) { + if (output_name_idx >= expression.output_names_size()) { + return Status::Invalid("Ambiguous plan. Expression had ", + expression.output_names_size(), + " output names but the field in base_schema had type ", + output_type.ToString(), " which needs more output names"); + } + if (!field.name().empty() && + field.name() != expression.output_names(output_name_idx)) { + return Status::Invalid("Ambiguous plan. Expression had output type ", + output_type.ToString(), + " which contains a nested field named ", field.name(), + " but the output_names in the Substrait message contained ", + expression.output_names(output_name_idx)); + } + output_name_idx++; + return Status::OK(); + })); + // The last name is the actual field name that we can't verify but there should only + // be one extra name. + if (output_name_idx < expression.output_names_size() - 1) { + return Status::Invalid("Ambiguous plan. Expression had ", + expression.output_names_size(), + " output names but the field in base_schema had type ", + output_type.ToString(), " which doesn't have enough fields"); + } + if (expression.output_names_size() == 0) { + // This is potentially invalid substrait but we can handle it + named_expr.name = ""; + } else { + named_expr.name = expression.output_names(expression.output_names_size() - 1); + } + return named_expr; +} + +Result> CreateExpressionReference( + const std::string& name, const Expression& expr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto expr_ref = std::make_unique(); + ARROW_RETURN_NOT_OK(VisitNestedFields(*expr.type(), [&](const Field& field) { + expr_ref->add_output_names(field.name()); + return Status::OK(); + })); + expr_ref->add_output_names(name); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr expression, + ToProto(expr, ext_set, conversion_options)); + expr_ref->set_allocated_expression(expression.release()); + return std::move(expr_ref); +} + +} // namespace + +Result FromProto(const substrait::ExtendedExpression& expression, + ExtensionSet* ext_set_out, + const ConversionOptions& conversion_options, + const ExtensionIdRegistry* registry) { + BoundExpressions bound_expressions; + ARROW_RETURN_NOT_OK(CheckVersion(expression.version().major_number(), + expression.version().minor_number())); + if (expression.has_advanced_extensions()) { + return Status::NotImplemented("Advanced extensions in ExtendedExpression"); + } + ARROW_ASSIGN_OR_RAISE( + ExtensionSet ext_set, + GetExtensionSetFromExtendedExpression(expression, conversion_options, registry)); + + ARROW_ASSIGN_OR_RAISE(bound_expressions.schema, + FromProto(expression.base_schema(), ext_set, conversion_options)); + + bound_expressions.named_expressions.reserve(expression.referred_expr_size()); + + for (const auto& referred_expr : expression.referred_expr()) { + ARROW_ASSIGN_OR_RAISE(NamedExpression named_expr, + ExpressionFromProto(referred_expr, *bound_expressions.schema, + ext_set, conversion_options, registry)); + bound_expressions.named_expressions.push_back(std::move(named_expr)); + } + + if (ext_set_out) { + *ext_set_out = std::move(ext_set); + } + + return std::move(bound_expressions); +} + +Result> ToProto( + const BoundExpressions& bound_expressions, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto expression = std::make_unique(); + expression->set_allocated_version(CreateVersion().release()); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr base_schema, + ToProto(*bound_expressions.schema, ext_set, conversion_options)); + expression->set_allocated_base_schema(base_schema.release()); + for (const auto& named_expression : bound_expressions.named_expressions) { + Expression bound_expr = named_expression.expression; + if (!bound_expr.IsBound()) { + // This will use the default function registry. Most of the time that will be fine. + // In the cases where this is not what the user wants then the user should make sure + // to pass in bound expressions. + ARROW_ASSIGN_OR_RAISE(bound_expr, bound_expr.Bind(*bound_expressions.schema)); + } + ARROW_ASSIGN_OR_RAISE(std::unique_ptr expr_ref, + CreateExpressionReference(named_expression.name, bound_expr, + ext_set, conversion_options)); + expression->mutable_referred_expr()->AddAllocated(expr_ref.release()); + } + RETURN_NOT_OK(AddExtensionSetToExtendedExpression(*ext_set, expression.get())); + return std::move(expression); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extended_expression_internal.h b/cpp/src/arrow/engine/substrait/extended_expression_internal.h new file mode 100644 index 0000000000000..81bc4b8745186 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extended_expression_internal.h @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" +#include "arrow/engine/substrait/relation.h" +#include "arrow/engine/substrait/visibility.h" +#include "arrow/result.h" +#include "arrow/status.h" + +#include "substrait/extended_expression.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +/// Convert a Substrait ExtendedExpression to a vector of expressions and output names +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::ExtendedExpression& expression, + ExtensionSet* ext_set_out, + const ConversionOptions& conversion_options, + const ExtensionIdRegistry* extension_id_registry); + +/// Convert a vector of expressions to a Substrait ExtendedExpression +ARROW_ENGINE_EXPORT +Result> ToProto( + const BoundExpressions& bound_expressions, ExtensionSet* ext_set, + const ConversionOptions& conversion_options); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index d89248383b722..b0dd6aeffbcfa 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -28,12 +28,15 @@ #include "arrow/engine/substrait/options.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" #include "arrow/util/string.h" namespace arrow { + +using internal::checked_pointer_cast; namespace engine { namespace { @@ -229,6 +232,8 @@ Status ExtensionSet::AddUri(Id id) { return Status::OK(); } +Id ExtensionSet::RegisterPlanSpecificId(Id id) { return plan_specific_ids_->Emplace(id); } + // Creates an extension set from the Substrait plan's top-level extensions block Result ExtensionSet::Make( std::unordered_map uris, @@ -873,11 +878,10 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic }; } -ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrait_fn_id) { +ExtensionIdRegistry::ArrowToSubstraitCall EncodeBasic(Id substrait_fn_id) { return [substrait_fn_id](const compute::Expression::Call& call) -> Result { - // nullable=true isn't quite correct but we don't know the nullability of - // the inputs + // nullable=true errs on the side of caution SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), /*nullable=*/true); for (std::size_t i = 0; i < call.arguments.size(); i++) { @@ -887,11 +891,31 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai }; } +ExtensionIdRegistry::ArrowToSubstraitCall EncodeIsNull(Id substrait_fn_id) { + return + [substrait_fn_id](const compute::Expression::Call& call) -> Result { + if (call.options != nullptr) { + auto null_opts = checked_pointer_cast(call.options); + if (null_opts->nan_is_null) { + return Status::Invalid( + "Substrait does not support is_null with nan_is_null=true. You can use " + "is_null || is_nan instead"); + } + } + SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), + /*nullable=*/false); + for (std::size_t i = 0; i < call.arguments.size(); i++) { + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); + } + return std::move(substrait_call); + }; +} + ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( const std::string& function_name, int max_args) { return [function_name, max_args](const SubstraitCall& call) -> Result { - if (call.size() > max_args) { + if (max_args >= 0 && call.size() > max_args) { return Status::NotImplemented("Acero does not have a kernel for ", function_name, " that receives ", call.size(), " arguments"); } @@ -1033,14 +1057,15 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { // -------------- Substrait -> Arrow Functions ----------------- // Mappings with a _checked variant for (const auto& function_name : - {"add", "subtract", "multiply", "divide", "power", "sqrt", "abs"}) { + {"add", "subtract", "multiply", "divide", "negate", "power", "sqrt", "abs"}) { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name}, DecodeOptionlessOverflowableArithmetic(function_name))); } - // Mappings without a _checked variant - for (const auto& function_name : {"exp", "sign"}) { + // Mappings either without a _checked variant or substrait has no overflow option + for (const auto& function_name : + {"exp", "sign", "cos", "sin", "tan", "acos", "asin", "atan", "atan2"}) { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name}, DecodeOptionlessUncheckedArithmetic(function_name))); @@ -1096,6 +1121,21 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitBooleanFunctionsUri, "not"}, DecodeOptionlessBasicMapping("invert", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_not"}, + DecodeOptionlessBasicMapping("bit_wise_not", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_or"}, + DecodeOptionlessBasicMapping("bit_wise_or", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_and"}, + DecodeOptionlessBasicMapping("bit_wise_and", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_xor"}, + DecodeOptionlessBasicMapping("bit_wise_xor", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitComparisonFunctionsUri, "coalesce"}, + DecodeOptionlessBasicMapping("coalesce", /*max_args=*/-1))); DCHECK_OK(AddSubstraitCallToArrow({kSubstraitDatetimeFunctionsUri, "extract"}, DecodeTemporalExtractionMapping())); DCHECK_OK(AddSubstraitCallToArrow({kSubstraitStringFunctionsUri, "concat"}, @@ -1103,6 +1143,12 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitComparisonFunctionsUri, "is_null"}, DecodeOptionlessBasicMapping("is_null", /*max_args=*/1))); + DCHECK_OK( + AddSubstraitCallToArrow({kSubstraitComparisonFunctionsUri, "is_nan"}, + DecodeOptionlessBasicMapping("is_nan", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitComparisonFunctionsUri, "is_finite"}, + DecodeOptionlessBasicMapping("is_finite", /*max_args=*/1))); DCHECK_OK(AddSubstraitCallToArrow( {kSubstraitComparisonFunctionsUri, "is_not_null"}, DecodeOptionlessBasicMapping("is_valid", /*max_args=*/1))); @@ -1127,7 +1173,9 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { } // --------------- Arrow -> Substrait Functions --------------- - for (const auto& fn_name : {"add", "subtract", "multiply", "divide"}) { + // Functions with a _checked variant + for (const auto& fn_name : + {"add", "subtract", "multiply", "divide", "negate", "power", "abs"}) { Id fn_id{kSubstraitArithmeticFunctionsUri, fn_name}; DCHECK_OK(AddArrowToSubstraitCall( fn_name, EncodeOptionlessOverflowableArithmetic(fn_id))); @@ -1135,11 +1183,49 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { AddArrowToSubstraitCall(std::string(fn_name) + "_checked", EncodeOptionlessOverflowableArithmetic(fn_id))); } - // Comparison operators - for (const auto& fn_name : {"equal", "is_not_distinct_from"}) { - Id fn_id{kSubstraitComparisonFunctionsUri, fn_name}; - DCHECK_OK(AddArrowToSubstraitCall(fn_name, EncodeOptionlessComparison(fn_id))); - } + // Functions with no options... + // ...and the same name + for (const auto& fn_pair : std::vector>{ + {kSubstraitComparisonFunctionsUri, "equal"}, + {kSubstraitComparisonFunctionsUri, "not_equal"}, + {kSubstraitComparisonFunctionsUri, "is_not_distinct_from"}, + {kSubstraitComparisonFunctionsUri, "is_nan"}, + {kSubstraitComparisonFunctionsUri, "is_finite"}, + {kSubstraitComparisonFunctionsUri, "coalesce"}, + {kSubstraitArithmeticFunctionsUri, "sqrt"}, + {kSubstraitArithmeticFunctionsUri, "sign"}, + {kSubstraitArithmeticFunctionsUri, "exp"}, + {kSubstraitArithmeticFunctionsUri, "cos"}, + {kSubstraitArithmeticFunctionsUri, "sin"}, + {kSubstraitArithmeticFunctionsUri, "tan"}, + {kSubstraitArithmeticFunctionsUri, "acos"}, + {kSubstraitArithmeticFunctionsUri, "asin"}, + {kSubstraitArithmeticFunctionsUri, "atan"}, + {kSubstraitArithmeticFunctionsUri, "atan2"}}) { + Id fn_id{fn_pair.first, fn_pair.second}; + DCHECK_OK(AddArrowToSubstraitCall(std::string(fn_pair.second), EncodeBasic(fn_id))); + } + // ...and different names + for (const auto& fn_triple : + std::vector>{ + {kSubstraitComparisonFunctionsUri, "lt", "less"}, + {kSubstraitComparisonFunctionsUri, "gt", "greater"}, + {kSubstraitComparisonFunctionsUri, "lte", "less_equal"}, + {kSubstraitComparisonFunctionsUri, "gte", "greater_equal"}, + {kSubstraitComparisonFunctionsUri, "is_not_null", "is_valid"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_and", "bit_wise_and"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_not", "bit_wise_not"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_or", "bit_wise_or"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_xor", "bit_wise_xor"}, + {kSubstraitBooleanFunctionsUri, "and", "and_kleene"}, + {kSubstraitBooleanFunctionsUri, "or", "or_kleene"}, + {kSubstraitBooleanFunctionsUri, "not", "invert"}}) { + Id fn_id{std::get<0>(fn_triple), std::get<1>(fn_triple)}; + DCHECK_OK(AddArrowToSubstraitCall(std::get<2>(fn_triple), EncodeBasic(fn_id))); + } + + DCHECK_OK(AddArrowToSubstraitCall( + "is_null", EncodeIsNull({kSubstraitComparisonFunctionsUri, "is_null"}))); } }; diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 50e0b11943f61..d9c0af081a546 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -131,6 +131,9 @@ class ARROW_ENGINE_EXPORT SubstraitCall { const std::shared_ptr& output_type() const { return output_type_; } bool output_nullable() const { return output_nullable_; } bool is_hash() const { return is_hash_; } + const std::unordered_map>& options() const { + return options_; + } bool HasEnumArg(int index) const; Result GetEnumArg(int index) const; @@ -427,6 +430,15 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// \return An anchor that can be used to refer to the function within a plan Result EncodeFunction(Id function_id); + /// \brief Stores a plan-specific id that is not known to the registry + /// + /// This is used when converting an Arrow execution plan to a Substrait plan. + /// + /// If the function is a UDF, something that wasn't known to the registry, + /// then we need long term storage of the function name (the ids are just + /// views) + Id RegisterPlanSpecificId(Id id); + /// \brief Return the number of custom functions in this extension set std::size_t num_functions() const { return functions_.size(); } diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 0d66c5eea43e5..1e6f6efb2c751 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -106,7 +106,8 @@ struct ARROW_ENGINE_EXPORT ConversionOptions { : strictness(ConversionStrictness::BEST_EFFORT), named_table_provider(kDefaultNamedTableProvider), named_tap_provider(default_named_tap_provider()), - extension_provider(default_extension_provider()) {} + extension_provider(default_extension_provider()), + allow_arrow_extensions(false) {} /// \brief How strictly the converter should adhere to the structure of the input. ConversionStrictness strictness; @@ -123,6 +124,11 @@ struct ARROW_ENGINE_EXPORT ConversionOptions { /// /// The default behavior will provide for relations known to Arrow. std::shared_ptr extension_provider; + /// \brief If true then Arrow-specific types and functions will be allowed + /// + /// Set to false to create plans that are more likely to be compatible with non-Arrow + /// engines + bool allow_arrow_extensions; }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index ecee81e25ff53..cc4806878c404 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -25,9 +25,10 @@ #include #include "arrow/compute/type_fwd.h" -#include "arrow/config.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/util.h" +#include "arrow/engine/substrait/util_internal.h" #include "arrow/result.h" #include "arrow/util/checked_cast.h" #include "arrow/util/hashing.h" @@ -43,122 +44,15 @@ using internal::checked_cast; namespace engine { Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { - plan->clear_extension_uris(); - - std::unordered_map map; - - auto uris = plan->mutable_extension_uris(); - uris->Reserve(static_cast(ext_set.uris().size())); - for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) { - auto uri = ext_set.uris().at(anchor); - if (uri.empty()) continue; - - auto ext_uri = std::make_unique(); - ext_uri->set_uri(std::string(uri)); - ext_uri->set_extension_uri_anchor(anchor); - uris->AddAllocated(ext_uri.release()); - - map[uri] = anchor; - } - - auto extensions = plan->mutable_extensions(); - extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); - - using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; - - for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { - ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); - if (type_record.id.empty()) continue; - - auto ext_decl = std::make_unique(); - - auto type = std::make_unique(); - type->set_extension_uri_reference(map[type_record.id.uri]); - type->set_type_anchor(anchor); - type->set_name(std::string(type_record.id.name)); - ext_decl->set_allocated_extension_type(type.release()); - extensions->AddAllocated(ext_decl.release()); - } - - for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) { - ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor)); - - auto fn = std::make_unique(); - fn->set_extension_uri_reference(map[function_id.uri]); - fn->set_function_anchor(anchor); - fn->set_name(std::string(function_id.name)); - - auto ext_decl = std::make_unique(); - ext_decl->set_allocated_extension_function(fn.release()); - extensions->AddAllocated(ext_decl.release()); - } - - return Status::OK(); + return AddExtensionSetToMessage(ext_set, plan); } Result GetExtensionSetFromPlan(const substrait::Plan& plan, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { - if (registry == NULLPTR) { - registry = default_extension_id_registry(); - } - std::unordered_map uris; - uris.reserve(plan.extension_uris_size()); - for (const auto& uri : plan.extension_uris()) { - uris[uri.extension_uri_anchor()] = uri.uri(); - } - - // NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make - // will only store views to memory owned by registry. - - std::unordered_map type_ids, function_ids; - for (const auto& ext : plan.extensions()) { - switch (ext.mapping_type_case()) { - case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { - return Status::NotImplemented("Type Variations are not yet implemented"); - } - - case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { - const auto& type = ext.extension_type(); - std::string_view uri = uris[type.extension_uri_reference()]; - type_ids[type.type_anchor()] = Id{uri, type.name()}; - break; - } - - case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { - const auto& fn = ext.extension_function(); - std::string_view uri = uris[fn.extension_uri_reference()]; - function_ids[fn.function_anchor()] = Id{uri, fn.name()}; - break; - } - - default: - Unreachable(); - } - } - - return ExtensionSet::Make(std::move(uris), std::move(type_ids), std::move(function_ids), - conversion_options, registry); -} - -namespace { - -// TODO(ARROW-18145) Populate these from cmake files -constexpr uint32_t kSubstraitMajorVersion = 0; -constexpr uint32_t kSubstraitMinorVersion = 20; -constexpr uint32_t kSubstraitPatchVersion = 0; - -std::unique_ptr CreateVersion() { - auto version = std::make_unique(); - version->set_major_number(kSubstraitMajorVersion); - version->set_minor_number(kSubstraitMinorVersion); - version->set_patch_number(kSubstraitPatchVersion); - version->set_producer("Acero " + GetBuildInfo().version_string); - return version; + return GetExtensionSetFromMessage(plan, conversion_options, registry); } -} // namespace - Result> PlanToProto( const acero::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { diff --git a/cpp/src/arrow/engine/substrait/relation.h b/cpp/src/arrow/engine/substrait/relation.h index 0be4e03bb3871..d0913b9ae029b 100644 --- a/cpp/src/arrow/engine/substrait/relation.h +++ b/cpp/src/arrow/engine/substrait/relation.h @@ -20,6 +20,7 @@ #include #include "arrow/acero/exec_plan.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -50,5 +51,21 @@ struct ARROW_ENGINE_EXPORT PlanInfo { std::vector names; }; +/// An expression whose output has a name +struct ARROW_ENGINE_EXPORT NamedExpression { + /// An expression + compute::Expression expression; + // An optional name to assign to the output, may be the empty string + std::string name; +}; + +/// A collection of expressions bound to a common schema +struct ARROW_ENGINE_EXPORT BoundExpressions { + /// The expressions + std::vector named_expressions; + /// The schema that all the expressions are bound to + std::shared_ptr schema; +}; + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index b5effd7852451..9e670f121778e 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -35,12 +35,14 @@ #include "arrow/compute/expression.h" #include "arrow/dataset/file_base.h" #include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/extended_expression_internal.h" #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/plan_internal.h" #include "arrow/engine/substrait/relation.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/type_internal.h" +#include "arrow/engine/substrait/util.h" #include "arrow/type.h" namespace arrow { @@ -74,6 +76,20 @@ Result> SerializePlan( return Buffer::FromString(std::move(serialized)); } +Result> SerializeExpressions( + const BoundExpressions& bound_expressions, + const ConversionOptions& conversion_options, ExtensionSet* ext_set) { + ExtensionSet throwaway_ext_set; + if (ext_set == nullptr) { + ext_set = &throwaway_ext_set; + } + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr extended_expression, + ToProto(bound_expressions, ext_set, conversion_options)); + std::string serialized = extended_expression->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + Result> SerializeRelation( const acero::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { @@ -125,20 +141,14 @@ DeclarationFactory MakeWriteDeclarationFactory( }; } -constexpr uint32_t kMinimumMajorVersion = 0; -constexpr uint32_t kMinimumMinorVersion = 20; - Result> DeserializePlans( const Buffer& buf, DeclarationFactory declaration_factory, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); - if (plan.version().major_number() < kMinimumMajorVersion && - plan.version().minor_number() < kMinimumMinorVersion) { - return Status::Invalid("Can only parse plans with a version >= ", - kMinimumMajorVersion, ".", kMinimumMinorVersion); - } + ARROW_RETURN_NOT_OK( + CheckVersion(plan.version().major_number(), plan.version().minor_number())); ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, conversion_options, registry)); @@ -196,12 +206,8 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( const Buffer& buf, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); - - if (plan.version().major_number() < kMinimumMajorVersion && - plan.version().minor_number() < kMinimumMinorVersion) { - return Status::Invalid("Can only parse plans with a version >= ", - kMinimumMajorVersion, ".", kMinimumMinorVersion); - } + ARROW_RETURN_NOT_OK( + CheckVersion(plan.version().major_number(), plan.version().minor_number())); ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, conversion_options, registry)); @@ -233,6 +239,14 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( return PlanInfo{std::move(decl_info), std::move(names)}; } +Result DeserializeExpressions( + const Buffer& buf, const ExtensionIdRegistry* registry, + const ConversionOptions& conversion_options, ExtensionSet* ext_set_out) { + ARROW_ASSIGN_OR_RAISE(auto extended_expression, + ParseFromBuffer(buf)); + return FromProto(extended_expression, ext_set_out, conversion_options, registry); +} + namespace { Result> MakeSingleDeclarationPlan( diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index ebbefb176e2a9..ab749f4a64b05 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -52,6 +52,19 @@ Result> SerializePlan( const acero::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); +/// \brief Serialize expressions to a Substrait message +/// +/// \param[in] bound_expressions the expressions to serialize. +/// \param[in] conversion_options options to control how the conversion is done +/// \param[in,out] ext_set the extension mapping to use, optional, only needed +/// if you want to control the value of function anchors +/// to mirror a previous serialization / deserialization. +/// Will be updated if new functions are encountered +ARROW_ENGINE_EXPORT +Result> SerializeExpressions( + const BoundExpressions& bound_expressions, + const ConversionOptions& conversion_options = {}, ExtensionSet* ext_set = NULLPTR); + /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. using ConsumerFactory = std::function()>; @@ -155,6 +168,21 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( ExtensionSet* ext_set_out = NULLPTR, const ConversionOptions& conversion_options = {}); +/// \brief Deserialize a Substrait ExtendedExpression message to the corresponding Arrow +/// type +/// +/// \param[in] buf a buffer containing the protobuf serialization of a collection of bound +/// expressions +/// \param[in] registry an extension-id-registry to use, or null for the default one +/// \param[in] conversion_options options to control how the conversion is done +/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait +/// message is returned here. +/// \return A collection of expressions and a common input schema they are bound to +ARROW_ENGINE_EXPORT Result DeserializeExpressions( + const Buffer& buf, const ExtensionIdRegistry* registry = NULLPTR, + const ConversionOptions& conversion_options = {}, + ExtensionSet* ext_set_out = NULLPTR); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index efe1f702b4868..2e72ae70edd88 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -6070,5 +6070,100 @@ TEST(Substrait, PlanWithSegmentedAggregateExtension) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +void CheckExpressionRoundTrip(const Schema& schema, + const compute::Expression& expression) { + ASSERT_OK_AND_ASSIGN(compute::Expression bound_expression, expression.Bind(schema)); + BoundExpressions bound_expressions; + bound_expressions.schema = std::make_shared(schema); + bound_expressions.named_expressions = {{std::move(bound_expression), "some_name"}}; + + ASSERT_OK_AND_ASSIGN(std::shared_ptr buf, + SerializeExpressions(bound_expressions)); + + ASSERT_OK_AND_ASSIGN(BoundExpressions round_tripped, DeserializeExpressions(*buf)); + + AssertSchemaEqual(schema, *round_tripped.schema); + ASSERT_EQ(1, round_tripped.named_expressions.size()); + ASSERT_EQ("some_name", round_tripped.named_expressions[0].name); + ASSERT_EQ(bound_expressions.named_expressions[0].expression, + round_tripped.named_expressions[0].expression); +} + +TEST(Substrait, ExtendedExpressionSerialization) { + std::shared_ptr test_schema = + schema({field("a", int32()), field("b", int32()), field("c", float32()), + field("nested", struct_({field("x", float32()), field("y", float32())}))}); + // Basic a + b + CheckExpressionRoundTrip( + *test_schema, compute::call("add", {compute::field_ref(0), compute::field_ref(1)})); + // Nested struct reference + CheckExpressionRoundTrip(*test_schema, compute::field_ref(FieldPath{3, 0})); + // Struct return type + CheckExpressionRoundTrip(*test_schema, compute::field_ref(3)); + // c + nested.y + CheckExpressionRoundTrip( + *test_schema, + compute::call("add", {compute::field_ref(2), compute::field_ref(FieldPath{3, 1})})); +} + +TEST(Substrait, ExtendedExpressionInvalidPlans) { + // The schema defines the type as {"x", "y"} but output_names has {"a", "y"} + constexpr std::string_view kBadOuptutNames = R"( + { + "referredExpr":[ + { + "expression":{ + "selection":{ + "directReference":{ + "structField":{ + "field":3 + } + }, + "rootReference":{} + } + }, + "outputNames":["a", "y", "some_name"] + } + ], + "baseSchema":{ + "names":["a","b","c","nested","x","y"], + "struct":{ + "types":[ + { + "i32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "i32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "fp32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "struct":{ + "types":[ + { + "fp32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "fp32":{"nullability":"NULLABILITY_NULLABLE"} + } + ], + "nullability":"NULLABILITY_NULLABLE" + } + } + ] + } + }, + "version":{"majorNumber":9999} + } + )"; + + ASSERT_OK_AND_ASSIGN( + auto buf, internal::SubstraitFromJSON("ExtendedExpression", kBadOuptutNames)); + + ASSERT_THAT(DeserializeExpressions(*buf), + Raises(StatusCode::Invalid, testing::HasSubstr("Ambiguous plan"))); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index b74e333fd97f2..d842d0ef9d73b 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -70,6 +70,16 @@ const std::string& default_extension_types_uri() { return uri; } +Status CheckVersion(uint32_t major_version, uint32_t minor_version) { + if (major_version < kSubstraitMinimumMajorVersion && + minor_version < kSubstraitMinimumMinorVersion) { + return Status::Invalid("Can only parse Substrait messages with a version >= ", + kSubstraitMinimumMajorVersion, ".", + kSubstraitMinimumMinorVersion); + } + return Status::OK(); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 9f8bd8048899a..5128ec44bff77 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -68,6 +68,16 @@ ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri(); +// TODO(ARROW-18145) Populate these from cmake files +constexpr uint32_t kSubstraitMajorVersion = 0; +constexpr uint32_t kSubstraitMinorVersion = 27; +constexpr uint32_t kSubstraitPatchVersion = 0; + +constexpr uint32_t kSubstraitMinimumMajorVersion = 0; +constexpr uint32_t kSubstraitMinimumMinorVersion = 20; + +Status CheckVersion(uint32_t major_version, uint32_t minor_version); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util_internal.cc b/cpp/src/arrow/engine/substrait/util_internal.cc index 4e6cacf4f672b..89034784ab5bd 100644 --- a/cpp/src/arrow/engine/substrait/util_internal.cc +++ b/cpp/src/arrow/engine/substrait/util_internal.cc @@ -17,6 +17,9 @@ #include "arrow/engine/substrait/util_internal.h" +#include "arrow/config.h" +#include "arrow/engine/substrait/util.h" + namespace arrow { namespace engine { @@ -30,6 +33,15 @@ std::string EnumToString(int value, const google::protobuf::EnumDescriptor& desc return value_desc->name(); } +std::unique_ptr CreateVersion() { + auto version = std::make_unique(); + version->set_major_number(kSubstraitMajorVersion); + version->set_minor_number(kSubstraitMinorVersion); + version->set_patch_number(kSubstraitPatchVersion); + version->set_producer("Acero " + GetBuildInfo().version_string); + return version; +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util_internal.h b/cpp/src/arrow/engine/substrait/util_internal.h index efc3145543dd2..627ad1126df6e 100644 --- a/cpp/src/arrow/engine/substrait/util_internal.h +++ b/cpp/src/arrow/engine/substrait/util_internal.h @@ -17,8 +17,17 @@ #pragma once +#include + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/visibility.h" +#include "arrow/result.h" +#include "arrow/util/hashing.h" +#include "arrow/util/unreachable.h" + #include "substrait/algebra.pb.h" // IWYU pragma: export +#include "substrait/plan.pb.h" // IWYU pragma: export namespace arrow { namespace engine { @@ -26,5 +35,110 @@ namespace engine { ARROW_ENGINE_EXPORT std::string EnumToString( int value, const google::protobuf::EnumDescriptor& descriptor); +// Extension sets can be present in both substrait::Plan and substrait::ExtendedExpression +// and so this utility is templated to support both. +template +Result GetExtensionSetFromMessage( + const MessageType& message, const ConversionOptions& conversion_options, + const ExtensionIdRegistry* registry) { + if (registry == NULLPTR) { + registry = default_extension_id_registry(); + } + std::unordered_map uris; + uris.reserve(message.extension_uris_size()); + for (const auto& uri : message.extension_uris()) { + uris[uri.extension_uri_anchor()] = uri.uri(); + } + + // NOTE: it's acceptable to use views to memory owned by message; ExtensionSet::Make + // will only store views to memory owned by registry. + + std::unordered_map type_ids, function_ids; + for (const auto& ext : message.extensions()) { + switch (ext.mapping_type_case()) { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { + return Status::NotImplemented("Type Variations are not yet implemented"); + } + + case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { + const auto& type = ext.extension_type(); + std::string_view uri = uris[type.extension_uri_reference()]; + type_ids[type.type_anchor()] = Id{uri, type.name()}; + break; + } + + case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + const auto& fn = ext.extension_function(); + std::string_view uri = uris[fn.extension_uri_reference()]; + function_ids[fn.function_anchor()] = Id{uri, fn.name()}; + break; + } + + default: + Unreachable(); + } + } + + return ExtensionSet::Make(std::move(uris), std::move(type_ids), std::move(function_ids), + conversion_options, registry); +} + +template +Status AddExtensionSetToMessage(const ExtensionSet& ext_set, Message* message) { + message->clear_extension_uris(); + + std::unordered_map map; + + auto uris = message->mutable_extension_uris(); + uris->Reserve(static_cast(ext_set.uris().size())); + for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) { + auto uri = ext_set.uris().at(anchor); + if (uri.empty()) continue; + + auto ext_uri = std::make_unique(); + ext_uri->set_uri(std::string(uri)); + ext_uri->set_extension_uri_anchor(anchor); + uris->AddAllocated(ext_uri.release()); + + map[uri] = anchor; + } + + auto extensions = message->mutable_extensions(); + extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); + + using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; + + for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { + ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); + if (type_record.id.empty()) continue; + + auto ext_decl = std::make_unique(); + + auto type = std::make_unique(); + type->set_extension_uri_reference(map[type_record.id.uri]); + type->set_type_anchor(anchor); + type->set_name(std::string(type_record.id.name)); + ext_decl->set_allocated_extension_type(type.release()); + extensions->AddAllocated(ext_decl.release()); + } + + for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) { + ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor)); + + auto fn = std::make_unique(); + fn->set_extension_uri_reference(map[function_id.uri]); + fn->set_function_anchor(anchor); + fn->set_name(std::string(function_id.name)); + + auto ext_decl = std::make_unique(); + ext_decl->set_allocated_extension_function(fn.release()); + extensions->AddAllocated(ext_decl.release()); + } + + return Status::OK(); +} + +std::unique_ptr CreateVersion(); + } // namespace engine } // namespace arrow diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index c05ff4228462c..8edaa422b3dcf 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -101,8 +101,8 @@ ARROW_RE2_BUILD_SHA256_CHECKSUM=f89c61410a072e5cbcf8c27e3a778da7d6fd2f2b5b1445cd # 1.1.9 is patched to implement https://github.com/google/snappy/pull/148 if this is bumped, remove the patch ARROW_SNAPPY_BUILD_VERSION=1.1.9 ARROW_SNAPPY_BUILD_SHA256_CHECKSUM=75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7 -ARROW_SUBSTRAIT_BUILD_VERSION=v0.20.0 -ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=5ceaa559ccef29a7825b5e5d4b5e7eed384830294f08bec913feecdd903a94cf +ARROW_SUBSTRAIT_BUILD_VERSION=v0.27.0 +ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=4ed375f69d972a57fdc5ec406c17003a111831d8640d3f1733eccd4b3ff45628 ARROW_S2N_TLS_BUILD_VERSION=v1.3.35 ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d ARROW_THRIFT_BUILD_VERSION=0.16.0 diff --git a/docs/source/python/api/substrait.rst b/docs/source/python/api/substrait.rst index 207b2d9cdbcea..66e88fcd279ae 100644 --- a/docs/source/python/api/substrait.rst +++ b/docs/source/python/api/substrait.rst @@ -31,6 +31,19 @@ Query Execution run_query +Expression Serialization +------------------------ + +These functions allow for serialization and deserialization of pyarrow +compute expressions. + +.. autosummary:: + :toctree: ../generated/ + + BoundExpressions + deserialize_expressions + serialize_expressions + Utility ------- diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 74b9c2ec8c78b..0c1744febbe1e 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -37,6 +37,23 @@ import numpy as np import warnings +__pas = None +_substrait_msg = ( + "The pyarrow installation is not built with support for Substrait." +) + + +def _pas(): + global __pas + if __pas is None: + try: + import pyarrow.substrait as pas + __pas = pas + except ImportError: + raise ImportError(_substrait_msg) + return __pas + + def _forbid_instantiation(klass, subclasses_instead=True): msg = '{} is an abstract class thus cannot be initialized.'.format( klass.__name__ @@ -2386,6 +2403,58 @@ cdef class Expression(_Weakrefable): self.__class__.__name__, str(self) ) + @staticmethod + def from_substrait(object buffer not None): + """ + Deserialize an expression from Substrait + + The serialized message must be an ExtendedExpression message that has + only a single expression. The name of the expression and the schema + the expression was bound to will be ignored. Use + pyarrow.substrait.deserialize_expressions if this information is needed + or if the message might contain multiple expressions. + + Parameters + ---------- + buffer : bytes or Buffer + The Substrait message to deserialize + + Returns + ------- + Expression + The deserialized expression + """ + expressions = _pas().deserialize_expressions(buffer).expressions + if len(expressions) == 0: + raise ValueError("Substrait message did not contain any expressions") + if len(expressions) > 1: + raise ValueError( + "Substrait message contained multiple expressions. Use pyarrow.substrait.deserialize_expressions instead") + return next(iter(expressions.values())) + + def to_substrait(self, Schema schema not None, c_bool allow_arrow_extensions=False): + """ + Serialize the expression using Substrait + + The expression will be serialized as an ExtendedExpression message that has a + single expression named "expression" + + Parameters + ---------- + schema : Schema + The input schema the expression will be bound to + allow_arrow_extensions : bool, default False + If False then only functions that are part of the core Substrait function + definitions will be allowed. Set this to True to allow pyarrow-specific functions + but the result may not be accepted by other compute libraries. + + Returns + ------- + Buffer + A buffer containing the serialized Protobuf plan. + """ + return _pas().serialize_expressions([self], ["expression"], schema, allow_arrow_extensions=allow_arrow_extensions) + @staticmethod def _deserialize(Buffer buffer not None): return Expression.wrap(GetResultValue(CDeserializeExpression( diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 1be2e6330aba5..4efad2c4d1bc5 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -20,6 +20,7 @@ from cython.operator cimport dereference as deref from libcpp.vector cimport vector as std_vector from pyarrow import Buffer, py_buffer +from pyarrow._compute cimport Expression from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * @@ -164,7 +165,7 @@ def _parse_json_plan(plan): Parameters ---------- - plan: bytes + plan : bytes Substrait plan in JSON. Returns @@ -185,6 +186,144 @@ def _parse_json_plan(plan): return pyarrow_wrap_buffer(c_buf_plan) +def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False): + """ + Serialize a collection of expressions into Substrait + + Substrait expressions must be bound to a schema. For example, + the Substrait expression ``a:i32 + b:i32`` is different from the + Substrait expression ``a:i64 + b:i64``. Pyarrow expressions are + typically unbound. For example, both of the above expressions + would be represented as ``a + b`` in pyarrow. + + This means a schema must be provided when serializing an expression. + It also means that the serialization may fail if a matching function + call cannot be found for the expression. + + Parameters + ---------- + exprs : list of Expression + The expressions to serialize + names : list of str + Names for the expressions + schema : Schema + The schema the expressions will be bound to + allow_arrow_extensions : bool, default False + If False then only functions that are part of the core Substrait function + definitions will be allowed. Set this to True to allow pyarrow-specific functions + and user defined functions but the result may not be accepted by other + compute libraries. + + Returns + ------- + Buffer + An ExtendedExpression message containing the serialized expressions + """ + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + shared_ptr[CBuffer] c_buffer + CNamedExpression c_named_expr + CBoundExpressions c_bound_exprs + CConversionOptions c_conversion_options + + if len(exprs) != len(names): + raise ValueError("exprs and names need to have the same length") + for expr, name in zip(exprs, names): + if not isinstance(expr, Expression): + raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs") + if not isinstance(name, str): + raise TypeError(f"Expected str, got '{type(name)}' in names") + c_named_expr.expression = ( expr).unwrap() + c_named_expr.name = tobytes( name) + c_bound_exprs.named_expressions.push_back(c_named_expr) + + c_bound_exprs.schema = ( schema).sp_schema + + c_conversion_options.allow_arrow_extensions = allow_arrow_extensions + + with nogil: + c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options) + c_buffer = GetResultValue(c_res_buffer) + return pyarrow_wrap_buffer(c_buffer) + + +cdef class BoundExpressions(_Weakrefable): + """ + A collection of named expressions and the schema they are bound to + + This is equivalent to the Substrait ExtendedExpression message + """ + + cdef: + CBoundExpressions c_bound_exprs + + def __init__(self): + msg = 'BoundExpressions is an abstract class thus cannot be initialized.' + raise TypeError(msg) + + cdef void init(self, CBoundExpressions bound_expressions): + self.c_bound_exprs = bound_expressions + + @property + def schema(self): + """ + The common schema that all expressions are bound to + """ + return pyarrow_wrap_schema(self.c_bound_exprs.schema) + + @property + def expressions(self): + """ + A dict from expression name to expression + """ + expr_dict = {} + for named_expr in self.c_bound_exprs.named_expressions: + name = frombytes(named_expr.name) + expr = Expression.wrap(named_expr.expression) + expr_dict[name] = expr + return expr_dict + + @staticmethod + cdef wrap(const CBoundExpressions& bound_expressions): + cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions) + self.init(bound_expressions) + return self + + +def deserialize_expressions(buf): + """ + Deserialize an ExtendedExpression Substrait message into a BoundExpressions object + + Parameters + ---------- + buf : Buffer or bytes + The message to deserialize + + Returns + ------- + BoundExpressions + The deserialized expressions, their names, and the bound schema + """ + cdef: + shared_ptr[CBuffer] c_buffer + CResult[CBoundExpressions] c_res_bound_exprs + CBoundExpressions c_bound_exprs + + if isinstance(buf, bytes): + c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) + elif isinstance(buf, Buffer): + c_buffer = pyarrow_unwrap_buffer(buf) + else: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") + + with nogil: + c_res_bound_exprs = DeserializeExpressions(deref(c_buffer)) + c_bound_exprs = GetResultValue(c_res_bound_exprs) + + return BoundExpressions.wrap(c_bound_exprs) + + def get_supported_functions(): """ Get a list of Substrait functions that the underlying diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index eabccb2b4a3cb..c41f4c05d3a77 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -40,6 +40,7 @@ cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" no CConversionOptions() ConversionStrictness strictness function[CNamedTableProvider] named_table_provider + c_bool allow_arrow_extensions cdef extern from "arrow/engine/substrait/extension_set.h" \ namespace "arrow::engine" nogil: @@ -49,6 +50,23 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \ ExtensionIdRegistry* default_extension_id_registry() +cdef extern from "arrow/engine/substrait/relation.h" namespace "arrow::engine" nogil: + + cdef cppclass CNamedExpression "arrow::engine::NamedExpression": + CExpression expression + c_string name + + cdef cppclass CBoundExpressions "arrow::engine::BoundExpressions": + std_vector[CNamedExpression] named_expressions + shared_ptr[CSchema] schema + +cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogil: + + CResult[shared_ptr[CBuffer]] SerializeExpressions( + const CBoundExpressions& bound_expressions, const CConversionOptions& conversion_options) + + CResult[CBoundExpressions] DeserializeExpressions( + const CBuffer& serialized_expressions) cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan( diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index ea7e19142cdd3..a2b217f4936c5 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -17,8 +17,11 @@ try: from pyarrow._substrait import ( # noqa + BoundExpressions, get_supported_functions, run_query, + deserialize_expressions, + serialize_expressions ) except ImportError as exc: raise ImportError( diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 6fbca5209975c..03d2688e65705 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -41,6 +41,10 @@ from pyarrow.lib import ArrowNotImplementedError from pyarrow.tests import util +try: + import pyarrow.substrait as pas +except ImportError: + pas = None all_array_types = [ ('bool', [True, False, False, True, True]), @@ -3285,7 +3289,14 @@ def test_rank_options(): tiebreaker="NonExisting") -def test_expression_serialization(): +def create_sample_expressions(): + # We need a schema for substrait conversion + schema = pa.schema([pa.field("i64", pa.int64()), pa.field( + "foo", pa.struct([pa.field("bar", pa.string())]))]) + + # Creates a bunch of sample expressions for testing + # serialization and deserialization. The expressions are categorized + # to reflect certain nuances in Substrait conversion. a = pc.scalar(1) b = pc.scalar(1.1) c = pc.scalar(True) @@ -3294,20 +3305,133 @@ def test_expression_serialization(): f = pc.scalar({'a': 1}) g = pc.scalar(pa.scalar(1)) h = pc.scalar(np.int64(2)) + j = pc.scalar(False) + + # These expression consist entirely of literals + literal_exprs = [a, b, c, d, e, g, h, j] + + # These expressions include at least one function call + exprs_with_call = [a == b, a != b, a > b, c & j, c | j, ~c, d.is_valid(), + a + b, a - b, a * b, a / b, pc.negate(a), + pc.add(a, b), pc.subtract(a, b), pc.divide(a, b), + pc.multiply(a, b), pc.power(a, a), pc.sqrt(a), + pc.exp(b), pc.cos(b), pc.sin(b), pc.tan(b), + pc.acos(b), pc.atan(b), pc.asin(b), pc.atan2(b, b), + pc.abs(b), pc.sign(a), pc.bit_wise_not(a), + pc.bit_wise_and(a, a), pc.bit_wise_or(a, a), + pc.bit_wise_xor(a, a), pc.is_nan(b), pc.is_finite(b), + pc.coalesce(a, b), + a.cast(pa.int32(), safe=False)] + + # These expressions test out various reference styles and may include function + # calls. Named references are used here. + exprs_with_ref = [pc.field('i64') > 5, pc.field('i64') == 5, + pc.field('i64') == 7, + pc.field(('foo', 'bar')) == 'value', + pc.field('foo', 'bar') == 'value'] + + # Similar to above but these use numeric references instead of string refs + exprs_with_numeric_refs = [pc.field(0) > 5, pc.field(0) == 5, + pc.field(0) == 7, + pc.field((1, 0)) == 'value', + pc.field(1, 0) == 'value'] + + # Expressions that behave uniquely when converting to/from substrait + special_cases = [ + f, # Struct literals lose their field names + a.isin([1, 2, 3]), # isin converts to an or list + pc.field('i64').is_null() # pyarrow always specifies a FunctionOptions + # for is_null which, being the default, is + # dropped on serialization + ] + + all_exprs = literal_exprs.copy() + all_exprs += exprs_with_call + all_exprs += exprs_with_ref + all_exprs += special_cases + + return { + "all": all_exprs, + "literals": literal_exprs, + "calls": exprs_with_call, + "refs": exprs_with_ref, + "numeric_refs": exprs_with_numeric_refs, + "special": special_cases, + "schema": schema + } - all_exprs = [a, b, c, d, e, f, g, h, a == b, a > b, a & b, a | b, ~c, - d.is_valid(), a.cast(pa.int32(), safe=False), - a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), - pc.field('i64') > 5, pc.field('i64') == 5, - pc.field('i64') == 7, pc.field('i64').is_null(), - pc.field(('foo', 'bar')) == 'value', - pc.field('foo', 'bar') == 'value'] - for expr in all_exprs: +# Tests the Arrow-specific serialization mechanism + + +def test_expression_serialization_arrow(): + for expr in create_sample_expressions()["all"]: assert isinstance(expr, pc.Expression) restored = pickle.loads(pickle.dumps(expr)) assert expr.equals(restored) +@pytest.mark.substrait +def test_expression_serialization_substrait(): + + exprs = create_sample_expressions() + schema = exprs["schema"] + + # Basic literals don't change on binding and so they will round + # trip without any change + for expr in exprs["literals"]: + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + assert expr.equals(deserialized) + + # Expressions are bound when they get serialized. Since bound + # expressions are not equal to their unbound variants we cannot + # compare the round tripped with the original + for expr in exprs["calls"]: + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + # We can't compare the expressions themselves because of the bound + # unbound difference. But we can compare the string representation + assert str(deserialized) == str(expr) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + for expr, expr_norm in zip(exprs["refs"], exprs["numeric_refs"]): + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + assert str(deserialized) == str(expr_norm) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + # For the special cases we get various wrinkles in serialization but we + # should always get the same thing from round tripping twice + for expr in exprs["special"]: + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + # Special case, we lose the field names of struct literals + f = exprs["special"][0] + serialized = f.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + assert deserialized.equals(pc.scalar({'': 1})) + + # Special case, is_in converts to a == opt[0] || a == opt[1] ... + a = pc.scalar(1) + expr = a.isin([1, 2, 3]) + target = (a == 1) | (a == 2) | (a == 3) + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + # Compare str's here to bypass the bound/unbound difference + assert str(target) == str(deserialized) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + def test_expression_construction(): zero = pc.scalar(0) one = pc.scalar(1) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 93ecae7bfa10e..be35a21a02411 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -21,8 +21,10 @@ import pytest import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.dataset as ds from pyarrow.lib import tobytes -from pyarrow.lib import ArrowInvalid +from pyarrow.lib import ArrowInvalid, ArrowNotImplementedError try: import pyarrow.substrait as substrait @@ -923,3 +925,106 @@ def table_provider(names, _): # Ordering of k is deterministic because this is running with serial execution assert res_tb == expected_tb + + +@pytest.mark.parametrize("expr", [ + pc.equal(ds.field("x"), 7), + pc.equal(ds.field("x"), ds.field("y")), + ds.field("x") > 50 +]) +def test_serializing_expressions(expr): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + + buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + assert len(returned.expressions) == 1 + assert "test_expr" in returned.expressions + + +def test_invalid_expression_ser_des(): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + expr = pc.equal(ds.field("x"), 7) + bad_expr = pc.equal(ds.field("z"), 7) + # Invalid number of names + with pytest.raises(ValueError) as excinfo: + pa.substrait.serialize_expressions([expr], [], schema) + assert 'need to have the same length' in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + pa.substrait.serialize_expressions([expr], ["foo", "bar"], schema) + assert 'need to have the same length' in str(excinfo.value) + # Expression doesn't match schema + with pytest.raises(ValueError) as excinfo: + pa.substrait.serialize_expressions([bad_expr], ["expr"], schema) + assert 'No match for FieldRef' in str(excinfo.value) + + +def test_serializing_multiple_expressions(): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + exprs = [pc.equal(ds.field("x"), 7), pc.equal(ds.field("x"), ds.field("y"))] + buf = pa.substrait.serialize_expressions(exprs, ["first", "second"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + assert len(returned.expressions) == 2 + + norm_exprs = [pc.equal(ds.field(0), 7), pc.equal(ds.field(0), ds.field(1))] + assert str(returned.expressions["first"]) == str(norm_exprs[0]) + assert str(returned.expressions["second"]) == str(norm_exprs[1]) + + +def test_serializing_with_compute(): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + expr = pc.equal(ds.field("x"), 7) + expr_norm = pc.equal(ds.field(0), 7) + buf = expr.to_substrait(schema) + returned = pa.substrait.deserialize_expressions(buf) + + assert schema == returned.schema + assert len(returned.expressions) == 1 + + assert str(returned.expressions["expression"]) == str(expr_norm) + + # Compute can't deserialize messages with multiple expressions + buf = pa.substrait.serialize_expressions([expr, expr], ["first", "second"], schema) + with pytest.raises(ValueError) as excinfo: + pc.Expression.from_substrait(buf) + assert 'contained multiple expressions' in str(excinfo.value) + + # Deserialization should be possible regardless of the expression name + buf = pa.substrait.serialize_expressions([expr], ["weirdname"], schema) + expr2 = pc.Expression.from_substrait(buf) + assert str(expr2) == str(expr_norm) + + +def test_serializing_udfs(): + # Note, UDF in this context means a function that is not + # recognized by Substrait. It might still be a builtin pyarrow + # function. + schema = pa.schema([ + pa.field("x", pa.uint32()) + ]) + a = pc.scalar(10) + b = pc.scalar(4) + exprs = [pc.shift_left(a, b)] + + with pytest.raises(ArrowNotImplementedError): + pa.substrait.serialize_expressions(exprs, ["expr"], schema) + + buf = pa.substrait.serialize_expressions( + exprs, ["expr"], schema, allow_arrow_extensions=True) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + assert len(returned.expressions) == 1 + assert str(returned.expressions["expr"]) == str(exprs[0])