From f8a2a9dd4e1cbdc7653830f29339239bb2c9ce38 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Wed, 20 Dec 2023 23:44:21 +0100 Subject: [PATCH] clean up code and write docstrings and comments --- dace/sdfg/work_depth_analysis/assumptions.py | 4 +- .../sdfg/work_depth_analysis/extrapolation.py | 233 -------- .../sdfg/work_depth_analysis/op_in_helpers.py | 185 +++++- .../operational_intensity.py | 545 +++++++++--------- dace/sdfg/work_depth_analysis/work_depth.py | 80 ++- tests/sdfg/operational_intensity_test.py | 275 +++------ tests/sdfg/work_depth_tests.py | 63 +- 7 files changed, 607 insertions(+), 778 deletions(-) delete mode 100644 dace/sdfg/work_depth_analysis/extrapolation.py diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py index 6e311cde0c..ec8c61ef73 100644 --- a/dace/sdfg/work_depth_analysis/assumptions.py +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -153,7 +153,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions): equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) equality_subs2 = {} - # In a second step, each symbol gets replace with its equal number (if present) + # In a second step, each symbol gets replaced with its equal number (if present) # using equality_subs2. for sym, assum in condensed_assumptions.items(): for e in assum.equal: @@ -182,7 +182,7 @@ def parse_assumptions(assumptions, array_symbols): Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions. - The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent + The other assumptions get handled by N tuples of substitution dicts (N = max number of concurrent assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol. First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol. diff --git a/dace/sdfg/work_depth_analysis/extrapolation.py b/dace/sdfg/work_depth_analysis/extrapolation.py deleted file mode 100644 index 0a38805bae..0000000000 --- a/dace/sdfg/work_depth_analysis/extrapolation.py +++ /dev/null @@ -1,233 +0,0 @@ -from scipy.optimize import curve_fit -import numpy as np -import matplotlib.pyplot as plt - -def print_scores(scores): - for k, v in scores.items(): - print(k.name, v) - -class Logistic: - - def __init__(self, name): - self.x_name = name - self.name = 'Logistic' - - def f(x, a, b, c): - return b / (c + np.exp(-a * x)) - - def fit(self, x, y): - param, _ = curve_fit(Logistic.f, x, y, maxfev=10000) - self.a, self.b, self.c = param - - def predict(self, x): - return Logistic.f(x, self.a, self.b, self.c) - - def to_string(self): - return f'{self.b} / ({self.c} + exp({-self.a} * {self.x_name}))' - -class Log: - def __init__(self, name): - self.x_name = name - self.name = 'Log' - - - def f(x, a, b): - return a * np.log(x) + b - - def fit(self, x, y): - param, _ = curve_fit(Log.f, x, y, maxfev=2500) - self.a, self.b = param - - def predict(self, x): - return Log.f(x, self.a, self.b) - - def to_string(self): - return f'{self.a} * log({self.x_name}) + {self.b}' - -class Plateau: - def __init__(self, name): - self.x_name = name - self.name = 'Plateau' - - - def f(x, a, b): - return (a * x) / (x + b) - - def fit(self, x, y): - param, _ = curve_fit(Plateau.f, x, y, maxfev=2500) - self.a, self.b = param - - def predict(self, x): - return Plateau.f(x, self.a, self.b) - - def to_string(self): - return f'({self.a} * {self.x_name}) / ({self.x_name} + {self.b})' - - -class Poly: - def __init__(self, name): - self.x_name = name - self.name = 'Poly' - - - def f(x, a, b): - return a * x + b - - def fit(self, x, y): - param, _ = curve_fit(Poly.f, x, y, maxfev=2500) - self.a, self.b = param - - def predict(self, x): - return Poly.f(x, self.a, self.b) - - def to_string(self): - return f'{self.a} * {self.x_name} + {self.b}' - -class Sqrt: - def __init__(self, name): - self.x_name = name - self.name = 'Sqrt' - - def f(x, a, b): - return a * np.sqrt(x) + b - - def fit(self, x, y): - param, _ = curve_fit(Sqrt.f, x, y, maxfev=2500) - self.a, self.b = param - - def predict(self, x): - return Sqrt.f(x, self.a, self.b) - - def to_string(self): - return f'{self.a} * sqrt({self.x_name}) + {self.b}' - -class Exponential: - def __init__(self, name): - self.x_name = name - self.name = 'Exponential' - - def f(x, a, b): - return a * np.exp(x) + b - - def fit(self, x, y): - param, _ = curve_fit(Exponential.f, x, y, maxfev=2500) - self.a, self.b = param - - def predict(self, x): - return Exponential.f(x, self.a, self.b) - - def to_string(self): - return f'{self.a} * np.exp({self.x_name}) + {self.b}' - -class Sin: - def __init__(self, name): - self.x_name = name - self.name = 'Sin' - - def f(x, a, b, c, d): - return a * np.sin(b*x + c) + d - - def fit(self, x, y): - param, _ = curve_fit(Sin.f, x, y, maxfev=2500) - self.a, self.b, self.c, self.d = param - - def predict(self, x): - return Sin.f(x, self.a, self.b, self.c, self.d) - - def to_string(self): - return f'{self.a} * sin({self.b}*{self.x_name} + {self.c}) + {self.d}' - -class Constant: - def __init__(self, name): - self.x_name = name - self.name = 'Sin' - - def f(x, a): - return np.ones_like(x) * a - - def fit(self, x, y): - param, _ = curve_fit(Constant.f, x, y, maxfev=2500) - self.a = param - - def predict(self, x): - return Constant.f(x, self.a) - - def to_string(self): - return f'{self.a}' - - - -def extrapolate(op_in_map, range_symbol): - """ - For each key in op_in_map (aka for each SDFG element), we have a list of measured data points y - for the values in x_values. - Now we fit a curve and return the best function found via leave-one-out cross validation. - """ - - if len(range_symbol) == 1: - # only 1 independent variable - symbol_name = list(range_symbol.keys())[0] - x = range_symbol[symbol_name].to_list() - - models = [Logistic(symbol_name), Log(symbol_name), Plateau(symbol_name), Poly(symbol_name), Sqrt(symbol_name), - Exponential(symbol_name), Sin(symbol_name), Constant(symbol_name)] - - for element, y in op_in_map.items(): - all_zero = True - for q in y: - if q != 0.0: - all_zero = False - break - if all_zero: - op_in_map[element] = str(0) - continue - scores = {} - for model in models: - error_sum = 0 - for left_out in range(len(x)): - xx = list(x) - test_x = xx.pop(left_out) - yy = list(y) - test_y = yy.pop(left_out) - try: - model.fit(xx, yy) - except RuntimeError: - # triggered if no fit was found --> give huge error - error_sum += 999999999 - # predict on left out sample - pred = model.predict(test_x) - # squared_error = np.square(pred - test_y) - # error_sum += squared_error - root_error = np.sqrt(np.abs(float(pred - test_y))) - error_sum += root_error - - mean_error = error_sum / len(x) - scores[model] = mean_error - - - - # find model with least error - min_model = model - min_error = mean_error - for model, error in scores.items(): - if error < min_error: - min_error = error - min_model = model - - # fit best model to all points and plot - min_model.fit(x, y) - fig, ax = plt.subplots() # Create a figure containing a single axes. - ax.scatter(x, y) - s = 1 - t = x[-1] + 3 - q = np.linspace(s, t, num=(t-s)*5) - r = min_model.predict(q) - ax.plot(q, r, label=min_model.to_string()) - - fig.tight_layout() - plt.show() - - op_in_map[element] = min_model.to_string() - - else: - print('2 independent variables not implemented yet') \ No newline at end of file diff --git a/dace/sdfg/work_depth_analysis/op_in_helpers.py b/dace/sdfg/work_depth_analysis/op_in_helpers.py index f5bb637e1a..6e84e64129 100644 --- a/dace/sdfg/work_depth_analysis/op_in_helpers.py +++ b/dace/sdfg/work_depth_analysis/op_in_helpers.py @@ -1,12 +1,19 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position. -Further, contains class AccessStack which which corresponds to the stack used to compute the stack distance. """ +""" Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position +and class AccessStack which which corresponds to the stack used to compute the stack distance. +Further, provides a curve fitting method and plotting function. """ from dace.data import Array import sympy as sp from collections import deque +from scipy.optimize import curve_fit +import numpy as np +import matplotlib.pyplot as plt +from dace import symbol + class CacheLineTracker: + """ A CacheLineTracker maps data container accesses to the corresponding accessed cache line. """ def __init__(self, L) -> None: self.array_info = {} @@ -20,7 +27,7 @@ def add_array(self, name: str, a: Array, mapping): self.array_info[name] = a self.start_lines[name] = self.next_free_line # increase next_free_line - self.next_free_line += (a.total_size.subs(mapping) * a.dtype.bytes + self.L - 1) // self.L # ceil division + self.next_free_line += (a.total_size.subs(mapping) * a.dtype.bytes + self.L - 1) // self.L # ceil division def cache_line_id(self, name: str, access: [int], mapping): arr = self.array_info[name] @@ -39,6 +46,7 @@ def copy(self): new_clt.next_free_line = self.next_free_line return new_clt + class Node: def __init__(self, val: int, n=None) -> None: @@ -51,9 +59,6 @@ class AccessStack: in the stack, report its distance and move it to the top of the stack. If the id was not found, we report a distance of -1. """ - # TODO: this can be optimised such that the stack is never larger than C, since all elements deeper than C are misses - # anyway. (then we cannot distinguish compulsory misses from capacity misses though) - def __init__(self, C) -> None: self.top = None self.num_calls = 0 @@ -83,10 +88,6 @@ def touch(self, id): curr = curr.next distance += 1 - # shorten the stack if distance >= C - # if distance >= self.C and curr is not None: - # curr.next = None - if not found: # we accessed this cache line for the first time ever self.top = Node(id, self.top) @@ -94,25 +95,7 @@ def touch(self, id): distance = -1 return distance - - def compare_cache(self, other): - "Returns True if the same data resides in cache with the same LRU order" - s = self.top - o = other.top - dist = 0 - while s is not None and o is not None and dist < self.C: - dist += 1 - if s != o: - return False - s = s.next - o = o.next - if s is None and o is not None: - return False - if s is not None and o is None: - return False - - return True - + def in_cache_as_list(self): """ Returns a list of cache ids currently in cache. Index 0 is the most recently used. @@ -125,7 +108,7 @@ def in_cache_as_list(self): curr = curr.next dist += 1 return res - + def debug_print(self): # prints the whole stack print('\n') @@ -144,4 +127,146 @@ def copy(self): curr = new_stack.top for x in cache_content: curr.next = Node(x) + curr = curr.next return new_stack + + +def plot(x, work_map, cache_misses, op_in_map, symbol_name, C, L, sympy_f, element, name): + work_map = work_map[element] + cache_misses = cache_misses[element] + op_in_map = op_in_map[element] + sympy_f = sympy_f[element] + + a = np.linspace(1, max(x) + 5, max(x) * 4) + + fig, ax = plt.subplots(1, 2, figsize=(12, 5)) # Create a figure containing a single axes. + ax[0].scatter(x, cache_misses, label=f'C={C*L}, L={L}') + b = [] + for curr in a: + b.append(sp.N(sp.sympify(sympy_f).subs(symbol_name, curr))) + ax[0].plot(a, b) + + c = [] + for i, curr in enumerate(x): + if work_map[0].subs(symbol_name, curr) == 0: + c.append(0) + elif (cache_misses[i] * L) == 0: + c.append(9999) + else: + c.append(work_map[0].subs(symbol_name, curr) / (cache_misses[i] * L)) + c = np.array(c).astype(np.float64) + + ax[1].scatter(x, c, label=f'C={C*L}, L={L}') + b = [] + for curr in a: + b.append(sp.N(sp.sympify(op_in_map).subs(symbol_name, curr))) + ax[1].plot(a, b) + + ax[0].set_ylim(bottom=0, top=max(cache_misses) + max(cache_misses) / 10) + ax[0].set_xlim(left=0, right=max(x) + 1) + ax[0].set_xlabel(symbol_name) + ax[0].set_ylabel('Number of Cache Misses') + ax[0].set_title(name) + ax[0].legend(fancybox=True, framealpha=0.5) + + ax[1].set_ylim(bottom=0, top=max(c) + max(c) / 10) + ax[1].set_xlim(left=0, right=max(x) + 1) + ax[1].set_xlabel(symbol_name) + ax[1].set_ylabel('Operational Intensity') + ax[1].set_title(name) + + fig.show() + + +def compute_mape(f, test_x, test_y, test_set_size): + total_error = 0 + for i in range(test_set_size): + pred = f(test_x[i]) + err = abs(test_y[i] - pred) + total_error += err / test_y[i] + return total_error / test_set_size + + +def r_squared(pred, y): + if np.sum(np.square(y - y.mean())) <= 0.0001: + return 1 + return 1 - np.sum(np.square(y - pred)) / np.sum(np.square(y - y.mean())) + + +def find_best_model(x, y, I, J, symbol_name): + """ Find the best model out of all combinations of (i, j) from I and J via leave-one-out cross validation. """ + min_error = None + for i in I: + for j in J: + # current model + if i == 0 and j == 0: + + def f(x, b): + return b * np.ones_like(x) + else: + + def f(x, c, b): + return c * np.power(x, i) * np.power(np.log2(x), j) + b + + error_sum = 0 + for left_out in range(len(x)): + xx = np.delete(x, left_out) + yy = np.delete(y, left_out) + try: + param, _ = curve_fit(f, xx, yy) + + # predict on left out sample + pred = f(x[left_out], *param) + squared_error = np.square(pred - y[left_out]) + error_sum += squared_error + except RuntimeError: + # triggered if no fit was found --> give huge error + error_sum += 999999 + + mean_error = error_sum / len(x) + if min_error is None or mean_error < min_error: + # new best model found + min_error = mean_error + best_i_j = (i, j) + if best_i_j[0] == 0 and best_i_j[1] == 0: + + def f_best(x, b): + return b * np.ones_like(x) + else: + + def f_best(x, c, b): + return c * np.power(x, best_i_j[0]) * np.power(np.log2(x), best_i_j[1]) + b + + # fit best model to all data points + final_p, _ = curve_fit(f_best, x, y) + + def final_f(x): + return f_best(x, *final_p) + + if best_i_j[0] == 0 and best_i_j[1] == 0: + sympy_f = final_p[0] + else: + sympy_f = sp.simplify(final_p[0] * symbol(symbol_name)**best_i_j[0] * + sp.log(symbol(symbol_name), 2)**best_i_j[1] + final_p[1]) + # compute r^2 + r_s = r_squared(final_f(x), y) + return final_f, sympy_f, r_s + + +def fit_curve(x, y, symbol_name): + """ + Fits a function throught the data set. + + :param x: The independent values. + :param y: The dependent values. + :param symbol_name: The name of the SDFG symbol. + """ + x = np.array(x).astype(np.int32) + y = np.array(y).astype(np.float64) + + # model search space + I = [x / 4 for x in range(13)] + J = [0, 1, 2] + final_f, sympy_final_f, r_s = find_best_model(x, y, I, J, symbol_name) + + return final_f, sympy_final_f, r_s diff --git a/dace/sdfg/work_depth_analysis/operational_intensity.py b/dace/sdfg/work_depth_analysis/operational_intensity.py index 141b281680..f9c3836e40 100644 --- a/dace/sdfg/work_depth_analysis/operational_intensity.py +++ b/dace/sdfg/work_depth_analysis/operational_intensity.py @@ -2,35 +2,28 @@ """ Analyses the operational intensity of an input SDFG. Can be used as a Python script or from the VS Code extension. """ -ask_user = False - import argparse from collections import deque -from dace.sdfg import nodes as nd, propagation, InterstateEdge +from dace.sdfg import nodes as nd from dace import SDFG, SDFGState, dtypes -from dace.subsets import Range from typing import Tuple, Dict import os import sympy as sp from copy import deepcopy -from dace.libraries.blas import MatMul -from dace.libraries.standard import Reduce, Transpose from dace.symbolic import pystr_to_symbolic, SymExpr -import ast -import astunparse -import warnings -from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits -from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions +from dace.sdfg.work_depth_analysis.helpers import get_uuid from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline from dace.data import Array from dace.sdfg.work_depth_analysis.op_in_helpers import CacheLineTracker, AccessStack from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work -from dace.sdfg.work_depth_analysis.extrapolation import extrapolate +from dace.sdfg.work_depth_analysis.extrapolation import fit_curve, plot, compute_mape + class SymbolRange(): + """ Used to describe an SDFG symbol associated with a range (start, stop, step) of values. """ def __init__(self, start_stop_step) -> None: self.r = range(*start_stop_step) @@ -42,47 +35,52 @@ def next(self): except StopIteration: r = -1 return r - + def to_list(self): return list(self.r) + def max_value(self): + return max(self.to_list()) -def update_map(op_in_map, uuid, new_misses): - if uuid in op_in_map: - misses, encounters = op_in_map[uuid] - op_in_map[uuid] = (misses + new_misses, encounters + 1) - else: - op_in_map[uuid] = (new_misses, 1) +def update_map(op_in_map, uuid, new_misses, average=True): + if average: + if uuid in op_in_map: + misses, encounters = op_in_map[uuid] + op_in_map[uuid] = (misses + new_misses, encounters + 1) + else: + op_in_map[uuid] = (new_misses, 1) + else: + if uuid in op_in_map: + misses, encounters = op_in_map[uuid] + op_in_map[uuid] = (misses + new_misses, encounters) + else: + op_in_map[uuid] = (new_misses, 1) -def calculate_op_in(op_in_map, work_map, assumptions, stringify=False): +def calculate_op_in(op_in_map, work_map, stringify=False, assumptions={}): + """ Calculates the operational intensity for each SDFG element from work and bytes loaded. """ for uuid in op_in_map: - try: - work = work_map[uuid][0].subs(assumptions) - if work == 0 and op_in_map[uuid] == 0: - op_in_map[uuid] = 0 - elif work != 0 and op_in_map[uuid] == 0: - # everything was read from cache --> infinite op_in - op_in_map[uuid] = sp.oo - else: - # op_in > 0 --> divide normally - op_in_map[uuid] = sp.N(work / op_in_map[uuid]) - # from random import random - # op_in_map[uuid] = round(random(), 2) - if stringify: - op_in_map[uuid] = str(op_in_map[uuid]) - except Exception as e: - work = work_map[uuid][0].subs(assumptions) - print(work / op_in_map[uuid] if op_in_map[uuid] != 0 and work == 0 else sp.oo) - raise e - + work = work_map[uuid][0].subs(assumptions) + if work == 0 and op_in_map[uuid] == 0: + op_in_map[uuid] = 0 + elif work != 0 and op_in_map[uuid] == 0: + # everything was read from cache --> infinite op_in + op_in_map[uuid] = sp.oo + else: + # op_in > 0 --> divide normally + op_in_map[uuid] = sp.N(work / op_in_map[uuid]) + if stringify: + op_in_map[uuid] = str(op_in_map[uuid]) + + def mem_accesses_on_path(states): mem_accesses = 0 for state in states: mem_accesses += len(state.read_and_write_sets()) return mem_accesses + def find_states_between(sdfg: SDFG, start_state: SDFGState, end_state: SDFGState): traversal_q = deque() traversal_q.append(start_state) @@ -117,7 +115,7 @@ def find_merge_state(sdfg: SDFG, state: SDFGState): return # Skip if natural loop if len(oedges) == 2 and ((ptree[oedges[0].dst] == state and ptree[oedges[1].dst] != state) or - (ptree[oedges[1].dst] == state and ptree[oedges[0].dst] != state)): + (ptree[oedges[1].dst] == state and ptree[oedges[0].dst] != state)): return # If branch without else (adf of one successor is equal to the other) @@ -162,7 +160,7 @@ def update_mapping(mapping, e): update = {} for k, v in e.data.assignments.items(): if '[' not in k and '[' not in v: - update[pystr_to_symbolic(k)] = pystr_to_symbolic(v).subs(mapping) + update[k] = pystr_to_symbolic(v).subs(mapping) mapping.update(update) @@ -171,38 +169,31 @@ def update_map_iterators(map, mapping): # if all iterations exhausted, return True # always increase the last one. If it is exhausted, increase the next one and so forth map_exhausted = True - for p, range in zip(map.params[::-1], map.range[::-1]): # reversed order + for p, range in zip(map.params[::-1], map.range[::-1]): # reversed order curr_value = mapping[p] - try: - if not isinstance(range[1], SymExpr): - if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].subs(mapping): - # update this value and we done - mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) - map_exhausted = False - break - else: - # set current param to start again and continue - mapping[p] = range[0].subs(mapping) + if not isinstance(range[1], SymExpr): + if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].subs(mapping): + # update this value and we done + mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + map_exhausted = False + break else: - if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].expr.subs(mapping): - # update this value and we done - mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) - map_exhausted = False - break - else: - # set current param to start again and continue - mapping[p] = range[0].subs(mapping) - except Exception as e: - print('exception in update_map_iterators:') - print(curr_value) - print(range[1]) - print(mapping, '\n\n') - raise(e) + # set current param to start again and continue + mapping[p] = range[0].subs(mapping) + else: + if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].expr.subs(mapping): + # update this value and we done + mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + map_exhausted = False + break + else: + # set current param to start again and continue + mapping[p] = range[0].subs(mapping) return map_exhausted - -def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches): +def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user): # we are inside a map --> we need to iterate over the map range and check each memory access. for p, range in zip(entry.map.params, entry.map.range): # map each map iteration variable to its start @@ -210,57 +201,67 @@ def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, s map_misses = 0 while True: # do analysis of map contents - map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches, entry) + map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, + ask_user, entry) if update_map_iterators(entry.map, mapping): break return map_misses - -def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, clt: CacheLineTracker, C, symbols, array_names, w_d_map, decided_branches, entry=None): + +def scope_op_in(state: SDFGState, + op_in_map: Dict[str, sp.Expr], + mapping, + stack: AccessStack, + clt: CacheLineTracker, + C, + symbols, + array_names, + decided_branches, + ask_user, + entry=None): + """ + Computes the operational intensity of a single scope (scope is either an SDFG state or a map scope). + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary storing the resulting operational intensity for each SDFG element. + :param mapping: Mapping of SDFG symbols to their current values. + :param stack: The stack used to track the stack distances. + :param clt: The current CacheLineTracker object mapping data container accesses to cache line ids. + :param C: Cache size in bytes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param array_names: A dictionary mapping local nested SDFG array names to global array names. + :param decided_branches: Dictionary keeping track of user's decisions on which branches to analyze (if ask_user is True). + :param ask_user: If True, the user has to decide which branch to analyze in case it cannot be determined automatically. If False, + all branches get analyzed. + :param entry: If None, the whole state gets analyzed. Else, only the scope starting at this entry node is analyzed. + """ + # find the number of cache misses for each node. # for maps and nested SDFG, we do it recursively. scope_misses = 0 scope_nodes = state.scope_children()[entry] for node in scope_nodes: - # add node to map - # op_in_map[get_uuid(node, state)] = 0 if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. - # The resulting work/depth are summarized into the entry node - map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches) - - # add up work for whole state, but also save work for this sub-scope scope in op_in_map + map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user) + update_map(op_in_map, get_uuid(node, state), map_misses) - # op_in_map[get_uuid(node, state)] = map_misses scope_misses += map_misses elif isinstance(node, nd.Tasklet): - # add up work for whole state, but also save work for this node in op_in_map tasklet_misses = 0 # analyze the memory accesses of this tasklet and whether they hit in cache or not for e in state.in_edges(node) + state.out_edges(node): - if e.data.data in clt.array_info or (e.data.data in array_names and array_names[e.data.data] in clt.array_info): - line_id = clt.cache_line_id(e.data.data if e.data.data not in array_names else array_names[e.data.data], - [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) - try: - line_id = int(line_id.subs(mapping)) - except TypeError as e: - print(line_id.subs(mapping).free_symbols) - print(mapping) - print(state.name) - try: - print(mapping[line_id.subs(mapping).free_symbols.pop()]) - except: - pass - raise(e) + if e.data.data in clt.array_info or (e.data.data in array_names + and array_names[e.data.data] in clt.array_info): + line_id = clt.cache_line_id( + e.data.data if e.data.data not in array_names else array_names[e.data.data], + [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + + line_id = int(line_id.subs(mapping)) dist = stack.touch(line_id) tasklet_misses += 1 if dist >= C or dist == -1 else 0 - # for e in state.out_edges(node): - # if e.data.data in clt.array_info: - # line_id = clt.cache_line_id(e.data.data if e.data.data not in array_names else array_names[e.data.data], - # [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) - # dist = stack.touch(line_id) - # tasklet_misses += 1 if dist > C or dist == -1 else 0 scope_misses += tasklet_misses # a tasklet can get passed multiple times... we report the average misses in the end @@ -283,21 +284,14 @@ def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: for e in state.in_edges(node): nested_array_names[e.dst_conn] = e.data.data for e in state.out_edges(node): - nested_array_names[e.src_conn] = e.data.data + nested_array_names[e.src_conn] = e.data.data # Nested SDFGs are recursively analyzed first. - nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, w_d_map, decided_branches) + nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, + decided_branches, ask_user) - # add up misses for whole state, but also save misses for this nested SDFG in op_in_map scope_misses += nsdfg_misses - # op_in_map[get_uuid(node, state)] = nsdfg_misses update_map(op_in_map, get_uuid(node, state), nsdfg_misses) elif isinstance(node, nd.LibraryNode): - # TODO: implement librarynodes. Note: When encountering some libNode, we can add a symbol - # "libnode_name_bytes". Then we have "libnode_name_work / libnode_name_bytes" in the final - # expression. Better to just have "libnode_name_opin" in final expr. Either dont spawn the work - # symbol and put the "op_in" symbol here - # or replace the division in the end with the "op_in" symbol - # add a symbol to the top level sdfg, such that the user can define it in the extension top_level_sdfg = state.parent try: @@ -309,35 +303,58 @@ def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: scope_misses += lib_node_misses update_map(op_in_map, get_uuid(node, state), lib_node_misses) if entry is None: - # op_in_map[get_uuid(state)] = scope_misses - update_map(op_in_map, get_uuid(state), scope_misses) + # if entry is none this means that we are analyzing the whole state --> save number of misses in get_uuid(state) + update_map(op_in_map, get_uuid(state), scope_misses, average=False) return scope_misses -def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack, clt: CacheLineTracker, C, symbols, array_names, w_d_map, decided_branches, start=None, end=None): - - # add this SDFG's arrays to the cache line tracker - for name, arr in sdfg.arrays.items(): - if isinstance(arr, Array): - if name in array_names: - name = array_names[name] - clt.add_array(name, arr, mapping) - - # traverse this SDFG's states - curr_state = start or sdfg.start_state - total_misses = 0 +def sdfg_op_in(sdfg: SDFG, + op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + mapping, + stack: AccessStack, + clt: CacheLineTracker, + C, + symbols, + array_names, + decided_branches, + ask_user, + start=None, + end=None): + """ + Computes the operational intensity of the input SDFG. + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary storing the resulting operational intensity for each SDFG element. + :param mapping: Mapping of SDFG symbols to their current values. + :param stack: The stack used to track the stack distances. + :param clt: The current CacheLineTracker object mapping data container accesses to cache line ids. + :param C: Cache size in bytes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param array_names: A dictionary mapping local nested SDFG array names to global array names. + :param decided_branches: Dictionary keeping track of user's decisions on which branches to analyze (if ask_user is True). + :param ask_user: If True, the user has to decide which branch to analyze in case it cannot be determined automatically. If False, + all branches get analyzed. + :param start: The start state of the SDFG traversal. If None, the SDFG's normal start state is used. + :param end: The end state of the SDFG traversal. If None, the whole SDFG is traversed. + """ + + if start is None: + # add this SDFG's arrays to the cache line tracker + for name, arr in sdfg.arrays.items(): + if isinstance(arr, Array): + if name in array_names: + name = array_names[name] + clt.add_array(name, arr, mapping) + # start traversal at SDFG's start state + curr_state = sdfg.start_state + else: + curr_state = start - num_states = 0 + total_misses = 0 + # traverse this SDFG's states while True: - # print(curr_state.name) - # print(mapping) - # print() - num_states += 1 - # if num_states % 100 == 0: - # print(curr_state.name) - # print(mapping) - - total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, w_d_map, decided_branches) + total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user) if len(sdfg.out_edges(curr_state)) == 0: # we reached an end state --> stop @@ -353,16 +370,24 @@ def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mappin update_mapping(mapping, e) except: print('\nWARNING: Strange assignment detected on InterstateEdge (e.g. bitwise operators).' - 'Analysis may give wrong results.') + 'Analysis may give wrong results.') print(e.data.assignments, 'was the edge\'s assignments.') curr_state = e.dst found = True break if not found: + # We need to check if we are in an implicit end state (i.e. all outgoing edge conditions evaluate to False) + all_false = True + for e in sdfg.out_edges(curr_state): + if e.data.condition_sympy().subs(mapping) != False: + all_false = False + if all_false: + break + if curr_state in decided_branches: # if the user already decided this branch in a previous iteration, take the same branch again. e = decided_branches[curr_state] - + update_mapping(mapping, e) curr_state = e.dst else: @@ -388,13 +413,12 @@ def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mappin print(f'({i}) for edge to state {edges[i].dst.name}') print(edges[i].dst._read_and_write_sets()) print('merge state is named ', merge_state) - chosen = 1 #int(input('Choose an option from above: ')) + chosen = int(input('Choose an option from above: ')) e = edges[chosen] update_mapping(mapping, e) decided_branches[curr_state] = e curr_state = e.dst - print('we continue with state', e.dst.name) - print(3*'\n') + print(2 * '\n') else: final_e = next_edge_candidates.pop() for e in next_edge_candidates: @@ -409,152 +433,162 @@ def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mappin curr_state = e.dst # walk down this branch until merge_state - # TODO: can we use the return value (misses of different branches) for something? - sdfg_op_in(sdfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, curr_array_names, w_d_map, decided_branches, curr_state, merge_state) + sdfg_op_in(sdfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, + curr_array_names, decided_branches, ask_user, curr_state, merge_state) update_mapping(mapping, final_e) curr_state = final_e.dst if curr_state == end: break - - # if sdfg.name == 'CLOUDSC': - # print('NUM STATES IS: ', num_states) - - # op_in_map[get_uuid(sdfg)] = total_misses if end is None: # only update if we were actually analyzing a whole sdfg (not just start to end state) - update_map(op_in_map, get_uuid(sdfg), total_misses) + update_map(op_in_map, get_uuid(sdfg), total_misses, average=False) return total_misses -def analyze_sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, sp.Expr], C, L, assumptions): + +def analyze_sdfg_op_in(sdfg: SDFG, + op_in_map: Dict[str, sp.Expr], + C, + L, + assumptions, + generate_plots=False, + stringify=False, + test_set_size=3, + ask_user=False): + """ + Computes the operational intensity of the input SDFG. + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary storing the resulting operational intensity for each SDFG element. + :param C: Cache size in bytes. + :param L: Cache line size in bytes. + :param assumptions: Dictionary mapping SDFG symbols to concrete values, e.g. {'N': 8}. At most one symbol might be associated + with a range of (start, stop, step), e.g. {'M' : '2,10,1'}. + :param generate_plots: If True (and there is a range symbol N), a plot showing the operational intensity as a function of N + for the whole SDFG. + :param stringify: If True, the final operational intensity values will the converted to strings. + :param test_set_size: The size of the test set when testing the goodness of fit. + :param ask_user: If True, the user has to decide which branch to analyze in case it cannot be determined automatically. If False, + all branches get analyzed. + """ + + # from now on we take C as the number of lines that fit into cache + C = C // L sdfg = deepcopy(sdfg) # apply SSA pass pipeline = FixedPointPipeline([StrictSymbolSSA()]) pipeline.apply_pass(sdfg, {}) - # print('C as num lines:', C, L, assumptions) - # TODO: insert some checks on whether this sdfg is analyzable, like - # - data-dependent loop bounds (i.e. unbounded executions) - # - indirect accesses (e.g. A[B[i]]) - - - - - - - - # check if all symbols are concretized - standard_range = (4, 16, 2) - num_undefined = 0 + # check if all symbols are concretized (at most one can be associated with a range) + undefined_symbols = set() range_symbol = {} for sym in sdfg.free_symbols: if sym not in assumptions: - num_undefined += 1 - range_symbol[sym] = SymbolRange(standard_range) + undefined_symbols.add(sym) elif isinstance(assumptions[sym], str): - num_undefined += 1 range_symbol[sym] = SymbolRange(int(x) for x in assumptions[sym].split(',')) del assumptions[sym] work_map = {} assumptions_list = [f'{x}=={y}' for x, y in assumptions.items()] - analyze_sdfg(sdfg, work_map, get_tasklet_work, assumptions_list, False) - - - - if num_undefined == 0: - sdfg.specialize(assumptions) - mapping = {} - mapping.update(assumptions) - - stack = AccessStack(C) - clt = CacheLineTracker(L) - # keeps track of user's input on which branches to analyze - decided_branches: Dict[SDFGState, InterstateEdge] = {} - # all symbols concretized, do normal analysis - sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, work_map, decided_branches) - # now we have number of misses --> multiply each by L to get bytes - for k, v in op_in_map.items(): - op_in_map[k] = v[0] * L / v[1] - # divide work by bytes to get operational intensity - calculate_op_in(op_in_map, work_map, assumptions, stringify=True) - - print('bla') - elif num_undefined > 1: - raise Exception('Too many undefined symbols') - else: - assert len(range_symbol) <= 2 - op_in_measurements = {} - - # keeps track of user's input on which branches to analyze - decided_branches: Dict[SDFGState, InterstateEdge] = {} - while True: - new_val = False - for sym, r in range_symbol.items(): - val = r.next() - if val > -1: - new_val = True - assumptions[sym] = val - if not new_val: - break + analyze_sdfg(sdfg, work_map, get_tasklet_work, assumptions_list) - print(assumptions) - curr_op_in_map = {} + if len(undefined_symbols) > 0: + raise Exception( + f'Undefined symbols detected: {undefined_symbols}. Please specify a value for all free symbols of the SDFG.' + ) + else: + # all symbols defined + if len(range_symbol) > 1: + raise Exception('More than one range symbol detected! Only one range symbol allowed.') + elif len(range_symbol) == 0: + # all symbols are concretized --> run normal op_in analysis with concretized symbols + sdfg.specialize(assumptions) mapping = {} mapping.update(assumptions) + stack = AccessStack(C) clt = CacheLineTracker(L) - sdfg_op_in(sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, work_map, decided_branches) - # now we have number of misses --> multiply each by L to get bytes - for k, v in curr_op_in_map.items(): - curr_op_in_map[k] = v[0] * L / v[1] - # divide work by bytes to get operational intensity - calculate_op_in(curr_op_in_map, work_map, assumptions) - - # put curr values in op_in_measurements - for k, v in curr_op_in_map.items(): - if k in op_in_measurements: - op_in_measurements[k].append(v) - else: - op_in_measurements[k] = [v] - - extrapolate(op_in_measurements, range_symbol) - op_in_map.update(op_in_measurements) - - - # TODO: extrapolate not the op_in, but the number of cache misses!!!! Maybe its better?? - - # sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, work_map, decided_branches) - - # # print('Misses: ', op_in_map[get_uuid(sdfg)]) - - - # # now we have number of misses --> multiply each by L to get bytes - # for k, v in op_in_map.items(): - # op_in_map[k] = v * L - # # print('Bytes: ', op_in_map[get_uuid(sdfg)]) - # # print('Work: ', work_map[get_uuid(sdfg)][0]) - - - # # divide work by bytes to get operational intensity - # for uuid in op_in_map: - # try: - # op_in_map[uuid] = str(sp.N(work_map[uuid][0].subs(assumptions) / op_in_map[uuid] if op_in_map[uuid] != 0 else 0)) - # except Exception as e: - # print(work_map[uuid][0] / op_in_map[uuid] if op_in_map[uuid] != 0 else 0) - # raise e - - # print('num memory accesses:', stack.num_calls) - # print('total op_in:', op_in_map[get_uuid(sdfg)]) - # print() + sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + # compute bytes + for k, v in op_in_map.items(): + op_in_map[k] = v[0] / v[1] * L + calculate_op_in(op_in_map, work_map, stringify) + else: + # we have one variable symbol + + # decided_branches: Dict[SDFGState, InterstateEdge] = {} + cache_miss_measurements = {} + work_measurements = [] + t = 0 + while True: + new_val = False + for sym, r in range_symbol.items(): + val = r.next() + if val > -1: + new_val = True + assumptions[sym] = val + elif t < 3: + # now we sample test set + t += 1 + assumptions[sym] = r.max_value() + t * 3 + new_val = True + if not new_val: + break - # for s in decided_branches: - # print(f'\'{s.name}\', ', end='') - # print('\n\n') + curr_op_in_map = {} + mapping = {} + mapping.update(assumptions) + stack = AccessStack(C) + clt = CacheLineTracker(L) + sdfg_op_in(sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + + # compute average cache misses + for k, v in curr_op_in_map.items(): + curr_op_in_map[k] = v[0] / v[1] + + # save cache misses + curr_cache_misses = dict(curr_op_in_map) + + work_measurements.append(work_map[get_uuid(sdfg)][0].subs(assumptions)) + # put curr values in cache_miss_measurements + for k, v in curr_cache_misses.items(): + if k in cache_miss_measurements: + cache_miss_measurements[k].append(v) + else: + cache_miss_measurements[k] = [v] + + symbol_name = next(iter(range_symbol.keys())) + x_values = range_symbol[symbol_name].to_list() + x_values.extend([r.max_value() + t * 3 for t in range(1, test_set_size + 1)]) + + sympy_fs = {} + for k, v in cache_miss_measurements.items(): + final_f, sympy_f, r_s = fit_curve(x_values[:-test_set_size], v[:-test_set_size], symbol_name) + op_in_map[k] = sp.simplify(sympy_f * L) + sympy_fs[k] = sympy_f + if k == get_uuid(sdfg): + # compute MAPE on total SDFG + mape = compute_mape(final_f, x_values[-test_set_size:], v[-test_set_size:], test_set_size) + if mape > 0.2: + print('High MAPE detected:', mape) + print('It is suggested to generate plots and analyze those.') + print('R^2 is:', r_s) + print('A hight R^2 (i.e. close to 1) suggests that we are fitting the test data well.') + print('This combined with high MAPE tells us that our test data does not generalize.') + calculate_op_in(op_in_map, work_map, not generate_plots) + + if generate_plots: + # plot results for the whole SDFG + plot(x_values, work_map, cache_miss_measurements, op_in_map, symbol_name, C, L, sympy_fs, + get_uuid(sdfg), sdfg.name) + if stringify: + for k, v in op_in_map.items(): + op_in_map[k] = str(v) ################################################################################ @@ -584,7 +618,7 @@ def main() -> None: op_in_map = {} if args.assume is None: args.assume = [] - + assumptions = {} for x in args.assume: a, b = x.split('==') @@ -595,7 +629,6 @@ def main() -> None: print(assumptions) analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L), assumptions) - result_whole_sdfg = op_in_map[get_uuid(sdfg)] print(80 * '-') @@ -605,9 +638,3 @@ def main() -> None: if __name__ == '__main__': main() - - - - - - diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index a1193ec8e7..0b257fdbaa 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -70,8 +70,8 @@ def count_work_matmul(node, symbols, state): if len(C_memlet.data.subset) == 3: result *= symeval(C_memlet.data.subset.size()[0], symbols) # M*N - # TODO: line below gives index out of range if we compute matrix vector product (as in e.g. atax from npbench) - result *= symeval(C_memlet.data.subset.size()[-2], symbols) + # we need the if else, since C_memlet is one dimensional in case of matrix vector product + result *= 1 if len(C_memlet.data.subset.size()) < 2 else symeval(C_memlet.data.subset.size()[-2], symbols) result *= symeval(C_memlet.data.subset.size()[-1], symbols) # K result *= symeval(A_memlet.data.subset.size()[-1], symbols) @@ -82,7 +82,7 @@ def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) - return bigo(sp.log(size_shared_dimension)) + return sp.log(size_shared_dimension) def count_work_reduce(node, symbols, state): @@ -102,7 +102,7 @@ def count_work_reduce(node, symbols, state): def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work - return bigo(sp.log(count_work_reduce(node, symbols, state))) + return sp.log(count_work_reduce(node, symbols, state)) LIBNODES_TO_WORK = { @@ -117,7 +117,6 @@ def count_depth_reduce(node, symbols, state): Reduce: count_depth_reduce, } -bigo = sp.Function('bigo') PYFUNC_TO_ARITHMETICS = { 'float': 0, 'dace.float64': 0, @@ -225,7 +224,6 @@ def visit_While(self, node): def count_depth_code(code): - # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism ctr = ArithmeticCounter() if isinstance(code, (tuple, list)): for stmt in code: @@ -241,9 +239,9 @@ def tasklet_work(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: # simplified work analysis for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return oedge.data.num_accesses or 0 # on Lulesh this was None for some tasklet(s) + return oedge.data.num_accesses elif tasklet_node.code.language == dtypes.Language.Python: - return count_arithmetic_ops_code(tasklet_node.code.code) or 0 # on Lulesh this was None for some tasklet(s) + return count_arithmetic_ops_code(tasklet_node.code.code) else: # other languages not implemented, count whole tasklet as work of 1 warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' @@ -291,12 +289,8 @@ def do_initial_subs(w, d, eq, subs1): """ Calls subs three times for the given (w)ork and (d)epth values. """ - try: - result = sp.simplify(sp.sympify(w).subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(sp.sympify(d).subs(eq[0]).subs(eq[1]).subs(subs1)) - except Exception as e: - print('w:', w) - print('d:', d) - raise(e) + result = sp.simplify(sp.sympify(w).subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify( + sp.sympify(d).subs(eq[0]).subs(eq[1]).subs(subs1)) return result @@ -334,10 +328,12 @@ def sdfg_work_depth(sdfg: SDFG, detailed_analysis) # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify(state_work * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify(state_depth * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_work = sp.simplify( + state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify( + state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) @@ -388,20 +384,9 @@ def sdfg_work_depth(sdfg: SDFG, traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() - # print('number of states in this sdfg: ', len(sdfg.states())) - # num_states = 0 - while traversal_q: state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() - # num_states += 1 - # if num_states % 50 == 0: - # print(state.name) - # print('work:', work) - # print() - # print() - - if ie is not None: visited.add(ie) @@ -411,11 +396,7 @@ def sdfg_work_depth(sdfg: SDFG, else: state_value_map[state] = value_map - # ignore assignments such as tmp=x[0], as those do not give much information. - try: - value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in state_value_map[state].items()} - except: - print('gg') + value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in state_value_map[state].items()} n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) n_work = sp.simplify((work + state_works[state]).subs(value_map)) @@ -480,10 +461,19 @@ def sdfg_work_depth(sdfg: SDFG, new_cse_stack.append((work_map[state], depth_map[state])) # same for value_map new_value_map = dict(state_value_map[state]) - new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) + new_value_map.update({ + pystr_to_symbolic(k): + pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + for k, v in oedge.data.assignments.items() + }) traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) else: - value_map.update(oedge.data.assignments) + # value_map.update(oedge.data.assignments) + value_map.update({ + pystr_to_symbolic(k): + pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + for k, v in oedge.data.assignments.items() + }) traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, common_subexpr_stack, value_map)) @@ -498,6 +488,17 @@ def sdfg_work_depth(sdfg: SDFG, sdfg_result = (max_work, max_depth) w_d_map[get_uuid(sdfg)] = sdfg_result + # TODO: + # for k, v in w_d_map.items(): + # w_d_map[k] = (v[0].subs(value_map), v[1].subs(value_map)) + + # TODO: is this needed + for k, (v_w, v_d) in w_d_map.items(): + # The symeval replaces nested SDFG symbols with their global counterparts. + # v_w, v_d = do_subs(v_w, v_d, all_subs) + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) + w_d_map[k] = (v_w, v_d) return sdfg_result @@ -553,9 +554,6 @@ def scope_work_depth( # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work w_d_map[get_uuid(node, state)] = (s_work, s_depth) - elif node == scope_exit: - # don't do anything for exit nodes, everthing handled already in the corresponding entry node. - pass elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in w_d_map t_work, t_depth = analyze_tasklet(node, state) @@ -598,7 +596,7 @@ def scope_work_depth( # Hence, we don't need to add anyting. pass lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) - lib_node_depth = sp.sympify(-1) # not analyzed + lib_node_depth = sp.sympify(-1) if analyze_tasklet != get_tasklet_work: # we are analyzing depth try: @@ -852,7 +850,7 @@ def main() -> None: elif args.analyze == 'work': print("Work:\t", result_whole_sdfg) elif args.analyze == 'avgPar': - print("Average Parallelism:\t", result_whole_sdfg) + print("Average Parallelism:\t", sp.N(result_whole_sdfg)) print(80 * '-') diff --git a/tests/sdfg/operational_intensity_test.py b/tests/sdfg/operational_intensity_test.py index 0dc4f6c7be..fdc2c89a2d 100644 --- a/tests/sdfg/operational_intensity_test.py +++ b/tests/sdfg/operational_intensity_test.py @@ -1,17 +1,12 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the operational intensity analysis. """ import dace as dc -from dace.sdfg.work_depth_analysis.operational_intensity import analyze_sdfg_op_in -from dace.sdfg.work_depth_analysis.helpers import get_uuid import sympy as sp +import numpy as np +from dace.sdfg.work_depth_analysis.operational_intensity import analyze_sdfg_op_in +from dace.sdfg.work_depth_analysis.helpers import get_uuid -from dace.transformation.interstate import NestSDFG -from dace.transformation.dataflow import MapExpansion from math import isclose -from numpy import sum - -# TODO: maybe include tests for column major memory layout. AKA test that strides are taken into account correctly. -# TODO: add tests for library nodes N = dc.symbol('N') M = dc.symbol('M') @@ -20,13 +15,13 @@ TILE_SIZE = dc.symbol('TILE_SIZE') - @dc.program def single_map64(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): z[:] = x + y # does N work, loads 3*N elements of 8 bytes # --> op_in should be N / 3*8*N = 1/24 (no reuse) assuming L divides N + @dc.program def single_map16(x: dc.float16[N], y: dc.float16[N], z: dc.float16[N]): z[:] = x + y @@ -42,15 +37,13 @@ def single_for_loop(x: dc.float64[N], y: dc.float64[N]): # --> 1/16 op in - - @dc.program def if_else(x: dc.int64[100], sum: dc.int64[1]): if x[10] > 50: - for i in range(100): + for i in range(100): sum += x[i] if x[0] > 3: - for i in range(100): + for i in range(100): sum += x[i] # no else --> simply analyze the ifs. if cache big enough, everything is reused @@ -61,7 +54,6 @@ def unaligned_for_loop(x: dc.float32[100], sum: dc.int64[1]): sum += x[i] - @dc.program def sequential_maps(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): z[:] = x + y @@ -70,233 +62,100 @@ def sequential_maps(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): # does N work, loads 3*N elements of 8 bytes # --> op_in should be N / 3*8*N = 1/24 (no reuse) assuming L divides N + @dc.program def nested_reuse(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N], result: dc.float64[1]): # load x, y and z z[:] = x + y - result[0] = sum(z) + result[0] = np.sum(z) # tests whether the access to z from the nested SDFG correspond with the prior accesses # to z outside of the nested SDFG. + @dc.program -def mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N,N]): +def mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N, N]): for n, k, m in dc.map[0:N, 0:N, 0:N]: - z[n,k] += x[n,m] * y[m,k] + z[n, k] += x[n, m] * y[m, k] @dc.program -def tiled_mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N,N]): +def tiled_mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N, N]): for n_TILE, k_TILE, m_TILE in dc.map[0:N:TILE_SIZE, 0:N:TILE_SIZE, 0:N:TILE_SIZE]: - for n, k, m in dc.map[n_TILE:n_TILE+TILE_SIZE, k_TILE:k_TILE+TILE_SIZE, m_TILE:m_TILE+TILE_SIZE]: - z[n,k] += x[n,m] * y[m,k] + for n, k, m in dc.map[n_TILE:n_TILE + TILE_SIZE, k_TILE:k_TILE + TILE_SIZE, m_TILE:m_TILE + TILE_SIZE]: + z[n, k] += x[n, m] * y[m, k] + @dc.program -def tiled_mmm_32(x: dc.float32[N, N], y: dc.float32[N, N], z: dc.float32[N,N]): +def tiled_mmm_32(x: dc.float32[N, N], y: dc.float32[N, N], z: dc.float32[N, N]): for n_TILE, k_TILE, m_TILE in dc.map[0:N:TILE_SIZE, 0:N:TILE_SIZE, 0:N:TILE_SIZE]: - for n, k, m in dc.map[n_TILE:n_TILE+TILE_SIZE, k_TILE:k_TILE+TILE_SIZE, m_TILE:m_TILE+TILE_SIZE]: - z[n,k] += x[n,m] * y[m,k] - - -# @dc.program -# def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): -# if x[10] > 50: -# z[:] = x + y # N work, 1 depth -# else: -# for i in range(K): # K work, K depth -# sum += x[i] - - -# @dc.program -# def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): -# single_map64(x, y, z) -# single_for_loop(x, y) - - -# @dc.program -# def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): -# z[:, :] = x + y - - -# @dc.program -# def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): -# for i in range(N): -# for j in range(K): -# x[i] += y[j] - - -# @dc.program -# def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): -# if x[10] > 50: -# if x[9] > 40: -# z[:] = x + y # N work, 1 depth -# z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth -# else: -# if y[9] > 30: -# for i in range(K): -# sum += x[i] # K work, K depth -# else: -# for j in range(M): -# sum += x[j] # M work, M depth -# z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth -# # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth -# # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth - - -# @dc.program -# def max_of_positive_symbol(x: dc.float64[N]): -# if x[0] > 0: -# for i in range(2 * N): # work 2*N^2, depth 2*N -# x += 1 -# else: -# for j in range(3 * N): # work 3*N^2, depth 3*N -# x += 1 -# # total is work 3*N^2, depth 3*N without any max - + for n, k, m in dc.map[n_TILE:n_TILE + TILE_SIZE, k_TILE:k_TILE + TILE_SIZE, m_TILE:m_TILE + TILE_SIZE]: + z[n, k] += x[n, m] * y[m, k] -# @dc.program -# def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], x2: dc.int64[M], y2: dc.int64[M], -# z2: dc.int64[M], x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): -# if x[0] > 0: -# z[:] = 2 * x + y # work 2*N, depth 2 -# elif x[1] > 0: -# z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 -# z2[0] += 3 + z[1] + z[2] -# elif x[2] > 0: -# z3[:] = 2 * x3 + y3 # work 2*K, depth 2 -# elif x[3] > 0: -# z[:] = 3 * x + y + 1 # work 3*N, depth 3 -# # --> work= Max(3*N, 2*M, 2*K) and depth = 5 - -# @dc.program -# def unbounded_while_do(x: dc.float64[N]): -# while x[0] < 100: -# x += 1 - - -# @dc.program -# def unbounded_do_while(x: dc.float64[N]): -# while True: -# x += 1 -# if x[0] >= 100: -# break - - -# @dc.program -# def unbounded_nonnegify(x: dc.float64[N]): -# while x[0] < 100: -# if x[1] < 42: -# x += 3 * x -# else: -# x += x - - -# @dc.program -# def continue_for_loop(x: dc.float64[N]): -# for i in range(N): -# if x[i] > 100: -# continue -# x += 1 - - -# @dc.program -# def break_for_loop(x: dc.float64[N]): -# for i in range(N): -# if x[i] > 100: -# break -# x += 1 - - -# @dc.program -# def break_while_loop(x: dc.float64[N]): -# while x[0] > 10: -# if x[1] > 100: -# break -# x += 1 - - -# @dc.program -# def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive -# if x[0] > 5: -# x[:] += 1 # N+1 work, 1 depth -# else: -# for i in range(M): # M work, M depth -# y[i + 1] += y[i] -# if M > N: -# y[:N + 1] += x[:] # N+1 work, 1 depth -# else: -# x[:M + 1] += y[:] # M+1 work, 1 depth -# # --> Work: Max(N+1, M) + Max(N+1, M+1) -# # Depth: Max(1, M) + 1 +@dc.program +def reduction_library_node(x: dc.float64[N]): + return np.sum(x) #(sdfg, c, l, assumptions, expected_result) tests_cases = [ - (single_map64, 64*64, 64, {'N' : 512}, 1/24), - (single_map16, 64*64, 64, {'N' : 512}, 1/6), + (single_map64, 64 * 64, 64, { + 'N': 512 + }, 1 / 24), + (single_map16, 64 * 64, 64, { + 'N': 512 + }, 1 / 6), # now num_elements_on_single_cache_line does not divie N anymore # -->513 work, 520 elements loaded --> 513 / (520*8*3) - (single_map64, 64*64, 64, {'N' : 513}, 513 / (3*8*520)), - - - - # # this one fails, but the issue is more broad than the op_in analysis --> skip for now - # (single_for_loop, 64, 64, {'N': 1024}, 1/16) - # # this one fails, but the issue is more broad than the op_in analysis --> skip for now - # (if_else, 1000, 800, {}, 200 / 1600), - # # this one fails, but the issue is more broad than the op_in analysis --> skip for now - # (unaligned_for_loop, -1, -1, {}, -1) - - - (sequential_maps, 1024, 3*8, {'N' : 29}, 87 / (90*8)), + (single_map64, 64 * 64, 64, { + 'N': 513 + }, 513 / (3 * 8 * 520)), + (sequential_maps, 1024, 3 * 8, { + 'N': 29 + }, 87 / (90 * 8)), # smaller cache --> only two arrays fit --> x loaded twice now - (sequential_maps, 6, 3*8, {'N' : 7}, 21 / (13*3*8)), - - - (nested_reuse, 1024, 64, {'N' : 1024}, 2048 / (3*1024*8 + 128)), - (mmm, 20, 16, {'N': 24}, (2*24**3) / ((36*24**2 + 24*12) * 16)), - (tiled_mmm, 20, 16, {'N': 24, 'TILE_SIZE' : 4}, (2*24**3) / (16*24*6**3)), - (tiled_mmm_32, 10, 16, {'N': 24, 'TILE_SIZE' : 4}, (2*24**3) / (16*12*6**3)), - - - # (nested_sdfg, (2 * N, N + 1)), - # (nested_maps, (M * N, 1)), - # (nested_for_loops, (K * N, K * N)), - # (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), - # (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - # (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) + (sequential_maps, 6, 3 * 8, { + 'N': 7 + }, 21 / (13 * 3 * 8)), + (nested_reuse, 1024, 64, { + 'N': 1024 + }, 2048 / (3 * 1024 * 8 + 128)), + (mmm, 20, 16, { + 'N': 24 + }, (2 * 24**3) / ((36 * 24**2 + 24 * 12) * 16)), + (tiled_mmm, 20, 16, { + 'N': 24, + 'TILE_SIZE': 4 + }, (2 * 24**3) / (16 * 24 * 6**3)), + (tiled_mmm_32, 10, 16, { + 'N': 24, + 'TILE_SIZE': 4 + }, (2 * 24**3) / (16 * 12 * 6**3)), + (reduction_library_node, 1024, 64, { + 'N': 128 + }, 128.0 / (dc.symbol('Reduce_misses') * 64.0 + 64.0)), ] -# tests_cases = [ -# (nested_reuse, 1024, 64, {'N' : 1024}, 2048 / (3*1024*8 + 128)) -# ] - def test_operational_intensity(): - errors = 0 for test, c, l, assumptions, correct in tests_cases: op_in_map = {} sdfg = test.to_sdfg() - sdfg.expand_library_nodes() - if test.name == 'mmm': - sdfg.save('mmm.sdfg') - if 'nested_sdfg' in test.name: - sdfg.apply_transformations(NestSDFG) - if 'nested_maps' in test.name: - sdfg.apply_transformations(MapExpansion) - analyze_sdfg_op_in(sdfg, op_in_map, c, l, assumptions) - res = float(op_in_map[get_uuid(sdfg)]) - # substitue each symbol without assumptions. - # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. - # check result - # assert correct == res - if not isclose(correct, res): - print(sdfg.name) - print(c, l, assumptions, correct, res) - print('ERROR DETECTED') - errors += 1 + if test.name == 'nested_reuse': + sdfg.expand_library_nodes() + analyze_sdfg_op_in(sdfg, op_in_map, c * l, l, assumptions) + res = (op_in_map[get_uuid(sdfg)]) + if test.name == 'reduction_library_node': + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in res.free_symbols} + res = res.subs(reps) + reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} + correct = sp.sympify(correct).subs(reps) + assert correct == res + else: + assert isclose(correct, res) - print(f'Encountered {errors} failing tests out of {len(tests_cases)} tests') if __name__ == '__main__': test_operational_intensity() diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py index 05375007df..9f79359927 100644 --- a/tests/sdfg/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -1,19 +1,17 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ import dace as dc -from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth, parse_assumptions +from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth, get_tasklet_avg_par, parse_assumptions from dace.sdfg.work_depth_analysis.helpers import get_uuid from dace.sdfg.work_depth_analysis.assumptions import ContradictingAssumptions import sympy as sp +import numpy as np from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion from pytest import raises -# TODO: add tests for library nodes (e.g. reduce, matMul) -# TODO: add tests for average parallelism - N = dc.symbol('N') M = dc.symbol('M') K = dc.symbol('K') @@ -172,6 +170,26 @@ def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot ass # Depth: Max(1, M) + 1 +@dc.program +def reduction_library_node(x: dc.float64[456]): + return np.sum(x) + + +@dc.program +def reduction_library_node_symbolic(x: dc.float64[N]): + return np.sum(x) + + +@dc.program +def gemm_library_node(x: dc.float64[456, 200], y: dc.float64[200, 111], z: dc.float64[456, 111]): + z[:] = x @ y + + +@dc.program +def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.float64[M, N]): + z[:] = x @ y + + #(sdfg, (expected_work, expected_depth)) tests_cases = [ (single_map, (N, 1)), @@ -191,7 +209,11 @@ def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot ass (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), (break_for_loop, (N**2, N)), (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), - (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) + (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), + (reduction_library_node, (456, sp.log(456))), + (reduction_library_node_symbolic, (N, sp.log(N))), + (gemm_library_node, (2 * 456 * 200 * 111, sp.log(200))), + (gemm_library_node_symbolic, (2 * M * K * N, sp.log(K))) ] @@ -218,6 +240,36 @@ def test_work_depth(): assert correct == res +#(sdfg, expected_avg_par) +tests_cases_avg_par = [(single_map, N), (single_for_loop, 1), (if_else, 1), (nested_sdfg, 2 * N / (N + 1)), + (nested_maps, N * M), (nested_for_loops, 1), + (max_of_positive_symbol, N), (unbounded_while_do, N), (unbounded_do_while, N), + (unbounded_nonnegify, N), (continue_for_loop, N), (break_for_loop, N), (break_while_loop, N), + (reduction_library_node, 456 / sp.log(456)), (reduction_library_node_symbolic, N / sp.log(N)), + (gemm_library_node, 2 * 456 * 200 * 111 / sp.log(200)), + (gemm_library_node_symbolic, 2 * M * K * N / sp.log(K))] + + +def test_avg_par(): + for test, correct in tests_cases_avg_par: + w_d_map = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) + res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in res.free_symbols} + res = res.subs(reps) + reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} + correct = sp.sympify(correct).subs(reps) + # check result + assert correct == res + + x, y, z, a = sp.symbols('x y z a') # (expr, assumptions, result) @@ -259,4 +311,5 @@ def test_assumption_system(): if __name__ == '__main__': test_work_depth() + test_avg_par() test_assumption_system()