Skip to content

Commit

Permalink
Fix circular import: utils, visitors, vectorizer
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Jul 6, 2016
1 parent 6f34c69 commit 4d90f8a
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions coffee/visitors/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from coffee.visitor import Visitor
from coffee.base import Sum, Sub, Prod, Div, ArrayInit, SparseArrayInit
from coffee.utils import ItSpace, flatten
import coffee.utils


__all__ = ["ReplaceSymbols", "CheckUniqueness", "Uniquify", "Evaluate",
Expand Down Expand Up @@ -119,9 +119,11 @@ def __init__(self, decls, track_zeros):
Prod: np.multiply,
Div: np.divide
}
from coffee.vectorizer import vect_roundup, vect_rounddown
self.up = vect_roundup
self.down = vect_rounddown

import coffee.vectorizer
self.up = coffee.vectorizer.vect_roundup
self.down = coffee.vectorizer.vect_rounddown
self.make_itspace = coffee.utils.ItSpace
super(Evaluate, self).__init__()

def visit_object(self, o, *args, **kwargs):
Expand Down Expand Up @@ -197,7 +199,7 @@ def visit_Writer(self, o, *args, **kwargs):
for k, g in itertools.groupby(sorted(mapper[-1].keys()), grouper):
group = list(g)
ranges.append((group[-1]-group[0]+1, group[0]))
nonzero[-1] = ItSpace(mode=1).merge(ranges or nonzero[-1], within=-1)
nonzero[-1] = self.make_itspace(mode=1).merge(ranges or nonzero[-1], within=-1)

return {lvalue: SparseArrayInit(values, precision, tuple(nonzero))}

Expand Down Expand Up @@ -301,13 +303,13 @@ def visit_Prod(self, o, parent=None, *args, **kwargs):
projection = self.default_retval()
for n in o.children:
projection.extend(self.visit(n, parent=o, *args, **kwargs))
return [list(flatten(projection))]
return [list(coffee.utils.flatten(projection))]
else:
# Only the top level Prod, in a chain of Prods, should do the
# tensor product
projection = [self.visit(n, parent=o, *args, **kwargs) for n in o.children]
product = itertools.product(*projection)
ret = [list(flatten(i)) for i in product] or projection
ret = [list(coffee.utils.flatten(i)) for i in product] or projection
return ret

def visit_Symbol(self, o, *args, **kwargs):
Expand Down Expand Up @@ -434,7 +436,7 @@ def visit_Prod(self, o, ret=None, syms=None, parent=None, *args, **kwargs):
if all(i for i in loc_syms):
self._update_mapper(mapper, loc_syms)
loc_syms = itertools.product(*loc_syms)
loc_syms = [tuple(flatten(i)) for i in loc_syms]
loc_syms = [tuple(coffee.utils.flatten(i)) for i in loc_syms]
syms |= set(loc_syms)
G.add_edges_from(loc_syms)
else:
Expand Down

0 comments on commit 4d90f8a

Please sign in to comment.