diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 16d2ace4ac0d7..6b4c05a3b1dd5 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -56,7 +56,7 @@ Status ParseFromBufferImpl(const Buffer& buf, const std::string& full_name, if (message->ParseFromZeroCopyStream(&buf_stream)) { return Status::OK(); } - return Status::IOError("ParseFromZeroCopyStream failed for ", full_name); + return Status::Invalid("ParseFromZeroCopyStream failed for ", full_name); } template diff --git a/docs/source/python/api/substrait.rst b/docs/source/python/api/substrait.rst index 1556be9dbd011..26c70216a8af2 100644 --- a/docs/source/python/api/substrait.rst +++ b/docs/source/python/api/substrait.rst @@ -43,6 +43,9 @@ compute expressions. BoundExpressions deserialize_expressions serialize_expressions + serialize_schema + deserialize_schema + SubstraitSchema Utility ------- diff --git a/docs/source/python/integration.rst b/docs/source/python/integration.rst index 1cafc3dbded37..95c912c187d52 100644 --- a/docs/source/python/integration.rst +++ b/docs/source/python/integration.rst @@ -34,6 +34,7 @@ This allows to easily integrate PyArrow with other languages and technologies. .. toctree:: :maxdepth: 2 + integration/substrait integration/python_r integration/python_java integration/extending diff --git a/docs/source/python/integration/substrait.rst b/docs/source/python/integration/substrait.rst new file mode 100644 index 0000000000000..eaa6151e4d32e --- /dev/null +++ b/docs/source/python/integration/substrait.rst @@ -0,0 +1,249 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +========= +Substrait +========= + +The ``arrow-substrait`` module implements support for the Substrait_ format, +enabling conversion to and from Arrow objects. + +The ``arrow-dataset`` module can execute Substrait_ plans via the +:doc:`Acero <../cpp/streaming_execution>` query engine. + +.. contents:: + +Working with Schemas +==================== + +Arrow schemas can be encoded and decoded using the :meth:`pyarrow.substrait.serialize_schema` and +:meth:`pyarrow.substrait.deserialize_schema` functions. + +.. code-block:: python + + import pyarrow as pa + import pyarrow.substrait as pa_substrait + + arrow_schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.string()) + ]) + substrait_schema = pa_substrait.serialize_schema(arrow_schema) + +The schema marshalled as a Substrait ``NamedStruct`` is directly +available as ``substrait_schema.schema``:: + + >>> print(substrait_schema.schema) + b'\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01' + +In case arrow custom types were used, the schema will require +extensions for those types to be actually usable, for this reason +the schema is also available as an `Extended Expression`_ including +all the extensions types:: + + >>> print(substrait_schema.expression) + b'"\x14\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01:\x19\x10,*\x15Acero 17.0.0' + +If ``Substrait Python`` is installed, the schema can also be converted to +a ``substrait-python`` object:: + + >>> print(substrait_schema.to_pysubstrait()) + version { + minor_number: 44 + producer: "Acero 17.0.0" + } + base_schema { + names: "x" + names: "y" + struct { + types { + i32 { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + } + } + +Working with Expressions +======================== + +Arrow compute expressions can be encoded and decoded using the +:meth:`pyarrow.substrait.serialize_expressions` and +:meth:`pyarrow.substrait.deserialize_expressions` functions. + +.. code-block:: python + + import pyarrow as pa + import pyarrow.compute as pa + import pyarrow.substrait as pa_substrait + + arrow_schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + + substrait_expr = pa_substrait.serialize_expressions( + exprs=[pc.field("x") + pc.field("y")], + names=["total"], + schema=arrow_schema + ) + +The result of encoding to substrait an expression will be the +protobuf ``ExtendedExpression`` message data itself:: + + >>> print(bytes(substrait_expr)) + b'\nZ\x12Xhttps://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml\x12\x07\x1a\x05\x1a\x03add\x1a>\n5\x1a3\x1a\x04*\x02\x10\x01"\n\x1a\x08\x12\x06\n\x02\x12\x00"\x00"\x0c\x1a\n\x12\x08\n\x04\x12\x02\x08\x01"\x00*\x11\n\x08overflow\x12\x05ERROR\x1a\x05total"\x14\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04*\x02\x10\x01:\x19\x10,*\x15Acero 17.0.0' + +So in case a ``Substrait Python`` object is required, the expression +has to be decoded from ``substrait-python`` itself:: + + >>> import substrait + >>> pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(substrait_expr) + >>> print(pysubstrait_expr) + version { + minor_number: 44 + producer: "Acero 17.0.0" + } + extension_uris { + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + } + extensions { + extension_function { + name: "add" + } + } + referred_expr { + expression { + scalar_function { + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + } + options { + name: "overflow" + preference: "ERROR" + } + output_type { + i32 { + nullability: NULLABILITY_NULLABLE + } + } + } + } + output_names: "total" + } + base_schema { + names: "x" + names: "y" + struct { + types { + i32 { + nullability: NULLABILITY_NULLABLE + } + } + types { + i32 { + nullability: NULLABILITY_NULLABLE + } + } + } + } + +Executing Queries Using Substrait Extended Expressions +====================================================== + +Dataset supports executing queries using Substrait's `Extended Expression`_, +the expressions can be passed to the dataset scanner in the form of +:class:`pyarrow.substrait.BoundExpressions` + +.. code-block:: python + + import pyarrow.dataset as ds + import pyarrow.substrait as pa_substrait + + # Use substrait-python to create the queries + from substrait import proto + + dataset = ds.dataset("./data/index-0.parquet") + substrait_schema = pa_substrait.serialize_schema(dataset.schema).to_pysubstrait() + + # SELECT project_name FROM dataset WHERE project_name = 'pyarrow' + + projection = proto.ExtendedExpression(referred_expr=[ + {"expression": {"selection": {"direct_reference": {"struct_field": {"field": 0}}}}, + "output_names": ["project_name"]} + ]) + projection.MergeFrom(substrait_schema) + + filtering = proto.ExtendedExpression( + extension_uris=[{"extension_uri_anchor": 99, "uri": "/functions_comparison.yaml"}], + extensions=[{"extension_function": {"extension_uri_reference": 99, "function_anchor": 199, "name": "equal:any1_any1"}}], + referred_expr=[ + {"expression": {"scalar_function": {"function_reference": 199, "arguments": [ + {"value": {"selection": {"direct_reference": {"struct_field": {"field": 0}}}}}, + {"value": {"literal": {"string": "pyarrow"}}} + ], "output_type": {"bool": {"nullability": False}}}}} + ] + ) + filtering.MergeFrom(substrait_schema) + + results = dataset.scanner( + columns=pa.substrait.BoundExpressions.from_substrait(projection), + filter=pa.substrait.BoundExpressions.from_substrait(filtering) + ).head(5) + + +.. code-block:: text + + project_name + 0 pyarrow + 1 pyarrow + 2 pyarrow + 3 pyarrow + 4 pyarrow + + +.. _`Substrait`: https://substrait.io/ +.. _`Substrait Python`: https://github.com/substrait-io/substrait-python +.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html +.. _`Extended Expression`: https://github.com/substrait-io/substrait/blob/main/site/docs/expressions/extended_expression.md diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index d39120934d5fd..658f6b6cac4b5 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2441,7 +2441,7 @@ cdef class Expression(_Weakrefable): ) @staticmethod - def from_substrait(object buffer not None): + def from_substrait(object message not None): """ Deserialize an expression from Substrait @@ -2453,7 +2453,7 @@ cdef class Expression(_Weakrefable): Parameters ---------- - buffer : bytes or Buffer + message : bytes or Buffer or a protobuf Message The Substrait message to deserialize Returns @@ -2461,7 +2461,7 @@ cdef class Expression(_Weakrefable): Expression The deserialized expression """ - expressions = _pas().deserialize_expressions(buffer).expressions + expressions = _pas().BoundExpressions.from_substrait(message).expressions if len(expressions) == 0: raise ValueError("Substrait message did not contain any expressions") if len(expressions) > 1: diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 3583a3213ccbc..b67c583c32d67 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -39,6 +39,11 @@ from pyarrow.util import _is_iterable, _is_path_like, _stringify_path from pyarrow._json cimport ParseOptions as JsonParseOptions from pyarrow._json cimport ReadOptions as JsonReadOptions +try: + import pyarrow.substrait as pa_substrait +except ImportError: + pa_substrait = None + _DEFAULT_BATCH_SIZE = 2**17 _DEFAULT_BATCH_READAHEAD = 16 @@ -272,6 +277,13 @@ cdef class Dataset(_Weakrefable): # at the moment only support filter requested_filter = options.get("filter") + if pa_substrait and isinstance(requested_filter, pa_substrait.BoundExpressions): + expressions = list(requested_filter.expressions.values()) + if len(expressions) != 1: + raise ValueError( + "Only one BoundExpressions with a single expression are supported") + new_options["filter"] = requested_filter = expressions[0] + current_filter = self._scan_options.get("filter") if requested_filter is not None and current_filter is not None: new_options["filter"] = current_filter & requested_filter @@ -282,7 +294,7 @@ cdef class Dataset(_Weakrefable): def scanner(self, object columns=None, - Expression filter=None, + object filter=None, int batch_size=_DEFAULT_BATCH_SIZE, int batch_readahead=_DEFAULT_BATCH_READAHEAD, int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, @@ -3410,6 +3422,9 @@ cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr, filter, pyarrow_wrap_schema(builder.schema())))) if columns is not None: + if pa_substrait and isinstance(columns, pa_substrait.BoundExpressions): + columns = columns.expressions + if isinstance(columns, dict): for expr in columns.values(): if not isinstance(expr, Expression): @@ -3490,7 +3505,7 @@ cdef class Scanner(_Weakrefable): @staticmethod def from_dataset(Dataset dataset not None, *, object columns=None, - Expression filter=None, + object filter=None, int batch_size=_DEFAULT_BATCH_SIZE, int batch_readahead=_DEFAULT_BATCH_READAHEAD, int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 067cb5f91681b..d9359c8e77d00 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -26,6 +26,13 @@ from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * +try: + import substrait as py_substrait +except ImportError: + py_substrait = None +else: + import substrait.proto # no-cython-lint + # TODO GH-37235: Fix exception handling cdef CDeclaration _create_named_table_provider( @@ -133,7 +140,7 @@ def run_query(plan, *, table_provider=None, use_threads=True): c_bool c_use_threads c_use_threads = use_threads - if isinstance(plan, bytes): + if isinstance(plan, (bytes, memoryview)): c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan)) elif isinstance(plan, Buffer): c_buf_plan = pyarrow_unwrap_buffer(plan) @@ -187,6 +194,105 @@ def _parse_json_plan(plan): return pyarrow_wrap_buffer(c_buf_plan) +class SubstraitSchema: + """A Schema encoded for Substrait usage. + + The SubstraitSchema contains a schema represented + both as a substrait ``NamedStruct`` and as an + ``ExtendedExpression``. + + The ``ExtendedExpression`` is available for cases where types + used by the schema require extensions to decode them. + In such case the schema will be the ``base_schema`` of the + ``ExtendedExpression`` and all extensions will be provided. + """ + + def __init__(self, schema, expression): + self.schema = schema + self.expression = expression + + def to_pysubstrait(self): + """Convert the schema to a substrait-python ExtendedExpression object.""" + if py_substrait is None: + raise ImportError("The 'substrait' package is required.") + return py_substrait.proto.ExtendedExpression.FromString(self.expression) + + +def serialize_schema(schema): + """ + Serialize a schema into a SubstraitSchema object. + + Parameters + ---------- + schema : Schema + The schema to serialize + + Returns + ------- + SubstraitSchema + The schema stored in a SubstraitSchema object. + """ + return SubstraitSchema( + schema=_serialize_namedstruct_schema(schema), + expression=serialize_expressions([], [], schema, allow_arrow_extensions=True) + ) + + +def _serialize_namedstruct_schema(schema): + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + shared_ptr[CBuffer] c_buffer + CConversionOptions c_conversion_options + CExtensionSet c_extensions + + with nogil: + c_res_buffer = SerializeSchema(deref(( schema).sp_schema), &c_extensions, c_conversion_options) + c_buffer = GetResultValue(c_res_buffer) + + return memoryview(pyarrow_wrap_buffer(c_buffer)) + + +def deserialize_schema(buf): + """ + Deserialize a ``NamedStruct`` Substrait message + or a SubstraitSchema object into an Arrow Schema object + + Parameters + ---------- + buf : Buffer or bytes or SubstraitSchema + The message to deserialize + + Returns + ------- + Schema + The deserialized schema + """ + cdef: + shared_ptr[CBuffer] c_buffer + CResult[shared_ptr[CSchema]] c_res_schema + shared_ptr[CSchema] c_schema + CConversionOptions c_conversion_options + CExtensionSet c_extensions + + if isinstance(buf, SubstraitSchema): + return deserialize_expressions(buf.expression).schema + + if isinstance(buf, (bytes, memoryview)): + c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) + elif isinstance(buf, Buffer): + c_buffer = pyarrow_unwrap_buffer(buf) + else: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") + + with nogil: + c_res_schema = DeserializeSchema( + deref(c_buffer), c_extensions, c_conversion_options) + c_schema = GetResultValue(c_res_schema) + + return pyarrow_wrap_schema(c_schema) + + def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False): """ Serialize a collection of expressions into Substrait @@ -245,7 +351,7 @@ def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False) with nogil: c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options) c_buffer = GetResultValue(c_res_buffer) - return pyarrow_wrap_buffer(c_buffer) + return memoryview(pyarrow_wrap_buffer(c_buffer)) cdef class BoundExpressions(_Weakrefable): @@ -290,6 +396,32 @@ cdef class BoundExpressions(_Weakrefable): self.init(bound_expressions) return self + @classmethod + def from_substrait(cls, message): + """ + Convert a Substrait message into a BoundExpressions object + + Parameters + ---------- + message : Buffer or bytes or protobuf Message + The message to convert to a BoundExpressions object + + Returns + ------- + BoundExpressions + The converted expressions, their names, and the bound schema + """ + if isinstance(message, (bytes, memoryview)): + return deserialize_expressions(message) + elif isinstance(message, Buffer): + return deserialize_expressions(message) + else: + try: + return deserialize_expressions(message.SerializeToString()) + except AttributeError: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes or protobuf Message, got '{type(message)}'") + def deserialize_expressions(buf): """ @@ -310,7 +442,7 @@ def deserialize_expressions(buf): CResult[CBoundExpressions] c_res_bound_exprs CBoundExpressions c_bound_exprs - if isinstance(buf, bytes): + if isinstance(buf, (bytes, memoryview)): c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) elif isinstance(buf, Buffer): c_buffer = pyarrow_unwrap_buffer(buf) diff --git a/python/pyarrow/includes/common.pxd b/python/pyarrow/includes/common.pxd index 044dd0333f323..9297436c1cf8c 100644 --- a/python/pyarrow/includes/common.pxd +++ b/python/pyarrow/includes/common.pxd @@ -173,3 +173,12 @@ cdef inline object PyObject_to_object(PyObject* o): cdef object result = o cpython.Py_DECREF(result) return result + + +cdef extern from "" namespace "std" nogil: + cdef cppclass cpp_string_view "std::string_view": + string_view() + string_view(const char*) + size_t size() + bint empty() + const char* data() diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index c41f4c05d3a77..865568e2ba6f1 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -45,6 +45,20 @@ cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" no cdef extern from "arrow/engine/substrait/extension_set.h" \ namespace "arrow::engine" nogil: + cdef struct CSubstraitId "arrow::engine::Id": + cpp_string_view uri + cpp_string_view name + + cdef struct CExtensionSetTypeRecord "arrow::engine::ExtensionSet::TypeRecord": + CSubstraitId id + shared_ptr[CDataType] type + + cdef cppclass CExtensionSet "arrow::engine::ExtensionSet": + CExtensionSet() + unordered_map[uint32_t, cpp_string_view]& uris() + CResult[uint32_t] EncodeType(const CDataType&) + CResult[CExtensionSetTypeRecord] DecodeType(uint32_t) + cdef cppclass ExtensionIdRegistry: std_vector[c_string] GetSupportedSubstraitFunctions() @@ -68,6 +82,15 @@ cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogi CResult[CBoundExpressions] DeserializeExpressions( const CBuffer& serialized_expressions) + CResult[shared_ptr[CBuffer]] SerializeSchema( + const CSchema &schema, CExtensionSet* extension_set, + const CConversionOptions& conversion_options) + + CResult[shared_ptr[CSchema]] DeserializeSchema( + const CBuffer& serialized_schema, const CExtensionSet& extension_set, + const CConversionOptions& conversion_options) + + cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan( const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry, diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index a2b217f4936c5..db2c3a96a1955 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -21,7 +21,10 @@ get_supported_functions, run_query, deserialize_expressions, - serialize_expressions + serialize_expressions, + deserialize_schema, + serialize_schema, + SubstraitSchema ) except ImportError as exc: raise ImportError( diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index ab181590914d3..ce03da22962c3 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -5663,3 +5663,37 @@ def test_make_write_options_error(): msg = "make_write_options\\(\\) takes exactly 0 positional arguments" with pytest.raises(TypeError, match=msg): pformat.make_write_options(43) + + +def test_scanner_from_substrait(dataset): + try: + import pyarrow.substrait as ps + except ImportError: + pytest.skip("substrait NOT enabled") + + # SELECT str WHERE i64 = 4 + projection = (b'\nS\x08\x0c\x12Ohttps://github.com/apache/arrow/blob/main/format' + b'/substrait/extension_types.yaml\x12\t\n\x07\x08\x0c\x1a\x03u64' + b'\x12\x0b\n\t\x08\x0c\x10\x01\x1a\x03u32\x1a\x0f\n\x08\x12\x06' + b'\n\x04\x12\x02\x08\x02\x1a\x03str"i\n\x03i64\n\x03f64\n\x03str' + b'\n\x05const\n\x06struct\n\x01a\n\x01b\n\x05group\n\x03key' + b'\x127\n\x04:\x02\x10\x01\n\x04Z\x02\x10\x01\n\x04b\x02\x10' + b'\x01\n\x04:\x02\x10\x01\n\x11\xca\x01\x0e\n\x04:\x02\x10\x01' + b'\n\x04b\x02\x10\x01\x18\x01\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01') + filtering = (b'\n\x1e\x08\x06\x12\x1a/functions_comparison.yaml\nS\x08\x0c\x12' + b'Ohttps://github.com/apache/arrow/blob/main/format' + b'/substrait/extension_types.yaml\x12\x18\x1a\x16\x08\x06\x10\xc5' + b'\x01\x1a\x0fequal:any1_any1\x12\t\n\x07\x08\x0c\x1a\x03u64\x12' + b'\x0b\n\t\x08\x0c\x10\x01\x1a\x03u32\x1a\x1f\n\x1d\x1a\x1b\x08' + b'\xc5\x01\x1a\x04\n\x02\x10\x02"\x08\x1a\x06\x12\x04\n\x02\x12\x00' + b'"\x06\x1a\x04\n\x02(\x04"i\n\x03i64\n\x03f64\n\x03str\n\x05const' + b'\n\x06struct\n\x01a\n\x01b\n\x05group\n\x03key\x127\n\x04:\x02' + b'\x10\x01\n\x04Z\x02\x10\x01\n\x04b\x02\x10\x01\n\x04:\x02\x10' + b'\x01\n\x11\xca\x01\x0e\n\x04:\x02\x10\x01\n\x04b\x02\x10\x01' + b'\x18\x01\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01') + + result = dataset.scanner( + columns=ps.BoundExpressions.from_substrait(projection), + filter=ps.BoundExpressions.from_substrait(filtering) + ).to_table() + assert result.to_pydict() == {'str': ['4', '4']} diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 01d468cd9e9cc..fcd1c8d48c5fc 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -105,7 +105,7 @@ def test_run_query_input_types(tmpdir, query): # Otherwise error for invalid query msg = "ParseFromZeroCopyStream failed for substrait.Plan" - with pytest.raises(OSError, match=msg): + with pytest.raises(ArrowInvalid, match=msg): substrait.run_query(query) @@ -1077,3 +1077,44 @@ def test_serializing_udfs(): assert schema == returned.schema assert len(returned.expressions) == 1 assert str(returned.expressions["expr"]) == str(exprs[0]) + + +def test_serializing_schema(): + substrait_schema = b'\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01' + expected_schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.string()) + ]) + returned = pa.substrait.deserialize_schema(substrait_schema) + assert expected_schema == returned + + arrow_substrait_schema = pa.substrait.serialize_schema(returned) + assert arrow_substrait_schema.schema == substrait_schema + + returned = pa.substrait.deserialize_schema(arrow_substrait_schema) + assert expected_schema == returned + + returned = pa.substrait.deserialize_schema(arrow_substrait_schema.schema) + assert expected_schema == returned + + returned = pa.substrait.deserialize_expressions(arrow_substrait_schema.expression) + assert returned.schema == expected_schema + + +def test_bound_expression_from_Message(): + class FakeMessage: + def __init__(self, expr): + self.expr = expr + + def SerializeToString(self): + return self.expr + + # SELECT project_release, project_version + message = (b'\x1a\x1b\n\x08\x12\x06\n\x04\x12\x02\x08\x01\x1a\x0fproject_release' + b'\x1a\x19\n\x06\x12\x04\n\x02\x12\x00\x1a\x0fproject_version' + b'"0\n\x0fproject_version\n\x0fproject_release' + b'\x12\x0c\n\x04:\x02\x10\x01\n\x04b\x02\x10\x01') + exprs = pa.substrait.BoundExpressions.from_substrait(FakeMessage(message)) + assert len(exprs.expressions) == 2 + assert 'project_release' in exprs.expressions + assert 'project_version' in exprs.expressions