Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

wence/better type inference #215

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions gem/impero_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from functools import singledispatch
from itertools import chain, groupby

from numpy import find_common_type

from gem.node import traversal, collect_refcount
from gem import gem, impero as imp, optimise, scheduling

Expand All @@ -21,7 +23,9 @@
# temporaries - List of GEM expressions which have assigned temporaries
# declare - Where to declare temporaries to get correct C code
# indices - Indices for declarations and referencing values
ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices'])
# return_variable - 2-tuple of gem return variable and inferred numpy dtype
ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices',
'return_variable'])


class NoopError(Exception):
Expand All @@ -38,11 +42,12 @@ def preprocess_gem(expressions, replace_delta=True, remove_componenttensors=True
return expressions


def compile_gem(assignments, prefix_ordering, remove_zeros=False):
def compile_gem(assignments, prefix_ordering, scalar_type, remove_zeros=False):
"""Compiles GEM to Impero.

:arg assignments: list of (return variable, expression DAG root) pairs
:arg prefix_ordering: outermost loop indices
:arg scalar_type: default scalar type
:arg remove_zeros: remove zero assignment to return variables
"""
# Remove zeros
Expand All @@ -52,6 +57,9 @@ def nonzero(assignment):
return not isinstance(expression, gem.Zero)
assignments = list(filter(nonzero, assignments))

# Type inference for return value
return_variable = infer_dtype(assignments, scalar_type)

# Just the expressions
expressions = [expression for variable, expression in assignments]

Expand Down Expand Up @@ -88,7 +96,26 @@ def nonzero(assignment):
declare, indices = place_declarations(tree, temporaries, get_indices)

# Prepare ImperoC (Impero AST + other data for code generation)
return ImperoC(tree, temporaries, declare, indices)
return ImperoC(tree, temporaries, declare, indices, return_variable)


def infer_dtype(assignments, scalar_type):
from tsfc.loopy import assign_dtypes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is just a draft pull request, but this line would create a circularity in package dependency, by making gem depend on tsfc. This in undesirable: gem should not depend on anything other than the Python standard library and quasi-standard packages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's easy to fix by doing the more general gem typing approach.

from gem.node import traversal

def extract_variable(expr):
x, = set(v for v in traversal([expr]) if isinstance(v, gem.Variable))
return x

vars = set()
dtypes = set()
for var, expression in assignments:
var = extract_variable(var)
((_, dtype), ) = assign_dtypes([expression], scalar_type)
vars.add(var)
dtypes.add(dtype)
var, = vars
return var, find_common_type([], dtypes)


def make_prefix_ordering(indices, prefix_ordering):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_codegen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy

from gem import impero_utils
from gem.gem import Index, Indexed, IndexSum, Product, Variable
Expand All @@ -18,7 +19,7 @@ def make_expression(i, j):
e2 = make_expression(i, i)

def gencode(expr):
impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j))
impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j), numpy.dtype(numpy.float64))
return impero_c.tree

assert len(gencode(e1).children) == len(gencode(e2).children)
Expand Down
6 changes: 4 additions & 2 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
for var in return_variables]))
index_ordering = tuple(quadrature_indices) + split_argument_indices
try:
impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True)
impero_c = impero_utils.compile_gem(assignments, index_ordering,
parameters["scalar_type"], remove_zeros=True)
except impero_utils.NoopError:
# No operations, construct empty kernel
return builder.construct_empty_kernel(kernel_name)
Expand Down Expand Up @@ -421,7 +422,8 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, inte
# TODO: one should apply some GEM optimisations as in assembly,
# but we don't for now.
ir, = impero_utils.preprocess_gem([ir])
impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices)
impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices,
parameters["scalar_type"])
index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices))
# Handle kernel interface requirements
builder.register_requirements([ir])
Expand Down
12 changes: 10 additions & 2 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def __init__(self, interface, **kwargs):
raise ValueError("unexpected keyword argument '{0}'".format(invalid_keywords.pop()))
self.__dict__.update(kwargs)

def reify(self, expr):
if self.complex_mode:
indices = gem.indices(len(expr.shape))
return gem.ComponentTensor(gem.MathFunction("real", gem.Indexed(expr, indices)),
indices)
else:
return expr

@cached_property
def fiat_cell(self):
return as_fiat_cell(self.ufl_cell)
Expand Down Expand Up @@ -136,7 +144,7 @@ def config(self):
return config

def cell_size(self):
return self.interface.cell_size(self.mt.restriction)
return self.interface.reify(self.interface.cell_size(self.mt.restriction))

def jacobian_at(self, point):
expr = Jacobian(self.mt.terminal.ufl_domain())
Expand Down Expand Up @@ -427,7 +435,7 @@ def translate_spatialcoordinate(terminal, mt, ctx):
# Rebuild modified terminal
expr = construct_modified_terminal(mt, terminal)
# Translate replaced UFL snippet
return ctx.translator(expr)
return ctx.reify(ctx.translator(expr))


class CellVolumeKernelInterface(ProxyKernelInterface):
Expand Down
5 changes: 4 additions & 1 deletion tsfc/kernel_interface/firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
class Kernel(object):
__slots__ = ("ast", "integral_type", "oriented", "subdomain_id",
"domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule",
"coefficient_numbers", "__weakref__")
"return_dtype", "coefficient_numbers", "__weakref__")
"""A compiled Kernel object.

:kwarg ast: The COFFEE ast for the kernel.
Expand All @@ -40,12 +40,14 @@ class Kernel(object):
:kwarg coefficient_numbers: A list of which coefficients from the
form the kernel needs.
:kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
:kwarg return_dtype: numpy dtype of the return value.
:kwarg tabulations: The runtime tabulations this kernel requires
:kwarg needs_cell_sizes: Does the kernel require cell sizes.
"""
def __init__(self, ast=None, integral_type=None, oriented=False,
subdomain_id=None, domain_number=None, quadrature_rule=None,
coefficient_numbers=(),
return_dtype=None,
needs_cell_sizes=False):
# Defaults
self.ast = ast
Expand All @@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
self.subdomain_id = subdomain_id
self.coefficient_numbers = coefficient_numbers
self.needs_cell_sizes = needs_cell_sizes
self.return_dtype = return_dtype
super(Kernel, self).__init__()


Expand Down
23 changes: 17 additions & 6 deletions tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
class Kernel(object):
__slots__ = ("ast", "integral_type", "oriented", "subdomain_id",
"domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule",
"coefficient_numbers", "__weakref__")
"return_dtype", "coefficient_numbers", "__weakref__")
"""A compiled Kernel object.

:kwarg ast: The loopy kernel object.
Expand All @@ -40,12 +40,14 @@ class Kernel(object):
:kwarg coefficient_numbers: A list of which coefficients from the
form the kernel needs.
:kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
:kwarg return_dtype: numpy dtype of the return value.
:kwarg tabulations: The runtime tabulations this kernel requires
:kwarg needs_cell_sizes: Does the kernel require cell sizes.
"""
def __init__(self, ast=None, integral_type=None, oriented=False,
subdomain_id=None, domain_number=None, quadrature_rule=None,
coefficient_numbers=(),
return_dtype=None,
needs_cell_sizes=False):
# Defaults
self.ast = ast
Expand All @@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
self.subdomain_id = subdomain_id
self.coefficient_numbers = coefficient_numbers
self.needs_cell_sizes = needs_cell_sizes
self.return_dtype = return_dtype
super(Kernel, self).__init__()


Expand Down Expand Up @@ -164,8 +167,8 @@ def construct_kernel(self, return_arg, impero_c, precision, index_names):
for name_, shape in self.tabulations:
args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape))

loopy_kernel = generate_loopy(impero_c, args, precision, self.scalar_type,
"expression_kernel", index_names)
loopy_kernel, _ = generate_loopy(impero_c, args, precision, self.scalar_type,
"expression_kernel", index_names, ignore_return_type=True)
return ExpressionKernel(loopy_kernel, self.oriented, self.cell_sizes,
self.coefficients, self.tabulations)

Expand Down Expand Up @@ -207,6 +210,7 @@ def set_arguments(self, arguments, multiindices):
:arg multiindices: GEM argument multiindices
:returns: GEM expression representing the return variable
"""
self.rank = len(arguments)
self.local_tensor, expressions = prepare_arguments(
arguments, multiindices, self.scalar_type, interior_facet=self.interior_facet,
diagonal=self.diagonal)
Expand Down Expand Up @@ -277,7 +281,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
:returns: :class:`Kernel` object
"""

args = [self.local_tensor, self.coordinates_arg]
ignore_return_type = self.rank > 0
if ignore_return_type:
args = [self.local_tensor, self.coordinates_arg]
else:
args = [self.coordinates_arg]
if self.kernel.oriented:
args.append(self.cell_orientations_loopy_arg)
if self.kernel.needs_cell_sizes:
Expand All @@ -292,8 +300,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape))

self.kernel.quadrature_rule = quadrature_rule
self.kernel.ast = generate_loopy(impero_c, args, precision,
self.scalar_type, name, index_names)
ast, dtype = generate_loopy(impero_c, args, precision,
self.scalar_type, name, index_names,
ignore_return_type=ignore_return_type)
self.kernel.ast = ast
self.kernel.return_dtype = dtype
return self.kernel

def construct_empty_kernel(self, name):
Expand Down
15 changes: 11 additions & 4 deletions tsfc/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,17 @@ def active_indices(mapping, ctx):
ctx.active_indices.pop(key)


def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[]):
def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[],
ignore_return_type=True):
"""Generates loopy code.

:arg impero_c: ImperoC tuple with Impero AST and other data
:arg args: list of loopy.GlobalArgs
:arg precision: floating-point precision for printing
:arg scalar_type: type of scalars as C typename string
:arg scalar_type: type of scalars as numpy dtype
:arg kernel_name: function name of the kernel
:arg index_names: pre-assigned index names
:arg ignore_return_type: Ignore inferred return type from impero_c?
:returns: loopy kernel
"""
ctx = LoopyContext()
Expand All @@ -205,7 +207,12 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel",
ctx.epsilon = 10.0 ** (-precision)

# Create arguments
data = list(args)
if ignore_return_type:
return_dtype = scalar_type
data = list(args)
else:
A, return_dtype = impero_c.return_variable
data = [lp.GlobalArg(A.name, shape=A.shape, dtype=return_dtype)] + list(args)
for i, (temp, dtype) in enumerate(assign_dtypes(impero_c.temporaries, scalar_type)):
name = "t%d" % i
if isinstance(temp, gem.Constant):
Expand Down Expand Up @@ -240,7 +247,7 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel",
insn_new.append(insn.copy(priority=len(knl.instructions) - i))
knl = knl.copy(instructions=insn_new)

return knl
return knl, return_dtype


@singledispatch
Expand Down