diff --git a/dace/sdfg/work_depth_analysis/extrapolation.py b/dace/sdfg/work_depth_analysis/extrapolation.py new file mode 100644 index 0000000000..0a38805bae --- /dev/null +++ b/dace/sdfg/work_depth_analysis/extrapolation.py @@ -0,0 +1,233 @@ +from scipy.optimize import curve_fit +import numpy as np +import matplotlib.pyplot as plt + +def print_scores(scores): + for k, v in scores.items(): + print(k.name, v) + +class Logistic: + + def __init__(self, name): + self.x_name = name + self.name = 'Logistic' + + def f(x, a, b, c): + return b / (c + np.exp(-a * x)) + + def fit(self, x, y): + param, _ = curve_fit(Logistic.f, x, y, maxfev=10000) + self.a, self.b, self.c = param + + def predict(self, x): + return Logistic.f(x, self.a, self.b, self.c) + + def to_string(self): + return f'{self.b} / ({self.c} + exp({-self.a} * {self.x_name}))' + +class Log: + def __init__(self, name): + self.x_name = name + self.name = 'Log' + + + def f(x, a, b): + return a * np.log(x) + b + + def fit(self, x, y): + param, _ = curve_fit(Log.f, x, y, maxfev=2500) + self.a, self.b = param + + def predict(self, x): + return Log.f(x, self.a, self.b) + + def to_string(self): + return f'{self.a} * log({self.x_name}) + {self.b}' + +class Plateau: + def __init__(self, name): + self.x_name = name + self.name = 'Plateau' + + + def f(x, a, b): + return (a * x) / (x + b) + + def fit(self, x, y): + param, _ = curve_fit(Plateau.f, x, y, maxfev=2500) + self.a, self.b = param + + def predict(self, x): + return Plateau.f(x, self.a, self.b) + + def to_string(self): + return f'({self.a} * {self.x_name}) / ({self.x_name} + {self.b})' + + +class Poly: + def __init__(self, name): + self.x_name = name + self.name = 'Poly' + + + def f(x, a, b): + return a * x + b + + def fit(self, x, y): + param, _ = curve_fit(Poly.f, x, y, maxfev=2500) + self.a, self.b = param + + def predict(self, x): + return Poly.f(x, self.a, self.b) + + def to_string(self): + return f'{self.a} * {self.x_name} + {self.b}' + +class Sqrt: + def __init__(self, name): + self.x_name = name + self.name = 'Sqrt' + + def f(x, a, b): + return a * np.sqrt(x) + b + + def fit(self, x, y): + param, _ = curve_fit(Sqrt.f, x, y, maxfev=2500) + self.a, self.b = param + + def predict(self, x): + return Sqrt.f(x, self.a, self.b) + + def to_string(self): + return f'{self.a} * sqrt({self.x_name}) + {self.b}' + +class Exponential: + def __init__(self, name): + self.x_name = name + self.name = 'Exponential' + + def f(x, a, b): + return a * np.exp(x) + b + + def fit(self, x, y): + param, _ = curve_fit(Exponential.f, x, y, maxfev=2500) + self.a, self.b = param + + def predict(self, x): + return Exponential.f(x, self.a, self.b) + + def to_string(self): + return f'{self.a} * np.exp({self.x_name}) + {self.b}' + +class Sin: + def __init__(self, name): + self.x_name = name + self.name = 'Sin' + + def f(x, a, b, c, d): + return a * np.sin(b*x + c) + d + + def fit(self, x, y): + param, _ = curve_fit(Sin.f, x, y, maxfev=2500) + self.a, self.b, self.c, self.d = param + + def predict(self, x): + return Sin.f(x, self.a, self.b, self.c, self.d) + + def to_string(self): + return f'{self.a} * sin({self.b}*{self.x_name} + {self.c}) + {self.d}' + +class Constant: + def __init__(self, name): + self.x_name = name + self.name = 'Sin' + + def f(x, a): + return np.ones_like(x) * a + + def fit(self, x, y): + param, _ = curve_fit(Constant.f, x, y, maxfev=2500) + self.a = param + + def predict(self, x): + return Constant.f(x, self.a) + + def to_string(self): + return f'{self.a}' + + + +def extrapolate(op_in_map, range_symbol): + """ + For each key in op_in_map (aka for each SDFG element), we have a list of measured data points y + for the values in x_values. + Now we fit a curve and return the best function found via leave-one-out cross validation. + """ + + if len(range_symbol) == 1: + # only 1 independent variable + symbol_name = list(range_symbol.keys())[0] + x = range_symbol[symbol_name].to_list() + + models = [Logistic(symbol_name), Log(symbol_name), Plateau(symbol_name), Poly(symbol_name), Sqrt(symbol_name), + Exponential(symbol_name), Sin(symbol_name), Constant(symbol_name)] + + for element, y in op_in_map.items(): + all_zero = True + for q in y: + if q != 0.0: + all_zero = False + break + if all_zero: + op_in_map[element] = str(0) + continue + scores = {} + for model in models: + error_sum = 0 + for left_out in range(len(x)): + xx = list(x) + test_x = xx.pop(left_out) + yy = list(y) + test_y = yy.pop(left_out) + try: + model.fit(xx, yy) + except RuntimeError: + # triggered if no fit was found --> give huge error + error_sum += 999999999 + # predict on left out sample + pred = model.predict(test_x) + # squared_error = np.square(pred - test_y) + # error_sum += squared_error + root_error = np.sqrt(np.abs(float(pred - test_y))) + error_sum += root_error + + mean_error = error_sum / len(x) + scores[model] = mean_error + + + + # find model with least error + min_model = model + min_error = mean_error + for model, error in scores.items(): + if error < min_error: + min_error = error + min_model = model + + # fit best model to all points and plot + min_model.fit(x, y) + fig, ax = plt.subplots() # Create a figure containing a single axes. + ax.scatter(x, y) + s = 1 + t = x[-1] + 3 + q = np.linspace(s, t, num=(t-s)*5) + r = min_model.predict(q) + ax.plot(q, r, label=min_model.to_string()) + + fig.tight_layout() + plt.show() + + op_in_map[element] = min_model.to_string() + + else: + print('2 independent variables not implemented yet') \ No newline at end of file diff --git a/dace/sdfg/work_depth_analysis/op_in_helpers.py b/dace/sdfg/work_depth_analysis/op_in_helpers.py index 88768f5fdb..f5bb637e1a 100644 --- a/dace/sdfg/work_depth_analysis/op_in_helpers.py +++ b/dace/sdfg/work_depth_analysis/op_in_helpers.py @@ -3,6 +3,8 @@ Further, contains class AccessStack which which corresponds to the stack used to compute the stack distance. """ from dace.data import Array +import sympy as sp +from collections import deque class CacheLineTracker: @@ -25,11 +27,17 @@ def cache_line_id(self, name: str, access: [int], mapping): one_d_index = 0 for dim in range(len(access)): i = access[dim] - one_d_index += (i + arr.offset[dim].subs(mapping)) * arr.strides[dim].subs(mapping) + one_d_index += (i + sp.sympify(arr.offset[dim]).subs(mapping)) * sp.sympify(arr.strides[dim]).subs(mapping) # divide by L to get the cache line id return self.start_lines[name] + (one_d_index * arr.dtype.bytes) // self.L + def copy(self): + new_clt = CacheLineTracker(self.L) + new_clt.array_info = dict(self.array_info) + new_clt.start_lines = dict(self.start_lines) + new_clt.next_free_line = self.next_free_line + return new_clt class Node: @@ -46,9 +54,11 @@ class AccessStack: # TODO: this can be optimised such that the stack is never larger than C, since all elements deeper than C are misses # anyway. (then we cannot distinguish compulsory misses from capacity misses though) - def __init__(self) -> None: + def __init__(self, C) -> None: self.top = None self.num_calls = 0 + self.length = 0 + self.C = C def touch(self, id): self.num_calls += 1 @@ -73,9 +83,65 @@ def touch(self, id): curr = curr.next distance += 1 + # shorten the stack if distance >= C + # if distance >= self.C and curr is not None: + # curr.next = None + if not found: # we accessed this cache line for the first time ever self.top = Node(id, self.top) + self.length += 1 distance = -1 - return distance \ No newline at end of file + return distance + + def compare_cache(self, other): + "Returns True if the same data resides in cache with the same LRU order" + s = self.top + o = other.top + dist = 0 + while s is not None and o is not None and dist < self.C: + dist += 1 + if s != o: + return False + s = s.next + o = o.next + if s is None and o is not None: + return False + if s is not None and o is None: + return False + + return True + + def in_cache_as_list(self): + """ + Returns a list of cache ids currently in cache. Index 0 is the most recently used. + """ + res = deque() + curr = self.top + dist = 0 + while curr is not None and dist < self.C: + res.append(curr.v) + curr = curr.next + dist += 1 + return res + + def debug_print(self): + # prints the whole stack + print('\n') + curr = self.top + while curr is not None: + print(curr.v, end=', ') + curr = curr.next + print('\n') + + def copy(self): + new_stack = AccessStack(self.C) + cache_content = self.in_cache_as_list() + if len(cache_content) > 0: + new_top_value = cache_content.popleft() + new_stack.top = Node(new_top_value) + curr = new_stack.top + for x in cache_content: + curr.next = Node(x) + return new_stack diff --git a/dace/sdfg/work_depth_analysis/operational_intensity.py b/dace/sdfg/work_depth_analysis/operational_intensity.py index f752bf937c..141b281680 100644 --- a/dace/sdfg/work_depth_analysis/operational_intensity.py +++ b/dace/sdfg/work_depth_analysis/operational_intensity.py @@ -2,39 +2,12 @@ """ Analyses the operational intensity of an input SDFG. Can be used as a Python script or from the VS Code extension. """ -""" -Plan: -- For each memory access, we need to figure out its cache line and then we compute its stack distance. -- For that we model the actual stack, where we push all the memory acesses (What do we push exactly? -Cache line ids?? check typescript implementation for that information.) -- How do we know which array maps to which cache line? - Idea: for each new array encountered, just assume that it is cache line aligned and starts - at the next free cache line. TODO: check if this is how it usually behaves. Or are arrays - aligned further, like base address % x == 0 for some x bigger than cache line size? -- It is also important that we take data types into account for each array. -- For each mem access we increase the miss counter if stack distance > C(apacity) or it it is a -compulsory miss. Then, in the end we know how many bytes are transferred to cache. It is: - num_misses * L(ine size in bytes) - -- Parameters to our analysis are - - input SDFG - - C(ache capacity) - - L(ine size) -""" - - - - - - - - - +ask_user = False import argparse from collections import deque from dace.sdfg import nodes as nd, propagation, InterstateEdge -from dace import SDFG, SDFGState, dtypes, int64 +from dace import SDFG, SDFGState, dtypes from dace.subsets import Range from typing import Tuple, Dict import os @@ -42,7 +15,7 @@ from copy import deepcopy from dace.libraries.blas import MatMul from dace.libraries.standard import Reduce, Transpose -from dace.symbolic import pystr_to_symbolic +from dace.symbolic import pystr_to_symbolic, SymExpr import ast import astunparse import warnings @@ -55,27 +28,181 @@ from dace.data import Array from dace.sdfg.work_depth_analysis.op_in_helpers import CacheLineTracker, AccessStack from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work +from dace.sdfg.work_depth_analysis.extrapolation import extrapolate + +class SymbolRange(): + + def __init__(self, start_stop_step) -> None: + self.r = range(*start_stop_step) + self.i = iter(self.r) + + def next(self): + try: + r = next(self.i) + except StopIteration: + r = -1 + return r + + def to_list(self): + return list(self.r) + + +def update_map(op_in_map, uuid, new_misses): + if uuid in op_in_map: + misses, encounters = op_in_map[uuid] + op_in_map[uuid] = (misses + new_misses, encounters + 1) + else: + op_in_map[uuid] = (new_misses, 1) + + + +def calculate_op_in(op_in_map, work_map, assumptions, stringify=False): + for uuid in op_in_map: + try: + work = work_map[uuid][0].subs(assumptions) + if work == 0 and op_in_map[uuid] == 0: + op_in_map[uuid] = 0 + elif work != 0 and op_in_map[uuid] == 0: + # everything was read from cache --> infinite op_in + op_in_map[uuid] = sp.oo + else: + # op_in > 0 --> divide normally + op_in_map[uuid] = sp.N(work / op_in_map[uuid]) + # from random import random + # op_in_map[uuid] = round(random(), 2) + if stringify: + op_in_map[uuid] = str(op_in_map[uuid]) + except Exception as e: + work = work_map[uuid][0].subs(assumptions) + print(work / op_in_map[uuid] if op_in_map[uuid] != 0 and work == 0 else sp.oo) + raise e + +def mem_accesses_on_path(states): + mem_accesses = 0 + for state in states: + mem_accesses += len(state.read_and_write_sets()) + return mem_accesses + +def find_states_between(sdfg: SDFG, start_state: SDFGState, end_state: SDFGState): + traversal_q = deque() + traversal_q.append(start_state) + visited = set() + states = [] + while traversal_q: + curr_state = traversal_q.popleft() + if curr_state == end_state: + continue + if curr_state not in visited: + visited.add(curr_state) + states.append(curr_state) + for e in sdfg.out_edges(curr_state): + traversal_q.append(e.dst) + return states + + +def find_merge_state(sdfg: SDFG, state: SDFGState): + """ + adapted from ``cfg.stateorder_topological_sort``. + """ + from dace.sdfg.analysis import cfg + + # Get parent states + ptree = cfg.state_parent_tree(sdfg) + + # Annotate branches + adf = cfg.acyclic_dominance_frontier(sdfg) + oedges = sdfg.out_edges(state) + # Skip if not branch + if len(oedges) <= 1: + return + # Skip if natural loop + if len(oedges) == 2 and ((ptree[oedges[0].dst] == state and ptree[oedges[1].dst] != state) or + (ptree[oedges[1].dst] == state and ptree[oedges[0].dst] != state)): + return + + # If branch without else (adf of one successor is equal to the other) + if len(oedges) == 2: + if {oedges[0].dst} & adf[oedges[1].dst]: + return oedges[0].dst + elif {oedges[1].dst} & adf[oedges[0].dst]: + return oedges[1].dst + + # Try to obtain common DF to find merge state + common_frontier = set() + for oedge in oedges: + frontier = adf[oedge.dst] + if not frontier: + frontier = {oedge.dst} + common_frontier |= frontier + if len(common_frontier) == 1: + return next(iter(common_frontier)) + print(f'WARNING: No merge state could be detected for branch state "{state.name}".', ) + + +def symeval(val, symbols): + """ + Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. + + :param val: The expression we are updating. + :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. + """ + first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} + second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} + return sp.simplify(val.subs(first_replacement).subs(second_replacement)) + + +def evaluate_symbols(base, new): + result = {} + for k, v in new.items(): + result[k] = symeval(v, base) + return result + + +def update_mapping(mapping, e): + update = {} + for k, v in e.data.assignments.items(): + if '[' not in k and '[' not in v: + update[pystr_to_symbolic(k)] = pystr_to_symbolic(v).subs(mapping) + mapping.update(update) + def update_map_iterators(map, mapping): # update the map params and return False # if all iterations exhausted, return True # always increase the last one. If it is exhausted, increase the next one and so forth map_exhausted = True - for p, range in zip(map.params[::-1], map.range[::-1]): + for p, range in zip(map.params[::-1], map.range[::-1]): # reversed order curr_value = mapping[p] - if curr_value.subs(mapping) < range[1].subs(mapping): - # update this value and we done - mapping[p] = curr_value + range[2].subs(mapping) - map_exhausted = False - break - else: - # set current param to start again and continue - mapping[p] = range[0].subs(mapping) + try: + if not isinstance(range[1], SymExpr): + if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].subs(mapping): + # update this value and we done + mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + map_exhausted = False + break + else: + # set current param to start again and continue + mapping[p] = range[0].subs(mapping) + else: + if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].expr.subs(mapping): + # update this value and we done + mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + map_exhausted = False + break + else: + # set current param to start again and continue + mapping[p] = range[0].subs(mapping) + except Exception as e: + print('exception in update_map_iterators:') + print(curr_value) + print(range[1]) + print(mapping, '\n\n') + raise(e) return map_exhausted -def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C): +def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches): # we are inside a map --> we need to iterate over the map range and check each memory access. for p, range in zip(entry.map.params, entry.map.range): # map each map iteration variable to its start @@ -83,102 +210,137 @@ def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, s map_misses = 0 while True: # do analysis of map contents - map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, entry) + map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches, entry) if update_map_iterators(entry.map, mapping): break return map_misses -def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, clt: CacheLineTracker, C, entry=None): - # find the work and depth of each node - # for maps and nested SDFG, we do it recursively +def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, clt: CacheLineTracker, C, symbols, array_names, w_d_map, decided_branches, entry=None): + # find the number of cache misses for each node. + # for maps and nested SDFG, we do it recursively. scope_misses = 0 scope_nodes = state.scope_children()[entry] for node in scope_nodes: # add node to map - op_in_map[get_uuid(node, state)] = 0 + # op_in_map[get_uuid(node, state)] = 0 if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C) + map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches) + # add up work for whole state, but also save work for this sub-scope scope in op_in_map - op_in_map[get_uuid(node, state)] = map_misses + update_map(op_in_map, get_uuid(node, state), map_misses) + # op_in_map[get_uuid(node, state)] = map_misses scope_misses += map_misses elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in op_in_map tasklet_misses = 0 # analyze the memory accesses of this tasklet and whether they hit in cache or not - for e in state.in_edges(node): - if e.data.data in clt.array_info: - line_id = clt.cache_line_id(e.data.data, [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) - dist = stack.touch(line_id) - tasklet_misses += 1 if dist > C or dist == -1 else 0 - for e in state.out_edges(node): - if e.data.data in clt.array_info: - line_id = clt.cache_line_id(e.data.data, [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + for e in state.in_edges(node) + state.out_edges(node): + if e.data.data in clt.array_info or (e.data.data in array_names and array_names[e.data.data] in clt.array_info): + line_id = clt.cache_line_id(e.data.data if e.data.data not in array_names else array_names[e.data.data], + [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + try: + line_id = int(line_id.subs(mapping)) + except TypeError as e: + print(line_id.subs(mapping).free_symbols) + print(mapping) + print(state.name) + try: + print(mapping[line_id.subs(mapping).free_symbols.pop()]) + except: + pass + raise(e) dist = stack.touch(line_id) - tasklet_misses += 1 if dist > C or dist == -1 else 0 + tasklet_misses += 1 if dist >= C or dist == -1 else 0 + # for e in state.out_edges(node): + # if e.data.data in clt.array_info: + # line_id = clt.cache_line_id(e.data.data if e.data.data not in array_names else array_names[e.data.data], + # [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + # dist = stack.touch(line_id) + # tasklet_misses += 1 if dist > C or dist == -1 else 0 - # TODO: wcr edges. Do they work out of the box?? scope_misses += tasklet_misses - op_in_map[get_uuid(node, state)] = tasklet_misses + # a tasklet can get passed multiple times... we report the average misses in the end + # op_in_map is a tuple for each element consisting of (num_total_misses, accesses). + # num_total_misses / accesses then gives the average misses + update_map(op_in_map, get_uuid(node, state), tasklet_misses) elif isinstance(node, nd.NestedSDFG): - # TODO: handle nested arrays properly. # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. - # We only want global symbols in our final work depth expressions. - # nested_syms = {} - # nested_syms.update(symbols) - # nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) + # We only want global symbols in our final expressions. + nested_syms = {} + nested_syms.update(symbols) + nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) + + # Handle nested arrays: Inside the nested SDFG, an array could have a different name, even + # though the same array is referenced + nested_array_names = {} + nested_array_names.update(array_names) + # for each conncector to the nested SDFG, add a pair (connector_name, incoming array name) to the dict + for e in state.in_edges(node): + nested_array_names[e.dst_conn] = e.data.data + for e in state.out_edges(node): + nested_array_names[e.src_conn] = e.data.data # Nested SDFGs are recursively analyzed first. - nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C) + nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, w_d_map, decided_branches) - # nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) - # add up work for whole state, but also save work for this nested SDFG in op_in_map + # add up misses for whole state, but also save misses for this nested SDFG in op_in_map scope_misses += nsdfg_misses - op_in_map[get_uuid(node, state)] = nsdfg_misses + # op_in_map[get_uuid(node, state)] = nsdfg_misses + update_map(op_in_map, get_uuid(node, state), nsdfg_misses) elif isinstance(node, nd.LibraryNode): - pass # TODO: implement librarynodes. Note: When encountering some libNode, we can add a symbol # "libnode_name_bytes". Then we have "libnode_name_work / libnode_name_bytes" in the final # expression. Better to just have "libnode_name_opin" in final expr. Either dont spawn the work # symbol and put the "op_in" symbol here # or replace the division in the end with the "op_in" symbol - # try: - # lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - # except KeyError: - # # add a symbol to the top level sdfg, such that the user can define it in the extension - # top_level_sdfg = state.parent - # # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, - # # such that the user can define its value. But it doesn't... - # # How to achieve this? - # top_level_sdfg.add_symbol(f'{node.name}_work', int64) - # lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) - # lib_node_depth = sp.sympify(-1) # not analyzed - # if analyze_tasklet != get_tasklet_work: - # # we are analyzing depth - # try: - # lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) - # except KeyError: - # top_level_sdfg = state.parent - # top_level_sdfg.add_symbol(f'{node.name}_depth', int64) - # lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) - # lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) - # work += lib_node_work - # op_in_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) - op_in_map[get_uuid(state)] = scope_misses + + # add a symbol to the top level sdfg, such that the user can define it in the extension + top_level_sdfg = state.parent + try: + top_level_sdfg.add_symbol(f'{node.name}_misses', dtypes.int64) + except FileExistsError: + pass + lib_node_misses = sp.Symbol(f'{node.name}_misses', positive=True) + lib_node_misses = lib_node_misses.subs(mapping) + scope_misses += lib_node_misses + update_map(op_in_map, get_uuid(node, state), lib_node_misses) + if entry is None: + # op_in_map[get_uuid(state)] = scope_misses + update_map(op_in_map, get_uuid(state), scope_misses) return scope_misses -def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack, clt, C): +def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack, clt: CacheLineTracker, C, symbols, array_names, w_d_map, decided_branches, start=None, end=None): + + # add this SDFG's arrays to the cache line tracker + for name, arr in sdfg.arrays.items(): + if isinstance(arr, Array): + if name in array_names: + name = array_names[name] + clt.add_array(name, arr, mapping) + + # traverse this SDFG's states - curr_state = sdfg.start_state + curr_state = start or sdfg.start_state total_misses = 0 + + num_states = 0 while True: - total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C) + # print(curr_state.name) + # print(mapping) + # print() + num_states += 1 + # if num_states % 100 == 0: + # print(curr_state.name) + # print(mapping) + + total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches) if len(sdfg.out_edges(curr_state)) == 0: - # we reached the end state --> stop + # we reached an end state --> stop break else: # take first edge with True condition @@ -188,78 +350,210 @@ def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mappin # save e's assignments in mapping and update curr_state # replace values first with mapping, then update mapping try: - mapping.update({k: sp.sympify(v).subs(mapping) for k, v in e.data.assignments.items() - if '[' not in k and '[' not in v}) + update_mapping(mapping, e) except: - print('WARNING: Strange assignment detected on InterstateEdge (e.g. bitwise operators).' + print('\nWARNING: Strange assignment detected on InterstateEdge (e.g. bitwise operators).' 'Analysis may give wrong results.') + print(e.data.assignments, 'was the edge\'s assignments.') curr_state = e.dst found = True break if not found: - # TODO: maybe print out the free symbols which may cause this warning. - print('WARNING: state has outgoing edges, but no condition of them can be' - 'evaluated as True. Analysis may give wrong results.') - free_syms_detected = {} - for e in sdfg.out_edges(curr_state): - free_syms_detected |= e.data.condition_sympy().free_symbols - print('Following free symbols detected in the condition of the outgoing edges:', free_syms_detected) - print(curr_state) - # continue with first edge - e = sdfg.out_edges(curr_state)[1] - mapping.update({k: sp.sympify(v).subs(mapping) for k, v in e.data.assignments.items() - if '[' not in k and '[' not in v}) - curr_state = e.dst + if curr_state in decided_branches: + # if the user already decided this branch in a previous iteration, take the same branch again. + e = decided_branches[curr_state] + + update_mapping(mapping, e) + curr_state = e.dst + else: + # we cannot determine which branch to take --> check if both contain work + merge_state = find_merge_state(sdfg, curr_state) + next_edge_candidates = [] + for e in sdfg.out_edges(curr_state): + states = find_states_between(sdfg, e.dst, merge_state) + curr_work = mem_accesses_on_path(states) + if sp.sympify(curr_work).subs(mapping) > 0: + next_edge_candidates.append(e) + + if len(next_edge_candidates) == 1: + e = next_edge_candidates[0] + update_mapping(mapping, e) + decided_branches[curr_state] = e + curr_state = e.dst + else: + if ask_user: + edges = sdfg.out_edges(curr_state) + print(f'\n\nWhich branch to take at {curr_state.name}') + for i in range(len(edges)): + print(f'({i}) for edge to state {edges[i].dst.name}') + print(edges[i].dst._read_and_write_sets()) + print('merge state is named ', merge_state) + chosen = 1 #int(input('Choose an option from above: ')) + e = edges[chosen] + update_mapping(mapping, e) + decided_branches[curr_state] = e + curr_state = e.dst + print('we continue with state', e.dst.name) + print(3*'\n') + else: + final_e = next_edge_candidates.pop() + for e in next_edge_candidates: + + # copy the state of the analysis + curr_mapping = dict(mapping) + update_mapping(curr_mapping, e) + curr_stack = stack.copy() + curr_clt = clt.copy() + curr_symbols = dict(symbols) + curr_array_names = dict(array_names) + + curr_state = e.dst + # walk down this branch until merge_state + # TODO: can we use the return value (misses of different branches) for something? + sdfg_op_in(sdfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, curr_array_names, w_d_map, decided_branches, curr_state, merge_state) + + update_mapping(mapping, final_e) + curr_state = final_e.dst + if curr_state == end: + break + + + # if sdfg.name == 'CLOUDSC': + # print('NUM STATES IS: ', num_states) - op_in_map[get_uuid(sdfg)] = total_misses + # op_in_map[get_uuid(sdfg)] = total_misses + if end is None: + # only update if we were actually analyzing a whole sdfg (not just start to end state) + update_map(op_in_map, get_uuid(sdfg), total_misses) return total_misses def analyze_sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, sp.Expr], C, L, assumptions): + sdfg = deepcopy(sdfg) - assumptions = {'N': 100} - print(C, L, assumptions) + # apply SSA pass + pipeline = FixedPointPipeline([StrictSymbolSSA()]) + pipeline.apply_pass(sdfg, {}) + + # print('C as num lines:', C, L, assumptions) # TODO: insert some checks on whether this sdfg is analyzable, like # - data-dependent loop bounds (i.e. unbounded executions) # - indirect accesses (e.g. A[B[i]]) - # TODO: use assumptions to concretize symbols - sdfg.specialize(assumptions) - mapping = {} - mapping.update(assumptions) - - stack = AccessStack() - clt = CacheLineTracker(L) - for _, name, arr in sdfg.arrays_recursive(): - if isinstance(arr, Array): - if name in clt.array_info: - # TODO: this can get triggered by nested sdfgs with same array names... needs to be looked at - print(f'WARNING: This array name ({name}) was already seen. Two arrays with the same name in the SDFG.') - clt.add_array(name, arr, mapping) - sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C) - # now we have number of misses --> multiply each by L - for k, v in op_in_map.items(): - op_in_map[k] = v * L - - # print('bytes:') - # print(op_in_map) - - print('Bytes: ', op_in_map[get_uuid(sdfg)]) - # get work + + + + + # check if all symbols are concretized + standard_range = (4, 16, 2) + num_undefined = 0 + range_symbol = {} + for sym in sdfg.free_symbols: + if sym not in assumptions: + num_undefined += 1 + range_symbol[sym] = SymbolRange(standard_range) + elif isinstance(assumptions[sym], str): + num_undefined += 1 + range_symbol[sym] = SymbolRange(int(x) for x in assumptions[sym].split(',')) + del assumptions[sym] + work_map = {} - analyze_sdfg(sdfg, work_map, get_tasklet_work, [], False) # TODO: assumptions - for uuid in op_in_map: - op_in_map[uuid] = str(work_map[uuid][0] / op_in_map[uuid] if op_in_map[uuid] != 0 else 0) + assumptions_list = [f'{x}=={y}' for x, y in assumptions.items()] + analyze_sdfg(sdfg, work_map, get_tasklet_work, assumptions_list, False) + - # print('work:') - # print(work_map) - print('Work: ', work_map[get_uuid(sdfg)][0]) - # print(op_in_map) + + if num_undefined == 0: + sdfg.specialize(assumptions) + mapping = {} + mapping.update(assumptions) + + stack = AccessStack(C) + clt = CacheLineTracker(L) + # keeps track of user's input on which branches to analyze + decided_branches: Dict[SDFGState, InterstateEdge] = {} + # all symbols concretized, do normal analysis + sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, work_map, decided_branches) + # now we have number of misses --> multiply each by L to get bytes + for k, v in op_in_map.items(): + op_in_map[k] = v[0] * L / v[1] + # divide work by bytes to get operational intensity + calculate_op_in(op_in_map, work_map, assumptions, stringify=True) + + print('bla') + elif num_undefined > 1: + raise Exception('Too many undefined symbols') + else: + assert len(range_symbol) <= 2 + op_in_measurements = {} + + # keeps track of user's input on which branches to analyze + decided_branches: Dict[SDFGState, InterstateEdge] = {} + while True: + new_val = False + for sym, r in range_symbol.items(): + val = r.next() + if val > -1: + new_val = True + assumptions[sym] = val + if not new_val: + break + + print(assumptions) + curr_op_in_map = {} + mapping = {} + mapping.update(assumptions) + stack = AccessStack(C) + clt = CacheLineTracker(L) + sdfg_op_in(sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, work_map, decided_branches) + # now we have number of misses --> multiply each by L to get bytes + for k, v in curr_op_in_map.items(): + curr_op_in_map[k] = v[0] * L / v[1] + # divide work by bytes to get operational intensity + calculate_op_in(curr_op_in_map, work_map, assumptions) + + # put curr values in op_in_measurements + for k, v in curr_op_in_map.items(): + if k in op_in_measurements: + op_in_measurements[k].append(v) + else: + op_in_measurements[k] = [v] + + extrapolate(op_in_measurements, range_symbol) + op_in_map.update(op_in_measurements) + + + # TODO: extrapolate not the op_in, but the number of cache misses!!!! Maybe its better?? + + # sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, work_map, decided_branches) + + # # print('Misses: ', op_in_map[get_uuid(sdfg)]) + - print(3*'\n') - print('num memory accesses:', stack.num_calls) + # # now we have number of misses --> multiply each by L to get bytes + # for k, v in op_in_map.items(): + # op_in_map[k] = v * L + # # print('Bytes: ', op_in_map[get_uuid(sdfg)]) + # # print('Work: ', work_map[get_uuid(sdfg)][0]) + + + # # divide work by bytes to get operational intensity + # for uuid in op_in_map: + # try: + # op_in_map[uuid] = str(sp.N(work_map[uuid][0].subs(assumptions) / op_in_map[uuid] if op_in_map[uuid] != 0 else 0)) + # except Exception as e: + # print(work_map[uuid][0] / op_in_map[uuid] if op_in_map[uuid] != 0 else 0) + # raise e + + # print('num memory accesses:', stack.num_calls) + # print('total op_in:', op_in_map[get_uuid(sdfg)]) + # print() + + # for s in decided_branches: + # print(f'\'{s.name}\', ', end='') + # print('\n\n') @@ -275,20 +569,31 @@ def main() -> None: description='Analyze the operational_intensity of an SDFG.') parser.add_argument('filename', type=str, help='The SDFG file to analyze.') - parser.add_argument('C', type=str, help='Cache size in bytes') - parser.add_argument('L', type=str, help='Cache line size in bytes') - - # TODO: add assumptions argument + parser.add_argument('--C', type=str, help='Cache size in bytes') + parser.add_argument('--L', type=str, help='Cache line size in bytes') + parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') args = parser.parse_args() + args = parser.parse_args() if not os.path.exists(args.filename): print(args.filename, 'does not exist.') exit() sdfg = SDFG.from_file(args.filename) op_in_map = {} - analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L), {}) + if args.assume is None: + args.assume = [] + + assumptions = {} + for x in args.assume: + a, b = x.split('==') + if b.isdigit(): + assumptions[a] = int(b) + else: + assumptions[a] = b + print(assumptions) + analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L), assumptions) result_whole_sdfg = op_in_map[get_uuid(sdfg)] diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index 9a8d7ee50b..a1193ec8e7 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -131,7 +131,7 @@ def count_depth_reduce(node, symbols, state): 'tanh': 1, 'math.sqrt': 1, 'sqrt': 1, - 'atan2:': 1, + 'atan2': 1, 'min': 0, 'max': 0, 'ceiling': 0, @@ -241,9 +241,9 @@ def tasklet_work(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: # simplified work analysis for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return oedge.data.num_accesses + return oedge.data.num_accesses or 0 # on Lulesh this was None for some tasklet(s) elif tasklet_node.code.language == dtypes.Language.Python: - return count_arithmetic_ops_code(tasklet_node.code.code) + return count_arithmetic_ops_code(tasklet_node.code.code) or 0 # on Lulesh this was None for some tasklet(s) else: # other languages not implemented, count whole tasklet as work of 1 warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' @@ -289,9 +289,15 @@ def update_value_map(old, new): def do_initial_subs(w, d, eq, subs1): """ - Calls subs three times for the give (w)ork and (d)epth values. + Calls subs three times for the given (w)ork and (d)epth values. """ - return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) + try: + result = sp.simplify(sp.sympify(w).subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(sp.sympify(d).subs(eq[0]).subs(eq[1]).subs(subs1)) + except Exception as e: + print('w:', w) + print('d:', d) + raise(e) + return result def sdfg_work_depth(sdfg: SDFG, @@ -382,9 +388,20 @@ def sdfg_work_depth(sdfg: SDFG, traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() + # print('number of states in this sdfg: ', len(sdfg.states())) + # num_states = 0 + while traversal_q: state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() + # num_states += 1 + # if num_states % 50 == 0: + # print(state.name) + # print('work:', work) + # print() + # print() + + if ie is not None: visited.add(ie) @@ -395,7 +412,10 @@ def sdfg_work_depth(sdfg: SDFG, state_value_map[state] = value_map # ignore assignments such as tmp=x[0], as those do not give much information. - value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} + try: + value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in state_value_map[state].items()} + except: + print('gg') n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) n_work = sp.simplify((work + state_works[state]).subs(value_map)) @@ -543,6 +563,8 @@ def scope_work_depth( for e in state.out_edges(node): if e.data.wcr is not None: t_work += count_arithmetic_ops_code(e.data.wcr) + if t_work is None: + t_work = 0 t_work, t_depth = do_initial_subs(t_work, t_depth, equality_subs, subs1) work += t_work w_d_map[get_uuid(node, state)] = (t_work, t_depth) diff --git a/tests/sdfg/operational_intensity_test.py b/tests/sdfg/operational_intensity_test.py new file mode 100644 index 0000000000..0dc4f6c7be --- /dev/null +++ b/tests/sdfg/operational_intensity_test.py @@ -0,0 +1,302 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains test cases for the operational intensity analysis. """ +import dace as dc +from dace.sdfg.work_depth_analysis.operational_intensity import analyze_sdfg_op_in +from dace.sdfg.work_depth_analysis.helpers import get_uuid +import sympy as sp + +from dace.transformation.interstate import NestSDFG +from dace.transformation.dataflow import MapExpansion +from math import isclose +from numpy import sum + +# TODO: maybe include tests for column major memory layout. AKA test that strides are taken into account correctly. +# TODO: add tests for library nodes + +N = dc.symbol('N') +M = dc.symbol('M') +K = dc.symbol('K') + +TILE_SIZE = dc.symbol('TILE_SIZE') + + + +@dc.program +def single_map64(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + # does N work, loads 3*N elements of 8 bytes + # --> op_in should be N / 3*8*N = 1/24 (no reuse) assuming L divides N + +@dc.program +def single_map16(x: dc.float16[N], y: dc.float16[N], z: dc.float16[N]): + z[:] = x + y + # does N work, loads 3*N elements of 2 bytes + # --> op_in should be N / 3*2*N = 1/6 (no reuse) assuming L divides N + + +@dc.program +def single_for_loop(x: dc.float64[N], y: dc.float64[N]): + for i in range(N): + x[i] += y[i] + # N work, 2*N*8 bytes loaded + # --> 1/16 op in + + + + +@dc.program +def if_else(x: dc.int64[100], sum: dc.int64[1]): + if x[10] > 50: + for i in range(100): + sum += x[i] + if x[0] > 3: + for i in range(100): + sum += x[i] + # no else --> simply analyze the ifs. if cache big enough, everything is reused + + +@dc.program +def unaligned_for_loop(x: dc.float32[100], sum: dc.int64[1]): + for i in range(17, 53): + sum += x[i] + + + +@dc.program +def sequential_maps(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + z[:] *= 2 + z[:] += x + # does N work, loads 3*N elements of 8 bytes + # --> op_in should be N / 3*8*N = 1/24 (no reuse) assuming L divides N + +@dc.program +def nested_reuse(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N], result: dc.float64[1]): + # load x, y and z + z[:] = x + y + result[0] = sum(z) + # tests whether the access to z from the nested SDFG correspond with the prior accesses + # to z outside of the nested SDFG. + +@dc.program +def mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N,N]): + for n, k, m in dc.map[0:N, 0:N, 0:N]: + z[n,k] += x[n,m] * y[m,k] + + +@dc.program +def tiled_mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N,N]): + for n_TILE, k_TILE, m_TILE in dc.map[0:N:TILE_SIZE, 0:N:TILE_SIZE, 0:N:TILE_SIZE]: + for n, k, m in dc.map[n_TILE:n_TILE+TILE_SIZE, k_TILE:k_TILE+TILE_SIZE, m_TILE:m_TILE+TILE_SIZE]: + z[n,k] += x[n,m] * y[m,k] + +@dc.program +def tiled_mmm_32(x: dc.float32[N, N], y: dc.float32[N, N], z: dc.float32[N,N]): + for n_TILE, k_TILE, m_TILE in dc.map[0:N:TILE_SIZE, 0:N:TILE_SIZE, 0:N:TILE_SIZE]: + for n, k, m in dc.map[n_TILE:n_TILE+TILE_SIZE, k_TILE:k_TILE+TILE_SIZE, m_TILE:m_TILE+TILE_SIZE]: + z[n,k] += x[n,m] * y[m,k] + + +# @dc.program +# def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): +# if x[10] > 50: +# z[:] = x + y # N work, 1 depth +# else: +# for i in range(K): # K work, K depth +# sum += x[i] + + +# @dc.program +# def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): +# single_map64(x, y, z) +# single_for_loop(x, y) + + +# @dc.program +# def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): +# z[:, :] = x + y + + +# @dc.program +# def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): +# for i in range(N): +# for j in range(K): +# x[i] += y[j] + + +# @dc.program +# def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): +# if x[10] > 50: +# if x[9] > 40: +# z[:] = x + y # N work, 1 depth +# z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth +# else: +# if y[9] > 30: +# for i in range(K): +# sum += x[i] # K work, K depth +# else: +# for j in range(M): +# sum += x[j] # M work, M depth +# z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth +# # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth +# # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + + +# @dc.program +# def max_of_positive_symbol(x: dc.float64[N]): +# if x[0] > 0: +# for i in range(2 * N): # work 2*N^2, depth 2*N +# x += 1 +# else: +# for j in range(3 * N): # work 3*N^2, depth 3*N +# x += 1 +# # total is work 3*N^2, depth 3*N without any max + + +# @dc.program +# def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], x2: dc.int64[M], y2: dc.int64[M], +# z2: dc.int64[M], x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): +# if x[0] > 0: +# z[:] = 2 * x + y # work 2*N, depth 2 +# elif x[1] > 0: +# z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 +# z2[0] += 3 + z[1] + z[2] +# elif x[2] > 0: +# z3[:] = 2 * x3 + y3 # work 2*K, depth 2 +# elif x[3] > 0: +# z[:] = 3 * x + y + 1 # work 3*N, depth 3 +# # --> work= Max(3*N, 2*M, 2*K) and depth = 5 + + +# @dc.program +# def unbounded_while_do(x: dc.float64[N]): +# while x[0] < 100: +# x += 1 + + +# @dc.program +# def unbounded_do_while(x: dc.float64[N]): +# while True: +# x += 1 +# if x[0] >= 100: +# break + + +# @dc.program +# def unbounded_nonnegify(x: dc.float64[N]): +# while x[0] < 100: +# if x[1] < 42: +# x += 3 * x +# else: +# x += x + + +# @dc.program +# def continue_for_loop(x: dc.float64[N]): +# for i in range(N): +# if x[i] > 100: +# continue +# x += 1 + + +# @dc.program +# def break_for_loop(x: dc.float64[N]): +# for i in range(N): +# if x[i] > 100: +# break +# x += 1 + + +# @dc.program +# def break_while_loop(x: dc.float64[N]): +# while x[0] > 10: +# if x[1] > 100: +# break +# x += 1 + + +# @dc.program +# def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive +# if x[0] > 5: +# x[:] += 1 # N+1 work, 1 depth +# else: +# for i in range(M): # M work, M depth +# y[i + 1] += y[i] +# if M > N: +# y[:N + 1] += x[:] # N+1 work, 1 depth +# else: +# x[:M + 1] += y[:] # M+1 work, 1 depth +# # --> Work: Max(N+1, M) + Max(N+1, M+1) +# # Depth: Max(1, M) + 1 + + +#(sdfg, c, l, assumptions, expected_result) +tests_cases = [ + (single_map64, 64*64, 64, {'N' : 512}, 1/24), + (single_map16, 64*64, 64, {'N' : 512}, 1/6), + # now num_elements_on_single_cache_line does not divie N anymore + # -->513 work, 520 elements loaded --> 513 / (520*8*3) + (single_map64, 64*64, 64, {'N' : 513}, 513 / (3*8*520)), + + + + # # this one fails, but the issue is more broad than the op_in analysis --> skip for now + # (single_for_loop, 64, 64, {'N': 1024}, 1/16) + # # this one fails, but the issue is more broad than the op_in analysis --> skip for now + # (if_else, 1000, 800, {}, 200 / 1600), + # # this one fails, but the issue is more broad than the op_in analysis --> skip for now + # (unaligned_for_loop, -1, -1, {}, -1) + + + (sequential_maps, 1024, 3*8, {'N' : 29}, 87 / (90*8)), + # smaller cache --> only two arrays fit --> x loaded twice now + (sequential_maps, 6, 3*8, {'N' : 7}, 21 / (13*3*8)), + + + (nested_reuse, 1024, 64, {'N' : 1024}, 2048 / (3*1024*8 + 128)), + (mmm, 20, 16, {'N': 24}, (2*24**3) / ((36*24**2 + 24*12) * 16)), + (tiled_mmm, 20, 16, {'N': 24, 'TILE_SIZE' : 4}, (2*24**3) / (16*24*6**3)), + (tiled_mmm_32, 10, 16, {'N': 24, 'TILE_SIZE' : 4}, (2*24**3) / (16*12*6**3)), + + + # (nested_sdfg, (2 * N, N + 1)), + # (nested_maps, (M * N, 1)), + # (nested_for_loops, (K * N, K * N)), + # (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), + # (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), + # (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) +] + + +# tests_cases = [ +# (nested_reuse, 1024, 64, {'N' : 1024}, 2048 / (3*1024*8 + 128)) +# ] + +def test_operational_intensity(): + errors = 0 + for test, c, l, assumptions, correct in tests_cases: + op_in_map = {} + sdfg = test.to_sdfg() + sdfg.expand_library_nodes() + if test.name == 'mmm': + sdfg.save('mmm.sdfg') + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + analyze_sdfg_op_in(sdfg, op_in_map, c, l, assumptions) + res = float(op_in_map[get_uuid(sdfg)]) + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + # check result + # assert correct == res + if not isclose(correct, res): + print(sdfg.name) + print(c, l, assumptions, correct, res) + print('ERROR DETECTED') + errors += 1 + + print(f'Encountered {errors} failing tests out of {len(tests_cases)} tests') + +if __name__ == '__main__': + test_operational_intensity()