Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ibis-substrait join operator example #2

Merged
merged 4 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
build/
.cache/
__pycache__/
venv
*.venv
10 changes: 10 additions & 0 deletions python/test/ibis/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

import pytest

from ibis_substrait.compiler.core import SubstraitCompiler


@pytest.fixture
def compiler():
return SubstraitCompiler()
111 changes: 111 additions & 0 deletions python/test/ibis/test_ibis_substrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from collections import OrderedDict

import ibis
import pytest
from google.protobuf import json_format

from ibis_substrait.compiler.translate import translate
from ibis_substrait.proto.substrait import type_pb2 as stt
from ibis_substrait.proto.substrait.algebra_pb2 import Expression, Rel

NULLABILITY_NULLABLE = stt.Type.Nullability.NULLABILITY_NULLABLE
NULLABILITY_REQUIRED = stt.Type.Nullability.NULLABILITY_REQUIRED

# Test adopted from upstream ibis-substrait unit(s)
# https://github.com/ibis-project/ibis-substrait/blob/main/ibis_substrait/tests/compiler/test_compiler.py

@pytest.fixture
def t0():
return ibis.table(
[
("full_name", "string"),
("age", "int64"),
("ts", "timestamp('UTC')"),
("delta", "interval"),
]
)

@pytest.fixture
def t1():
return ibis.table(
[
("full_name", "string"),
("age", "int64"),
("ts", "timestamp('UTC')"),
("delta", "interval"),
]
)


def to_dict(message):
"""Print Protobuf message as python dictionary object."""
return json_format.MessageToDict(message)


def test_join(t0, t1, compiler):
"""A walkthrough of a join expression in Substrait."""
expr = (
t0.left_join(t1, t0.age == t1.age)
)
result = translate(expr, compiler)

# This plan is a "volcano" style plan meant for bottoms-up execution.
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
# As a result, we top-level operation in the relation is the final projection
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/substrait-io/substrait/blob/main/proto/substrait/algebra.proto
#
# TODO(bsarden): Find out which logical plan optimizers are used.
assert(result.WhichOneof("rel_type") == "project")
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
input_: Rel = result.project.input
assert(input_.WhichOneof("rel_type") == "join")

join: Rel = input_.join
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From here it seems like Rel defines a relation and its type can be determined by rel_type(project, join). Is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I believe we can treat the Rel as semantically equivalent to an OpInterface in TableGen, where each *Rel variant would just be an mlir::Operation that inherits from the Relational interface (or trait).


# The `Expression` message type describes functions / arguments to run on
# the given operator. Each `Expression` defines a Relational Expression Type
# (`rex_type``), which maps to a broad categorization of the underlying
# function category (e.g., `ScalarFunction`, `WindowFunction`, `IfThen`,
# etc.).
join_expr: Expression = join.expression
assert(join_expr.WhichOneof("rex_type") == "scalar_function")
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
scalar_func = join_expr.scalar_function

# Each `rex_type` function breaks down into their own `protobuf.Message` type,
# but we will study the `ScalarFunction` as an example, since they are all pretty
# similar. Each `ScalarFunction` maps to:
# 1. A `function_reference`: which represents a "pointer" to a uniquely identifiable
# operator ID that has been [registered][0] with the corresponding `Plan` type.
# These functions are serialized / registered with the `Plan` object through the
# definition of an `Extension` (see link above) and are often referred to as `*_anchor`
# in the specification.
# 2. A list of `FunctionArguments`: these include input type specifications, and can
# also be the result of another `Expression`.
# 3. A `Type` definition for the output: Currently there is only one `output_type` per
# operator that is supported.
#
# [0]: https://github.com/ibis-project/ibis-substrait/blob/main/ibis_substrait/compiler/core.py#L53-L80
assert len(scalar_func.args) == 2, "a join should always have two operands"

# We verify that the inputs are `FieldReference` expressions, because we are grabbing
# the `age` column on both tables. To check this, we make sure that both selection ordinal values
# match, since the two tables are equivalent.
arg0, arg1 = scalar_func.args
assert arg0.WhichOneof("rex_type") == "selection"
assert arg1.WhichOneof("rex_type") == "selection"
sel0, sel1 = arg0.selection, arg1.selection
sel0_col_id = sel0.direct_reference.struct_field.field
sel1_col_id = sel1.direct_reference.struct_field.field
assert sel0_col_id
assert sel1_col_id

# The left / right sides of the join equate to Table Scan operations
#
# Which contains a struct about field names, dtypes, and whether a field
# is nullable.
left, right = input_.join.left, input_.join.right
assert left.WhichOneof("rel_type") == "read"
assert right.WhichOneof("rel_type") == "read"

# with open("test_join.pb", "wb") as f:
# f.write(result.SerializeToString())
js = to_dict(result)
assert js
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ python-dateutil==2.8.2
pytz==2021.3
six==1.16.0
tomli==2.0.1
ibis-framework==3.0.2
ibis-substrait==2.7.0