diff --git a/coffee/base.py b/coffee/base.py index d849097c..e10b14a7 100644 --- a/coffee/base.py +++ b/coffee/base.py @@ -788,8 +788,6 @@ def nonzero(self): def gencode(self, not_scope=False): pointers = " " + " ".join(['*' + ' '.join(i) for i in self.pointers]) prefix = "" - if self.pragma: - prefix = "\n".join(p for p in self.pragma) + "\n" if isinstance(self.init, EmptyStatement): return prefix + decl(spacer(self.qual), self.typ + pointers, self.sym.gencode(), spacer(self.attr)) + semicolon(not_scope) @@ -833,13 +831,7 @@ class For(Statement): for (int i = 0, j = 0; ...)""" def __init__(self, init, cond, incr, body, pragma=None): - # If the body is a plain list, cast it to a Block. - if not isinstance(body, Block): - if not isinstance(body, list): - body = [body] - body = Block(body, open_scope=True) - - super(For, self).__init__([body], pragma) + super(For, self).__init__([enforce_block(body)], pragma) self.init = init self.cond = cond self.incr = incr @@ -968,7 +960,7 @@ class FunDecl(Statement): static inline void foo(int a, int b) {return;};""" def __init__(self, ret, name, args, body, pred=[], headers=None): - super(FunDecl, self).__init__([body]) + super(FunDecl, self).__init__([enforce_block(body)]) self.pred = pred self.ret = ret self.name = name @@ -1244,6 +1236,15 @@ def c_flat_for(code, parent): return new_block +def enforce_block(body, open_scope=True): + """Wrap ``body`` in a Block if not already a Block.""" + if not isinstance(body, Block): + if not isinstance(body, list): + body = [body] + body = Block(body, open_scope=open_scope) + return body + + # Access modes for a symbol ## diff --git a/coffee/utils.py b/coffee/utils.py index b8994b9f..e0885b21 100644 --- a/coffee/utils.py +++ b/coffee/utils.py @@ -83,50 +83,6 @@ def _ast_replace(node, to_replace, n_replaced): return n_replaced -def ast_remove(node, to_remove, mode='all'): - """Remove the AST node ``to_remove`` from the tree rooted in ``node``. - - :param mode: either ``all``, in which case ``to_remove`` is turned into a - string (if not a string already) and all of its occurrences are removed - from the AST; or ``symbol``, in which case only (all of) the references - to the provided ``to_remove`` node are cut away. - """ - - def _is_removable(n, tr): - n, tr = (str(n), str(tr)) if mode == 'all' else (n, tr) - return True if n == tr else False - - def _ast_remove(node, parent, index, tr): - if _is_removable(node, tr): - return -1 - if not node.children: - return index - _may_remove = [_ast_remove(n, node, i, tr) for i, n in enumerate(node.children)] - if all([i > -1 for i in _may_remove]): - # No removals occurred, so just return - return index - if all([i == -1 for i in _may_remove]): - # Removed all of the children, so I'm also going to remove myself - return -1 - alive = [i for i in _may_remove if i > -1] - if len(alive) > 1: - # Some children were removed, but not all of them, so no surgery needed - return index - # One child left, need to reattach it as child of my parent - alive = alive[0] - parent.children[index] = node.children[alive] - return index - - if mode not in ['all', 'symbol']: - raise ValueError - - try: - if all(_ast_remove(node, None, None, tr) == -1 for tr in to_remove): - return -1 - except TypeError: - return _ast_remove(node, None, None, to_remove) - - def ast_update_ofs(node, ofs, **kwargs): """Change the offsets of the iteration space variables of the symbols rooted in ``node``. @@ -227,35 +183,16 @@ def _ast_make_bal_expr(nodes): return _ast_make_expr(nodes) -def ast_make_alias(node1, node2): - """Return an object in which the LHS is represented by ``node1`` and the RHS - by ``node2``, and ``node1`` is an alias for ``node2``; that is, ``node1`` - will point to the same memory region of ``node2``. - - :type node1: either a ``Decl`` or a ``Symbol``. If a ``Decl`` is provided, - the init field of the ``Decl`` is used to assign the alias. - :type node2: either a ``Decl`` or a ``Symbol``. If a ``Decl`` is provided, - the symbol is extracted and used for the assignment. +def ast_make_alias(node, alias_name): """ - if not isinstance(node1, (Decl, Symbol)): - raise RuntimeError("Cannot assign a pointer to %s type" % type(node1)) - if not isinstance(node2, (Decl, Symbol)): - raise RuntimeError("Cannot assign a pointer to %s type" % type(node1)) - - # Handle node2 - if isinstance(node2, Decl): - node2 = node2.sym - node2.symbol = node2.symbol.strip('*') - node2.rank, node2.offset, node2.loop_dep = (), (), () - - # Handle node1 - if isinstance(node1, Symbol): - node1.symbol = node1.symbol.strip('*') - node1.rank, node1.offset, node1.loop_dep = (), (), () - return Assign(node1, node2) - else: - node1.init = node2 - return node1 + Create an alias of ``node`` (must be of type Decl). The alias symbol is + given the name ``alias_name``. + """ + assert isinstance(node, Decl) + + pointers = node.pointers + ['' for i in node.size] + return Decl(node.typ, alias_name, node.lvalue.symbol, qualifiers=node.qual, + scope=node.scope, pointers=pointers) ########################################################### diff --git a/coffee/visitors/utilities.py b/coffee/visitors/utilities.py index eb75cbe9..13786438 100644 --- a/coffee/visitors/utilities.py +++ b/coffee/visitors/utilities.py @@ -9,7 +9,7 @@ from coffee.visitor import Visitor from coffee.base import Sum, Sub, Prod, Div, ArrayInit, SparseArrayInit -from coffee.utils import ItSpace, flatten +import coffee.utils __all__ = ["ReplaceSymbols", "CheckUniqueness", "Uniquify", "Evaluate", @@ -119,9 +119,11 @@ def __init__(self, decls, track_zeros): Prod: np.multiply, Div: np.divide } - from coffee.vectorizer import vect_roundup, vect_rounddown - self.up = vect_roundup - self.down = vect_rounddown + + import coffee.vectorizer + self.up = coffee.vectorizer.vect_roundup + self.down = coffee.vectorizer.vect_rounddown + self.make_itspace = coffee.utils.ItSpace super(Evaluate, self).__init__() def visit_object(self, o, *args, **kwargs): @@ -197,7 +199,7 @@ def visit_Writer(self, o, *args, **kwargs): for k, g in itertools.groupby(sorted(mapper[-1].keys()), grouper): group = list(g) ranges.append((group[-1]-group[0]+1, group[0])) - nonzero[-1] = ItSpace(mode=1).merge(ranges or nonzero[-1], within=-1) + nonzero[-1] = self.make_itspace(mode=1).merge(ranges or nonzero[-1], within=-1) return {lvalue: SparseArrayInit(values, precision, tuple(nonzero))} @@ -301,13 +303,13 @@ def visit_Prod(self, o, parent=None, *args, **kwargs): projection = self.default_retval() for n in o.children: projection.extend(self.visit(n, parent=o, *args, **kwargs)) - return [list(flatten(projection))] + return [list(coffee.utils.flatten(projection))] else: # Only the top level Prod, in a chain of Prods, should do the # tensor product projection = [self.visit(n, parent=o, *args, **kwargs) for n in o.children] product = itertools.product(*projection) - ret = [list(flatten(i)) for i in product] or projection + ret = [list(coffee.utils.flatten(i)) for i in product] or projection return ret def visit_Symbol(self, o, *args, **kwargs): @@ -434,7 +436,7 @@ def visit_Prod(self, o, ret=None, syms=None, parent=None, *args, **kwargs): if all(i for i in loc_syms): self._update_mapper(mapper, loc_syms) loc_syms = itertools.product(*loc_syms) - loc_syms = [tuple(flatten(i)) for i in loc_syms] + loc_syms = [tuple(coffee.utils.flatten(i)) for i in loc_syms] syms |= set(loc_syms) G.add_edges_from(loc_syms) else: