Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache basis evaluation #3921

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import collections
import itertools
from functools import singledispatch
from functools import cached_property, singledispatch

import gem
import numpy
Expand All @@ -18,7 +18,6 @@
from gem.node import traversal
from gem.optimise import constant_fold_zero, ffc_rounding
from gem.unconcatenate import unconcatenate
from gem.utils import cached_property
from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors,
CellFacetJacobian, CellOrientation, CellOrigin,
CellVertices, CellVolume, Coefficient, FacetArea,
Expand All @@ -42,6 +41,8 @@
TSFCConstantMixin, entity_avg, one_times,
preprocess_expression, simplify_abs)

from pyop2.caching import serial_cache


class ContextBase(ProxyKernelInterface):
"""Common UFL -> GEM translation context."""
Expand Down Expand Up @@ -262,6 +263,18 @@ def needs_coordinate_mapping(element):
return isinstance(create_element(element), NeedsCoordinateMappingElement)


@serial_cache(hashkey=lambda *args: args)
def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme):
integration_cell = fiat_cell.construct_subcomplex(integration_dim)
return make_quadrature(integration_cell, quadrature_degree, scheme=scheme)


def make_basis_evaluation_key(ctx, finat_element, mt, entity_id):
domain = extract_unique_domain(mt.terminal)
coordinate_element = domain.ufl_coordinate_element()
return (finat_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction)


class PointSetContext(ContextBase):
"""Context for compile-time known evaluation points."""

Expand All @@ -272,13 +285,9 @@ class PointSetContext(ContextBase):
'weight_expr',
)

@cached_property
def integration_cell(self):
return self.fiat_cell.construct_subelement(self.integration_dim)

@cached_property
def quadrature_rule(self):
return make_quadrature(self.integration_cell, self.quadrature_degree)
return get_quadrature_rule(self.fiat_cell, self.integration_dim, self.quadrature_degree, "default")

@cached_property
def point_set(self):
Expand All @@ -296,6 +305,7 @@ def point_expr(self):
def weight_expr(self):
return self.quadrature_rule.weight_expression

@serial_cache(hashkey=make_basis_evaluation_key)
def basis_evaluation(self, finat_element, mt, entity_id):
return finat_element.basis_evaluation(mt.local_derivatives,
self.point_set,
Expand Down
14 changes: 6 additions & 8 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import collections
import operator
import string
from functools import reduce
from functools import cached_property, reduce
from itertools import chain, product

import gem
import gem.impero_utils as impero_utils
import numpy
from FIAT.reference_element import TensorProductCell
from finat.cell_tools import max_complex
from finat.quadrature import AbstractQuadratureRule, make_quadrature
from finat.quadrature import AbstractQuadratureRule
from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from gem.utils import cached_property
from numpy import asarray
from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell, create_element
Expand Down Expand Up @@ -137,7 +136,7 @@ def compile_integrand(self, integrand, params, ctx):
integrand = ufl_utils.split_coefficients(integrand, self.coefficient_split)
# Compile: ufl -> gem
info = self.integral_data_info
functions = list(info.arguments) + [self.coordinate(info.domain)] + list(info.coefficients)
functions = [*info.arguments, self.coordinate(info.domain), *info.coefficients]
set_quad_rule(params, info.domain.ufl_cell(), info.integral_type, functions)
quad_rule = params["quadrature_rule"]
config = self.fem_config()
Expand Down Expand Up @@ -320,8 +319,7 @@ def set_quad_rule(params, cell, integral_type, functions):
fiat_cell = max_complex(fiat_cells)

integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
integration_cell = fiat_cell.construct_subcomplex(integration_dim)
quad_rule = make_quadrature(integration_cell, quadrature_degree, scheme=scheme)
quad_rule = fem.get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme)
params["quadrature_rule"] = quad_rule

if not isinstance(quad_rule, AbstractQuadratureRule):
Expand All @@ -330,8 +328,8 @@ def set_quad_rule(params, cell, integral_type, functions):


def get_index_ordering(quadrature_indices, return_variables):
split_argument_indices = tuple(chain(*[var.index_ordering()
for var in return_variables]))
split_argument_indices = tuple(chain(*(var.index_ordering()
for var in return_variables)))
return tuple(quadrature_indices) + split_argument_indices


Expand Down
Loading