Skip to content

Commit

Permalink
fix(field_index): get field index w.r.t. pre-join table schemata (#1078)
Browse files Browse the repository at this point in the history
* fix(field_index): get field index w.r.t. pre-join table schemata

JoinChains provide the schema of the joined table (which is great for Ibis)
but for substrait we need the Field index computed with respect to
the original table schemata.  In practice, this means rolling through
the tables in a JoinChain and computing the field index _without_
removing the join key

Given
Table 1
  a: int
  b: int

Table 2
  a: int
  c: int

JoinChain[r0]
 JoinLink[inner, r1]
   r0.a == r1.a
 values:
   a: r0.a
   b: r0.b
   c: r1.c

If we ask for the field index of `c`, the JoinChain schema will give
us an index of `2`, but it should be `3` because

 0: table 1 a
 1: table 1 b
 2: table 2 a
 3: table 2 c

So now we pull out the correct JoinReference object and use that to
index into the tables in the JoinChain and offset by the length of the
schema of those preceding tables.

* test(snapshots): update snapshots for fixed join indexing

* fix: apply suggestions from review

Co-authored-by: Phillip Cloud <[email protected]>

---------

Co-authored-by: Phillip Cloud <[email protected]>
  • Loading branch information
gforsyth and cpcloud authored Jul 29, 2024
1 parent 3a8b0db commit 7095b19
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 19 deletions.
54 changes: 52 additions & 2 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,58 @@ def table_column(
else:
base_offset = 0

schema = op.rel.schema
relative_offset = schema._name_locs[op.name]
if isinstance(op.rel, ops.JoinChain):
# JoinChains provide the schema of the joined table (which is great for Ibis)
# but for substrait we need the Field index computed with respect to
# the original table schemas. In practice, this means rolling through
# the tables in a JoinChain and computing the field index _without_
# removing the join key
#
# Given
# Table 1
# a: int
# b: int
#
# Table 2
# a: int
# c: int
#
# JoinChain[r0]
# JoinLink[inner, r1]
# r0.a == r1.a
# values:
# a: r0.a
# b: r0.b
# c: r1.c
#
# If we ask for the field index of `c`, the JoinChain schema will give
# us an index of `2`, but it should be `3` because
#
# 0: table 1 a
# 1: table 1 b
# 2: table 2 a
# 3: table 2 c
#

# List of join reference objects
join_tables = op.rel.tables
# Join reference containing the field we care about
field_table = op.rel.values.get(op.name).rel
# Index of that join reference in the list of join references
field_table_index = join_tables.index(field_table)

# Offset by the number of columns in each preceding table
join_table_offset = sum(
len(join_tables[i].schema) for i in range(field_table_index)
)
# Then add on the index of the column in the table
# Also in the event of renaming due to join collisions, resolve
# the renamed column to the original name so we can pull it off the parent table
orig_name = op.rel.values[op.name].name
relative_offset = join_table_offset + field_table.schema._name_locs[orig_name]
else:
schema = op.rel.schema
relative_offset = schema._name_locs[op.name]
absolute_offset = base_offset + relative_offset
return stalg.Expression(
selection=stalg.Expression.FieldReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@
"selection": {
"directReference": {
"structField": {
"field": 1
"field": 45
}
},
"rootReference": {}
Expand Down Expand Up @@ -764,7 +764,9 @@
"value": {
"selection": {
"directReference": {
"structField": {}
"structField": {
"field": 41
}
},
"rootReference": {}
}
Expand Down Expand Up @@ -810,7 +812,7 @@
"selection": {
"directReference": {
"structField": {
"field": 1
"field": 45
}
},
"rootReference": {}
Expand Down Expand Up @@ -842,7 +844,9 @@
"value": {
"selection": {
"directReference": {
"structField": {}
"structField": {
"field": 41
}
},
"rootReference": {}
}
Expand Down Expand Up @@ -882,7 +886,7 @@
"selection": {
"directReference": {
"structField": {
"field": 2
"field": 17
}
},
"rootReference": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@
"selection": {
"directReference": {
"structField": {
"field": 3
"field": 54
}
},
"rootReference": {}
Expand Down Expand Up @@ -990,7 +990,7 @@
"selection": {
"directReference": {
"structField": {
"field": 4
"field": 36
}
},
"rootReference": {}
Expand Down Expand Up @@ -1034,7 +1034,7 @@
"selection": {
"directReference": {
"structField": {
"field": 5
"field": 4
}
},
"rootReference": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@
"selection": {
"directReference": {
"structField": {
"field": 3
"field": 29
}
},
"rootReference": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@
"selection": {
"directReference": {
"structField": {
"field": 1
"field": 25
}
},
"rootReference": {}
Expand Down Expand Up @@ -588,7 +588,7 @@
"selection": {
"directReference": {
"structField": {
"field": 2
"field": 19
}
},
"rootReference": {}
Expand All @@ -600,7 +600,7 @@
"selection": {
"directReference": {
"structField": {
"field": 3
"field": 18
}
},
"rootReference": {}
Expand Down Expand Up @@ -630,7 +630,7 @@
"selection": {
"directReference": {
"structField": {
"field": 6
"field": 33
}
},
"rootReference": {}
Expand Down Expand Up @@ -817,7 +817,9 @@
"value": {
"selection": {
"directReference": {
"structField": {}
"structField": {
"field": 7
}
},
"rootReference": {}
}
Expand Down Expand Up @@ -854,7 +856,7 @@
"selection": {
"directReference": {
"structField": {
"field": 4
"field": 9
}
},
"rootReference": {}
Expand Down Expand Up @@ -1063,7 +1065,9 @@
"value": {
"selection": {
"directReference": {
"structField": {}
"structField": {
"field": 7
}
},
"rootReference": {}
}
Expand Down Expand Up @@ -1100,7 +1104,7 @@
"selection": {
"directReference": {
"structField": {
"field": 4
"field": 9
}
},
"rootReference": {}
Expand Down
52 changes: 52 additions & 0 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,55 @@ def test_groupby_multiple_keys(compiler):
# There should be one grouping with two separate expressions inside
assert len(plan.aggregate.groupings) == 1
assert len(plan.aggregate.groupings[0].grouping_expressions) == 2


def test_join_chain_indexing_in_group_by(compiler):
t1 = ibis.table([("a", int), ("b", int)], name="t1")
t2 = ibis.table([("a", int), ("c", int)], name="t2")
t3 = ibis.table([("a", int), ("d", int)], name="t3")
t4 = ibis.table([("a", int), ("c", int)], name="t4")

join_chain = t1.join(t2, "a").join(t3, "a").join(t4, "a")
# Indexing for chained join
# t1: a: 0
# t1: b: 1
# t2: a: 2
# t2: c: 3
# t3: a: 4
# t3: d: 5
# t4: a: 6
# t4: c: 7

expr = join_chain.group_by("d").count().select("d")
plan = compiler.compile(expr)
# Check that the field index for the group_by key is correctly indexed
assert (
plan.relations[0]
.root.input.project.input.aggregate.groupings[0]
.grouping_expressions[0]
.selection.direct_reference.struct_field.field
== 5
)

expr = join_chain.group_by("c").count().select("c")
plan = compiler.compile(expr)
# Check that the field index for the group_by key is correctly indexed
assert (
plan.relations[0]
.root.input.project.input.aggregate.groupings[0]
.grouping_expressions[0]
.selection.direct_reference.struct_field.field
== 3
)

# Group-by on a column that will be renamed by the joinchain
expr = join_chain.group_by(t4.c).count().select("c")
plan = compiler.compile(expr)
# Check that the field index for the group_by key is correctly indexed
assert (
plan.relations[0]
.root.input.project.input.aggregate.groupings[0]
.grouping_expressions[0]
.selection.direct_reference.struct_field.field
== 7
)

0 comments on commit 7095b19

Please sign in to comment.