Skip to content

Commit

Permalink
Simplify contiguity computation (#315)
Browse files Browse the repository at this point in the history
This PR uses sympy CSE and integral assumptions
to simplify complex expressions.

---------

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Dec 12, 2024
1 parent 71eb1c8 commit a1dc616
Showing 1 changed file with 51 additions and 32 deletions.
83 changes: 51 additions & 32 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,47 @@ def transform_floor(expr):
return sympy.simplify(expr)


def approximate_difference(
expr: IndexExpr, vars: list[IndexSymbol], elements_per_thread: int
) -> bool:
"""
During the contiguity check, we take a unit step in the fastest changing
dimension (j -> j + 1) and we compute f(j + 1) - f(j) to see if it is 1.
In general, we will end up with expressions of the form
g(x + eps) - g(x) where x = h(j) and eps is a rational of the form 1/q.
We can use q to determine if the mapping is contiguous as follows
if q is divisible by elements_per_thread (dimensions where we have not applied the unit step), or
if eps is 1 (corresponds to the dimension where we have applied the unit step)
then the mapping is contiguous.
The mapping function f(j) will be non-linear in general, and so the difference
of 1 will be transformed to different constant values based on the function.
But, if we recover a value of 1, we can assume that the function preserves
the difference.
In this function we do a pre-order traversal of the expression to obtain
the value of the constant eps.
"""
if expr.is_number:
return expr
new_vars, new_exprs = sympy.cse(expr)
new_expr = new_exprs[0] if new_vars else expr
new_vars = [x[0] for x in new_vars] if new_vars else vars
for arg in sympy.preorder_traversal(new_expr):
if isinstance(arg, sympy.Add):
if all([x in arg.args for x in new_vars]):
constant = [x for x in arg.args if x not in new_vars][0]
if not isinstance(constant, sympy.Rational):
return expr
if constant.p != 1:
return expr
if constant.q == 1:
return 1
return 0 if constant.q % elements_per_thread == 0 else expr
return expr


def check_is_mapping_contiguous(
mapping: IndexMapping,
symbolc_shape: tuple[IndexExpr, ...],
Expand Down Expand Up @@ -1239,41 +1280,19 @@ def check_is_mapping_contiguous(
index_mapping = mapping.map_output_indices(symbolc_shape)

index_mapping = tuple(subs_idxc(i) for i in index_mapping)

iters = mapping.iters

subs = [(sym, expr.start) for sym, expr in zip(iters.keys(), index.values())]
subs = [(sym, sym + int(i == len(iters) - 1)) for i, sym in enumerate(iters)]
diff = [
approximate_difference(
index_mapping[i].subs(subs) - index_mapping[i],
list(iters.keys())[-1:],
elements_per_thread,
)
for i in range(len(index_mapping))
]

# Iterate over elements_per_thread end check if every subsequent read have
# diff 1 in fastest changing dim and 0s in every other.
expected_diff = [0] * len(index_mapping)
expected_diff[-1] = 1

# Assume fastest changing dim increments in elements_per_thread between individual ops,
# This is tranform exressions floor(x/a + 1/b) into floor(floor(x/ept)*ept/a + 1/b)
# which is required for further sympy simplifications.
subs[-1] = (subs[-1][0], (subs[-1][1] // elements_per_thread) * elements_per_thread)

# Construct indices for vector element 0
prev_indices = _get_start_indices(
{key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)}
)

# Construct indices for vector elements [1, 2, ..., elements_per_thread - 1]
# and compare with previous ones.
for i in range(1, elements_per_thread, 1):
# Increment fastest changing dim in unmapped index by one and apply mapping.
subs[-1] = (subs[-1][0], subs[-1][1] + 1)
next_result_index = {
key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)
}
next_indices = _get_start_indices(next_result_index)

# Compute diff for every mapped dim.
diff = [_simplify_sympy_expr(a - b) for a, b in zip(next_indices, prev_indices)]
if diff != expected_diff:
return False

prev_indices = next_indices

return True
return diff == expected_diff

0 comments on commit a1dc616

Please sign in to comment.