Skip to content

Commit

Permalink
Merge pull request #84 from coneoproject/minor-fixes
Browse files Browse the repository at this point in the history
* minor-fixes:
  Fix circular import: utils, visitors, vectorizer
  Remove now unused ast_remove function
  Enforce Block as FunDecl direct child
  Simplify make alias utils function
  Do not print pragmas on Decls
  • Loading branch information
miklos1 committed Jul 7, 2016
2 parents 98377cf + 4d90f8a commit 8d53e61
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 90 deletions.
21 changes: 11 additions & 10 deletions coffee/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,8 +788,6 @@ def nonzero(self):
def gencode(self, not_scope=False):
pointers = " " + " ".join(['*' + ' '.join(i) for i in self.pointers])
prefix = ""
if self.pragma:
prefix = "\n".join(p for p in self.pragma) + "\n"
if isinstance(self.init, EmptyStatement):
return prefix + decl(spacer(self.qual), self.typ + pointers, self.sym.gencode(),
spacer(self.attr)) + semicolon(not_scope)
Expand Down Expand Up @@ -833,13 +831,7 @@ class For(Statement):
for (int i = 0, j = 0; ...)"""

def __init__(self, init, cond, incr, body, pragma=None):
# If the body is a plain list, cast it to a Block.
if not isinstance(body, Block):
if not isinstance(body, list):
body = [body]
body = Block(body, open_scope=True)

super(For, self).__init__([body], pragma)
super(For, self).__init__([enforce_block(body)], pragma)
self.init = init
self.cond = cond
self.incr = incr
Expand Down Expand Up @@ -968,7 +960,7 @@ class FunDecl(Statement):
static inline void foo(int a, int b) {return;};"""

def __init__(self, ret, name, args, body, pred=[], headers=None):
super(FunDecl, self).__init__([body])
super(FunDecl, self).__init__([enforce_block(body)])
self.pred = pred
self.ret = ret
self.name = name
Expand Down Expand Up @@ -1244,6 +1236,15 @@ def c_flat_for(code, parent):
return new_block


def enforce_block(body, open_scope=True):
"""Wrap ``body`` in a Block if not already a Block."""
if not isinstance(body, Block):
if not isinstance(body, list):
body = [body]
body = Block(body, open_scope=open_scope)
return body


# Access modes for a symbol ##


Expand Down
81 changes: 9 additions & 72 deletions coffee/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,50 +83,6 @@ def _ast_replace(node, to_replace, n_replaced):
return n_replaced


def ast_remove(node, to_remove, mode='all'):
"""Remove the AST node ``to_remove`` from the tree rooted in ``node``.
:param mode: either ``all``, in which case ``to_remove`` is turned into a
string (if not a string already) and all of its occurrences are removed
from the AST; or ``symbol``, in which case only (all of) the references
to the provided ``to_remove`` node are cut away.
"""

def _is_removable(n, tr):
n, tr = (str(n), str(tr)) if mode == 'all' else (n, tr)
return True if n == tr else False

def _ast_remove(node, parent, index, tr):
if _is_removable(node, tr):
return -1
if not node.children:
return index
_may_remove = [_ast_remove(n, node, i, tr) for i, n in enumerate(node.children)]
if all([i > -1 for i in _may_remove]):
# No removals occurred, so just return
return index
if all([i == -1 for i in _may_remove]):
# Removed all of the children, so I'm also going to remove myself
return -1
alive = [i for i in _may_remove if i > -1]
if len(alive) > 1:
# Some children were removed, but not all of them, so no surgery needed
return index
# One child left, need to reattach it as child of my parent
alive = alive[0]
parent.children[index] = node.children[alive]
return index

if mode not in ['all', 'symbol']:
raise ValueError

try:
if all(_ast_remove(node, None, None, tr) == -1 for tr in to_remove):
return -1
except TypeError:
return _ast_remove(node, None, None, to_remove)


def ast_update_ofs(node, ofs, **kwargs):
"""Change the offsets of the iteration space variables of the symbols rooted
in ``node``.
Expand Down Expand Up @@ -227,35 +183,16 @@ def _ast_make_bal_expr(nodes):
return _ast_make_expr(nodes)


def ast_make_alias(node1, node2):
"""Return an object in which the LHS is represented by ``node1`` and the RHS
by ``node2``, and ``node1`` is an alias for ``node2``; that is, ``node1``
will point to the same memory region of ``node2``.
:type node1: either a ``Decl`` or a ``Symbol``. If a ``Decl`` is provided,
the init field of the ``Decl`` is used to assign the alias.
:type node2: either a ``Decl`` or a ``Symbol``. If a ``Decl`` is provided,
the symbol is extracted and used for the assignment.
def ast_make_alias(node, alias_name):
"""
if not isinstance(node1, (Decl, Symbol)):
raise RuntimeError("Cannot assign a pointer to %s type" % type(node1))
if not isinstance(node2, (Decl, Symbol)):
raise RuntimeError("Cannot assign a pointer to %s type" % type(node1))

# Handle node2
if isinstance(node2, Decl):
node2 = node2.sym
node2.symbol = node2.symbol.strip('*')
node2.rank, node2.offset, node2.loop_dep = (), (), ()

# Handle node1
if isinstance(node1, Symbol):
node1.symbol = node1.symbol.strip('*')
node1.rank, node1.offset, node1.loop_dep = (), (), ()
return Assign(node1, node2)
else:
node1.init = node2
return node1
Create an alias of ``node`` (must be of type Decl). The alias symbol is
given the name ``alias_name``.
"""
assert isinstance(node, Decl)

pointers = node.pointers + ['' for i in node.size]
return Decl(node.typ, alias_name, node.lvalue.symbol, qualifiers=node.qual,
scope=node.scope, pointers=pointers)


###########################################################
Expand Down
18 changes: 10 additions & 8 deletions coffee/visitors/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from coffee.visitor import Visitor
from coffee.base import Sum, Sub, Prod, Div, ArrayInit, SparseArrayInit
from coffee.utils import ItSpace, flatten
import coffee.utils


__all__ = ["ReplaceSymbols", "CheckUniqueness", "Uniquify", "Evaluate",
Expand Down Expand Up @@ -119,9 +119,11 @@ def __init__(self, decls, track_zeros):
Prod: np.multiply,
Div: np.divide
}
from coffee.vectorizer import vect_roundup, vect_rounddown
self.up = vect_roundup
self.down = vect_rounddown

import coffee.vectorizer
self.up = coffee.vectorizer.vect_roundup
self.down = coffee.vectorizer.vect_rounddown
self.make_itspace = coffee.utils.ItSpace
super(Evaluate, self).__init__()

def visit_object(self, o, *args, **kwargs):
Expand Down Expand Up @@ -197,7 +199,7 @@ def visit_Writer(self, o, *args, **kwargs):
for k, g in itertools.groupby(sorted(mapper[-1].keys()), grouper):
group = list(g)
ranges.append((group[-1]-group[0]+1, group[0]))
nonzero[-1] = ItSpace(mode=1).merge(ranges or nonzero[-1], within=-1)
nonzero[-1] = self.make_itspace(mode=1).merge(ranges or nonzero[-1], within=-1)

return {lvalue: SparseArrayInit(values, precision, tuple(nonzero))}

Expand Down Expand Up @@ -301,13 +303,13 @@ def visit_Prod(self, o, parent=None, *args, **kwargs):
projection = self.default_retval()
for n in o.children:
projection.extend(self.visit(n, parent=o, *args, **kwargs))
return [list(flatten(projection))]
return [list(coffee.utils.flatten(projection))]
else:
# Only the top level Prod, in a chain of Prods, should do the
# tensor product
projection = [self.visit(n, parent=o, *args, **kwargs) for n in o.children]
product = itertools.product(*projection)
ret = [list(flatten(i)) for i in product] or projection
ret = [list(coffee.utils.flatten(i)) for i in product] or projection
return ret

def visit_Symbol(self, o, *args, **kwargs):
Expand Down Expand Up @@ -434,7 +436,7 @@ def visit_Prod(self, o, ret=None, syms=None, parent=None, *args, **kwargs):
if all(i for i in loc_syms):
self._update_mapper(mapper, loc_syms)
loc_syms = itertools.product(*loc_syms)
loc_syms = [tuple(flatten(i)) for i in loc_syms]
loc_syms = [tuple(coffee.utils.flatten(i)) for i in loc_syms]
syms |= set(loc_syms)
G.add_edges_from(loc_syms)
else:
Expand Down

0 comments on commit 8d53e61

Please sign in to comment.