Skip to content

Commit

Permalink
TSFC: Cache basis evaluation (#3921)
Browse files Browse the repository at this point in the history
* TSFC: cache basis evaluation

* TSFC: cache QuadratureRule
  • Loading branch information
pbrubeck authored Dec 13, 2024
1 parent 831757c commit fddc94d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
24 changes: 17 additions & 7 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import collections
import itertools
from functools import singledispatch
from functools import cached_property, singledispatch

import gem
import numpy
Expand All @@ -18,7 +18,6 @@
from gem.node import traversal
from gem.optimise import constant_fold_zero, ffc_rounding
from gem.unconcatenate import unconcatenate
from gem.utils import cached_property
from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors,
CellFacetJacobian, CellOrientation, CellOrigin,
CellVertices, CellVolume, Coefficient, FacetArea,
Expand All @@ -42,6 +41,8 @@
TSFCConstantMixin, entity_avg, one_times,
preprocess_expression, simplify_abs)

from pyop2.caching import serial_cache


class ContextBase(ProxyKernelInterface):
"""Common UFL -> GEM translation context."""
Expand Down Expand Up @@ -262,6 +263,18 @@ def needs_coordinate_mapping(element):
return isinstance(create_element(element), NeedsCoordinateMappingElement)


@serial_cache(hashkey=lambda *args: args)
def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme):
integration_cell = fiat_cell.construct_subcomplex(integration_dim)
return make_quadrature(integration_cell, quadrature_degree, scheme=scheme)


def make_basis_evaluation_key(ctx, finat_element, mt, entity_id):
domain = extract_unique_domain(mt.terminal)
coordinate_element = domain.ufl_coordinate_element()
return (finat_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction)


class PointSetContext(ContextBase):
"""Context for compile-time known evaluation points."""

Expand All @@ -272,13 +285,9 @@ class PointSetContext(ContextBase):
'weight_expr',
)

@cached_property
def integration_cell(self):
return self.fiat_cell.construct_subelement(self.integration_dim)

@cached_property
def quadrature_rule(self):
return make_quadrature(self.integration_cell, self.quadrature_degree)
return get_quadrature_rule(self.fiat_cell, self.integration_dim, self.quadrature_degree, "default")

@cached_property
def point_set(self):
Expand All @@ -296,6 +305,7 @@ def point_expr(self):
def weight_expr(self):
return self.quadrature_rule.weight_expression

@serial_cache(hashkey=make_basis_evaluation_key)
def basis_evaluation(self, finat_element, mt, entity_id):
return finat_element.basis_evaluation(mt.local_derivatives,
self.point_set,
Expand Down
14 changes: 6 additions & 8 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import collections
import operator
import string
from functools import reduce
from functools import cached_property, reduce
from itertools import chain, product

import gem
import gem.impero_utils as impero_utils
import numpy
from FIAT.reference_element import TensorProductCell
from finat.cell_tools import max_complex
from finat.quadrature import AbstractQuadratureRule, make_quadrature
from finat.quadrature import AbstractQuadratureRule
from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from gem.utils import cached_property
from numpy import asarray
from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell, create_element
Expand Down Expand Up @@ -137,7 +136,7 @@ def compile_integrand(self, integrand, params, ctx):
integrand = ufl_utils.split_coefficients(integrand, self.coefficient_split)
# Compile: ufl -> gem
info = self.integral_data_info
functions = list(info.arguments) + [self.coordinate(info.domain)] + list(info.coefficients)
functions = [*info.arguments, self.coordinate(info.domain), *info.coefficients]
set_quad_rule(params, info.domain.ufl_cell(), info.integral_type, functions)
quad_rule = params["quadrature_rule"]
config = self.fem_config()
Expand Down Expand Up @@ -320,8 +319,7 @@ def set_quad_rule(params, cell, integral_type, functions):
fiat_cell = max_complex(fiat_cells)

integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
integration_cell = fiat_cell.construct_subcomplex(integration_dim)
quad_rule = make_quadrature(integration_cell, quadrature_degree, scheme=scheme)
quad_rule = fem.get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme)
params["quadrature_rule"] = quad_rule

if not isinstance(quad_rule, AbstractQuadratureRule):
Expand All @@ -330,8 +328,8 @@ def set_quad_rule(params, cell, integral_type, functions):


def get_index_ordering(quadrature_indices, return_variables):
split_argument_indices = tuple(chain(*[var.index_ordering()
for var in return_variables]))
split_argument_indices = tuple(chain(*(var.index_ordering()
for var in return_variables)))
return tuple(quadrature_indices) + split_argument_indices


Expand Down

0 comments on commit fddc94d

Please sign in to comment.