From a1dc616e1e53fb81b45e5fc1df4e52a331841a81 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Wed, 11 Dec 2024 23:02:30 -0800 Subject: [PATCH] Simplify contiguity computation (#315) This PR uses sympy CSE and integral assumptions to simplify complex expressions. --------- Signed-off-by: Harsh Menon --- iree/turbine/kernel/wave/utils.py | 83 +++++++++++++++++++------------ 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 4ba7a7bf..aa190671 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -1211,6 +1211,47 @@ def transform_floor(expr): return sympy.simplify(expr) +def approximate_difference( + expr: IndexExpr, vars: list[IndexSymbol], elements_per_thread: int +) -> bool: + """ + During the contiguity check, we take a unit step in the fastest changing + dimension (j -> j + 1) and we compute f(j + 1) - f(j) to see if it is 1. + In general, we will end up with expressions of the form + g(x + eps) - g(x) where x = h(j) and eps is a rational of the form 1/q. + We can use q to determine if the mapping is contiguous as follows + + if q is divisible by elements_per_thread (dimensions where we have not applied the unit step), or + if eps is 1 (corresponds to the dimension where we have applied the unit step) + then the mapping is contiguous. + + The mapping function f(j) will be non-linear in general, and so the difference + of 1 will be transformed to different constant values based on the function. + But, if we recover a value of 1, we can assume that the function preserves + the difference. + + In this function we do a pre-order traversal of the expression to obtain + the value of the constant eps. + """ + if expr.is_number: + return expr + new_vars, new_exprs = sympy.cse(expr) + new_expr = new_exprs[0] if new_vars else expr + new_vars = [x[0] for x in new_vars] if new_vars else vars + for arg in sympy.preorder_traversal(new_expr): + if isinstance(arg, sympy.Add): + if all([x in arg.args for x in new_vars]): + constant = [x for x in arg.args if x not in new_vars][0] + if not isinstance(constant, sympy.Rational): + return expr + if constant.p != 1: + return expr + if constant.q == 1: + return 1 + return 0 if constant.q % elements_per_thread == 0 else expr + return expr + + def check_is_mapping_contiguous( mapping: IndexMapping, symbolc_shape: tuple[IndexExpr, ...], @@ -1239,41 +1280,19 @@ def check_is_mapping_contiguous( index_mapping = mapping.map_output_indices(symbolc_shape) index_mapping = tuple(subs_idxc(i) for i in index_mapping) - iters = mapping.iters - subs = [(sym, expr.start) for sym, expr in zip(iters.keys(), index.values())] + subs = [(sym, sym + int(i == len(iters) - 1)) for i, sym in enumerate(iters)] + diff = [ + approximate_difference( + index_mapping[i].subs(subs) - index_mapping[i], + list(iters.keys())[-1:], + elements_per_thread, + ) + for i in range(len(index_mapping)) + ] - # Iterate over elements_per_thread end check if every subsequent read have - # diff 1 in fastest changing dim and 0s in every other. expected_diff = [0] * len(index_mapping) expected_diff[-1] = 1 - # Assume fastest changing dim increments in elements_per_thread between individual ops, - # This is tranform exressions floor(x/a + 1/b) into floor(floor(x/ept)*ept/a + 1/b) - # which is required for further sympy simplifications. - subs[-1] = (subs[-1][0], (subs[-1][1] // elements_per_thread) * elements_per_thread) - - # Construct indices for vector element 0 - prev_indices = _get_start_indices( - {key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)} - ) - - # Construct indices for vector elements [1, 2, ..., elements_per_thread - 1] - # and compare with previous ones. - for i in range(1, elements_per_thread, 1): - # Increment fastest changing dim in unmapped index by one and apply mapping. - subs[-1] = (subs[-1][0], subs[-1][1] + 1) - next_result_index = { - key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping) - } - next_indices = _get_start_indices(next_result_index) - - # Compute diff for every mapped dim. - diff = [_simplify_sympy_expr(a - b) for a, b in zip(next_indices, prev_indices)] - if diff != expected_diff: - return False - - prev_indices = next_indices - - return True + return diff == expected_diff