Skip to content

Commit

Permalink
apacheGH-33985: [C++] Add substrait serialization/deserialization for…
Browse files Browse the repository at this point in the history
… expressions (apache#34834)

### Rationale for this change

Substrait provides a library-independent way to represent compute expressions.  By serializing and deserializing pyarrow compute expression to substrait we can allow interoperability with other libraries.

Originally it was thought this would not be needed because users would be sending entire query plans (which contain expressions) back and forth and so there was no need to work with expressions by themselves.

However, as more and more APIs and integration points emerge it turns out there are situations where serializing expressions by themselves is useful.  For example, the proposed datasets protocol, or for the Java JNI datasets implementation (which uses Arrow-C++'s datasets)

### What changes are included in this PR?

In Arrow-C++ we add two new methods to serialize and deserialize a collection of named, bound expressions to Substrait's ExtendedExpression message.

In pyarrow we expose these two methods and also add utility methods to pyarrow.compute.Expression to convert a single expression to/from substrait (these will be encoded as an ExtendedExpression message with one expression named "expression")

In addition, this PR exposed that we do not have very many bindings for arrow-functions to substrait-functions (previous work has mostly focused on the reverse direction).  This PR adds many (though not all) new bindings.

In addition, this PR adds ToProto for cast and both FromProto and ToProto support for the SingularOrList expression type (we convert is_in to SingularOrList and convert SingularOrList to an or list).

This should provide support for all the sargable operators except between (there is no Arrow-C++ between function) and like (we still don't have arrow->substrait bindings for the string functions) which should be a sufficient set of expressions for a first release.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

There are new features, as described above, but no backwards incompatible changes.

### Caveats

There are a fair number of minor inconsistencies or surprises, many of which can be smoothed over by follow-up work.

#### Bound Expressions

Arrow-C++ has long had a distinction between "unbound expressions" (e.g. `a + b`) and "bound expressions" (e.g. `a:i32 + b:i32`).  A bound expression is an expression that has been bound to a schema of some kind.  Field references are resolved and the output type is known for every node of the AST.

Pyarrow has hidden this complexity and most pyarrow compute expressions that the user encounters will be unbound expressions.  Substrait is only capable (currently) of representing bound expressions.  As a result, in order to serialize expressions, the user will need to provide an input schema.  This can be an inconvenience for some workflows.  To resolve this, I would like to eventually add support for unbound expressions to Substrait (substrait-io/substrait#515)

Another minor annoyance of bound expressions is that an unbound pyarrow.compute.Expression object will not be equal to a bound pyarrow.compute.Expression object.  It would make testing easier if we had a `pyarrow.compute.Expression.equals` variant that did not examine bound fields.

#### Named field references

Pyarrow datasets users are used to working with named field references.  For example, one can set a filter `pc.equal(ds.field("x"), 7)`.  Substrait, since it requires everything to be bound, considers named references to be superfluous and does everything in terms of numeric indices into the base schema.  So the above expression, after round tripping, would become something like `pc.equal(ds.field(3), 7)` (assuming `"x"` is at index `3` in the schema used for serialization).  This is something that can be overcome in the future if Substrait adds support for unbound expressions.  Or, if that doesn't happen, it could still be implemented as a Substrait expression hint (this would allow named references to be used even if the user wants to work with bound expressions).

#### UDFs

UDFs ARE supported by this PR.  This covers both "builtin arrow functions that do not exist in substrait (e.g. shift_left)" and "custom UDFs added with `register_scalar_function`".  By default, UDFs will not be allowed when converting to Substrait because the resulting message would not be portable (e.g. you can't expect an external system to know about your custom UDFs).  However, you can set the `allow_udfs` flag to True and these will be allowed.  The Substrait representation will have the URI `urn:arrow:substrait_simple_extension_function`.

**Options**: Although UDFs are allowed we do not yet support UDFs that take function options.  These are trickier to convert to Substrait (though it should be possible in the future if someone is motivated enough).

#### Rough Edges

There are a few corner cases:

 * The function `is_in` converts to Substrait's `SingularOrList`.  On conversion back to Arrow this becomes an or list.  In other words, the function `is_in(5, [1, 2, 5])` converts to `5 == 1 || 5 == 2 || 5 == 5`.  This is because Substrait's or list is more expression and allows things like `5 == field_ref(0) || 5 == 7` which cannot be expressed as an `is_in` function.
 * Arrow functions can either be converted to Substrait or are considered UDFs.  However, there are a small number of functions which can "sometimes" be converted to Substrait depending on the function options.  At the moment I think this is only the `is_null` function.  The `is_null` function has an option `nan_is_null` which will allow you to consider `NaN` as a null value.  Substrait has no single function that evaluates both `NULL` and `NaN` as true.  In the meantime you can use `is_null || is_nan`.  In the future, should someone want to, they could add special logic to convert this case.
* Closes: apache#33985

Lead-authored-by: Weston Pace <[email protected]>
Co-authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
  • Loading branch information
westonpace and jorisvandenbossche authored Aug 22, 2023
1 parent 9ddd8d5 commit 702e9ca
Show file tree
Hide file tree
Showing 27 changed files with 1,363 additions and 158 deletions.
7 changes: 6 additions & 1 deletion cpp/cmake_modules/ThirdpartyToolchain.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1785,7 +1785,12 @@ macro(build_substrait)

# Note: not all protos in Substrait actually matter to plan
# consumption. No need to build the ones we don't need.
set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type)
set(SUBSTRAIT_PROTOS
algebra
extended_expression
extensions/extensions
plan
type)
set(ARROW_SUBSTRAIT_PROTOS extension_rels)
set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto")

Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ CastOptions::CastOptions(bool safe)
allow_float_truncate(!safe),
allow_invalid_utf8(!safe) {}

bool CastOptions::is_safe() const {
return !allow_int_overflow && !allow_time_truncate && !allow_time_overflow &&
!allow_decimal_truncate && !allow_float_truncate && !allow_invalid_utf8;
}

bool CastOptions::is_unsafe() const {
return allow_int_overflow && allow_time_truncate && allow_time_overflow &&
allow_decimal_truncate && allow_float_truncate && allow_invalid_utf8;
}

constexpr char CastOptions::kTypeName[];

Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class ARROW_EXPORT CastOptions : public FunctionOptions {
// Indicate if conversions from Binary/FixedSizeBinary to string must
// validate the utf8 payload.
bool allow_invalid_utf8;

/// true if the safety options all match CastOptions::Safe
///
/// Note, if this returns false it does not mean is_unsafe will return true
bool is_safe() const;
/// true if the safety options all match CastOptions::Unsafe
///
/// Note, if this returns false it does not mean is_safe will return true
bool is_unsafe() const;
};

/// @}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ arrow_install_all_headers("arrow/engine")

set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extended_expression_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
substrait/options.cc
Expand Down
162 changes: 156 additions & 6 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "arrow/buffer.h"
#include "arrow/builder.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/expression.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/engine/substrait/extension_set.h"
Expand Down Expand Up @@ -338,6 +339,22 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
return compute::call("case_when", std::move(args));
}

case substrait::Expression::kSingularOrList: {
const auto& or_list = expr.singular_or_list();

ARROW_ASSIGN_OR_RAISE(compute::Expression value,
FromProto(or_list.value(), ext_set, conversion_options));

std::vector<compute::Expression> option_eqs;
for (const auto& option : or_list.options()) {
ARROW_ASSIGN_OR_RAISE(compute::Expression arrow_option,
FromProto(option, ext_set, conversion_options));
option_eqs.push_back(compute::call("equal", {value, arrow_option}));
}

return compute::or_(option_eqs);
}

case substrait::Expression::kScalarFunction: {
const auto& scalar_fn = expr.scalar_function();

Expand Down Expand Up @@ -1055,9 +1072,68 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
" arguments but no argument could be found at index ", i);
}
}

for (const auto& option : call.options()) {
substrait::FunctionOption* fn_option = scalar_fn->add_options();
fn_option->set_name(option.first);
for (const auto& opt_val : option.second) {
std::string* pref = fn_option->add_preference();
*pref = opt_val;
}
}

return std::move(scalar_fn);
}

Result<std::vector<std::unique_ptr<substrait::Expression>>> DatumToLiterals(
const Datum& datum, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
std::vector<std::unique_ptr<substrait::Expression>> literals;

auto ScalarToLiteralExpr = [&](const std::shared_ptr<Scalar>& scalar)
-> Result<std::unique_ptr<substrait::Expression>> {
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::Literal> literal,
ToProto(scalar, ext_set, conversion_options));
auto literal_expr = std::make_unique<substrait::Expression>();
literal_expr->set_allocated_literal(literal.release());
return literal_expr;
};

switch (datum.kind()) {
case Datum::Kind::SCALAR: {
ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(datum.scalar()));
literals.push_back(std::move(literal_expr));
break;
}
case Datum::Kind::ARRAY: {
std::shared_ptr<Array> values = datum.make_array();
for (int64_t i = 0; i < values->length(); i++) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, values->GetScalar(i));
ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar));
literals.push_back(std::move(literal_expr));
}
break;
}
case Datum::Kind::CHUNKED_ARRAY: {
std::shared_ptr<ChunkedArray> values = datum.chunked_array();
for (const auto& chunk : values->chunks()) {
for (int64_t i = 0; i < chunk->length(); i++) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, chunk->GetScalar(i));
ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar));
literals.push_back(std::move(literal_expr));
}
}
break;
}
case Datum::Kind::RECORD_BATCH:
case Datum::Kind::TABLE:
case Datum::Kind::NONE:
return Status::Invalid("Expected a literal or an array of literals, got ",
datum.ToString());
}
return literals;
}

Result<std::unique_ptr<substrait::Expression>> ToProto(
const compute::Expression& expr, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -1164,15 +1240,89 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(

out->set_allocated_if_then(if_then.release());
return std::move(out);
} else if (call->function_name == "cast") {
auto cast = std::make_unique<substrait::Expression::Cast>();

// Arrow's cast function does not have a "return null" option and so throw exception
// is the only behavior we can support.
cast->set_failure_behavior(
substrait::Expression::Cast::FAILURE_BEHAVIOR_THROW_EXCEPTION);

std::shared_ptr<compute::CastOptions> cast_options =
internal::checked_pointer_cast<compute::CastOptions>(call->options);
if (!cast_options->is_unsafe()) {
return Status::Invalid("Substrait is only capable of representing unsafe casts");
}

if (arguments.size() != 1) {
return Status::Invalid(
"A call to the cast function must have exactly one argument");
}

cast->set_allocated_input(arguments[0].release());

ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Type> to_type,
ToProto(*cast_options->to_type.type, /*nullable=*/true, ext_set,
conversion_options));

cast->set_allocated_type(to_type.release());

out->set_allocated_cast(cast.release());
return std::move(out);
} else if (call->function_name == "is_in") {
auto or_list = std::make_unique<substrait::Expression::SingularOrList>();

if (arguments.size() != 1) {
return Status::Invalid(
"A call to the is_in function must have exactly one argument");
}

or_list->set_allocated_value(arguments[0].release());
std::shared_ptr<compute::SetLookupOptions> is_in_options =
internal::checked_pointer_cast<compute::SetLookupOptions>(call->options);

// TODO(GH-36420) Acero does not currently handle nulls correctly
ARROW_ASSIGN_OR_RAISE(
std::vector<std::unique_ptr<substrait::Expression>> options,
DatumToLiterals(is_in_options->value_set, ext_set, conversion_options));
for (auto& option : options) {
or_list->mutable_options()->AddAllocated(option.release());
}
out->set_allocated_singular_or_list(or_list.release());
return std::move(out);
}

// other expression types dive into extensions immediately
ARROW_ASSIGN_OR_RAISE(
ExtensionIdRegistry::ArrowToSubstraitCall converter,
ext_set->registry()->GetArrowToSubstraitCall(call->function_name));
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn,
EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
Result<ExtensionIdRegistry::ArrowToSubstraitCall> maybe_converter =
ext_set->registry()->GetArrowToSubstraitCall(call->function_name);

ExtensionIdRegistry::ArrowToSubstraitCall converter;
std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn;
if (maybe_converter.ok()) {
converter = *maybe_converter;
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(
scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
} else if (maybe_converter.status().IsNotImplemented() &&
conversion_options.allow_arrow_extensions) {
if (call->options) {
return Status::NotImplemented(
"The function ", call->function_name,
" has no Substrait mapping. Arrow extensions are enabled but the call "
"contains function options and there is no current mechanism to encode those.");
}
Id persistent_id = ext_set->RegisterPlanSpecificId(
{kArrowSimpleExtensionFunctionsUri, call->function_name});
SubstraitCall substrait_call(persistent_id, call->type.GetSharedPtr(),
/*nullable=*/true);
for (int i = 0; i < static_cast<int>(call->arguments.size()); i++) {
substrait_call.SetValueArg(i, call->arguments[i]);
}
ARROW_ASSIGN_OR_RAISE(
scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
} else {
return maybe_converter.status();
}
out->set_allocated_scalar_function(scalar_fn.release());
return std::move(out);
}
Expand Down
Loading

0 comments on commit 702e9ca

Please sign in to comment.