From aab3455164de6d0c0ed59144afe99426b58d5522 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 11 Dec 2024 15:18:22 +0000 Subject: [PATCH 1/3] Cache basis evaluation --- tsfc/fem.py | 11 +++++++++-- tsfc/kernel_interface/common.py | 3 +-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tsfc/fem.py b/tsfc/fem.py index abc8bc7cb8..c66c0614ad 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -3,7 +3,7 @@ import collections import itertools -from functools import singledispatch +from functools import cached_property, singledispatch import gem import numpy @@ -18,7 +18,6 @@ from gem.node import traversal from gem.optimise import constant_fold_zero, ffc_rounding from gem.unconcatenate import unconcatenate -from gem.utils import cached_property from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors, CellFacetJacobian, CellOrientation, CellOrigin, CellVertices, CellVolume, Coefficient, FacetArea, @@ -42,6 +41,8 @@ TSFCConstantMixin, entity_avg, one_times, preprocess_expression, simplify_abs) +from pyop2.caching import serial_cache + class ContextBase(ProxyKernelInterface): """Common UFL -> GEM translation context.""" @@ -296,6 +297,12 @@ def point_expr(self): def weight_expr(self): return self.quadrature_rule.weight_expression + def make_basis_evaluation_key(self, finat_element, mt, entity_id): + domain = extract_unique_domain(mt.terminal) + restriction = mt.restriction + return (self, finat_element, mt.local_derivatives, domain, restriction, entity_id) + + @serial_cache(hashkey=make_basis_evaluation_key) def basis_evaluation(self, finat_element, mt, entity_id): return finat_element.basis_evaluation(mt.local_derivatives, self.point_set, diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index df7e879f09..5cb2096961 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -1,7 +1,7 @@ import collections import operator import string -from functools import reduce +from functools import cached_property, reduce from itertools import chain, product import gem @@ -13,7 +13,6 @@ from gem.node import traversal from gem.optimise import constant_fold_zero from gem.optimise import remove_componenttensors as prune -from gem.utils import cached_property from numpy import asarray from tsfc import fem, ufl_utils from tsfc.finatinterface import as_fiat_cell, create_element From 9de44437c32b4e9783e1c10f0848cbcd385e7ca6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 12 Dec 2024 22:18:52 +0000 Subject: [PATCH 2/3] Hash ufl_coordinate_element --- tsfc/fem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tsfc/fem.py b/tsfc/fem.py index c66c0614ad..1d5b1ac392 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -299,8 +299,8 @@ def weight_expr(self): def make_basis_evaluation_key(self, finat_element, mt, entity_id): domain = extract_unique_domain(mt.terminal) - restriction = mt.restriction - return (self, finat_element, mt.local_derivatives, domain, restriction, entity_id) + coordinate_element = domain.ufl_coordinate_element() + return (self, finat_element, mt.local_derivatives, coordinate_element, mt.restriction, entity_id) @serial_cache(hashkey=make_basis_evaluation_key) def basis_evaluation(self, finat_element, mt, entity_id): From dcb322d74da7d9741207167ba07fc06c387981be Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 13 Dec 2024 10:56:54 +0000 Subject: [PATCH 3/3] TSFC: cache QuadratureRule --- tsfc/fem.py | 23 +++++++++++++---------- tsfc/kernel_interface/common.py | 11 +++++------ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tsfc/fem.py b/tsfc/fem.py index 1d5b1ac392..99251ed0c6 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -263,6 +263,18 @@ def needs_coordinate_mapping(element): return isinstance(create_element(element), NeedsCoordinateMappingElement) +@serial_cache(hashkey=lambda *args: args) +def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme): + integration_cell = fiat_cell.construct_subcomplex(integration_dim) + return make_quadrature(integration_cell, quadrature_degree, scheme=scheme) + + +def make_basis_evaluation_key(ctx, finat_element, mt, entity_id): + domain = extract_unique_domain(mt.terminal) + coordinate_element = domain.ufl_coordinate_element() + return (finat_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction) + + class PointSetContext(ContextBase): """Context for compile-time known evaluation points.""" @@ -273,13 +285,9 @@ class PointSetContext(ContextBase): 'weight_expr', ) - @cached_property - def integration_cell(self): - return self.fiat_cell.construct_subelement(self.integration_dim) - @cached_property def quadrature_rule(self): - return make_quadrature(self.integration_cell, self.quadrature_degree) + return get_quadrature_rule(self.fiat_cell, self.integration_dim, self.quadrature_degree, "default") @cached_property def point_set(self): @@ -297,11 +305,6 @@ def point_expr(self): def weight_expr(self): return self.quadrature_rule.weight_expression - def make_basis_evaluation_key(self, finat_element, mt, entity_id): - domain = extract_unique_domain(mt.terminal) - coordinate_element = domain.ufl_coordinate_element() - return (self, finat_element, mt.local_derivatives, coordinate_element, mt.restriction, entity_id) - @serial_cache(hashkey=make_basis_evaluation_key) def basis_evaluation(self, finat_element, mt, entity_id): return finat_element.basis_evaluation(mt.local_derivatives, diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index 5cb2096961..f757af951f 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -9,7 +9,7 @@ import numpy from FIAT.reference_element import TensorProductCell from finat.cell_tools import max_complex -from finat.quadrature import AbstractQuadratureRule, make_quadrature +from finat.quadrature import AbstractQuadratureRule from gem.node import traversal from gem.optimise import constant_fold_zero from gem.optimise import remove_componenttensors as prune @@ -136,7 +136,7 @@ def compile_integrand(self, integrand, params, ctx): integrand = ufl_utils.split_coefficients(integrand, self.coefficient_split) # Compile: ufl -> gem info = self.integral_data_info - functions = list(info.arguments) + [self.coordinate(info.domain)] + list(info.coefficients) + functions = [*info.arguments, self.coordinate(info.domain), *info.coefficients] set_quad_rule(params, info.domain.ufl_cell(), info.integral_type, functions) quad_rule = params["quadrature_rule"] config = self.fem_config() @@ -319,8 +319,7 @@ def set_quad_rule(params, cell, integral_type, functions): fiat_cell = max_complex(fiat_cells) integration_dim, _ = lower_integral_type(fiat_cell, integral_type) - integration_cell = fiat_cell.construct_subcomplex(integration_dim) - quad_rule = make_quadrature(integration_cell, quadrature_degree, scheme=scheme) + quad_rule = fem.get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme) params["quadrature_rule"] = quad_rule if not isinstance(quad_rule, AbstractQuadratureRule): @@ -329,8 +328,8 @@ def set_quad_rule(params, cell, integral_type, functions): def get_index_ordering(quadrature_indices, return_variables): - split_argument_indices = tuple(chain(*[var.index_ordering() - for var in return_variables])) + split_argument_indices = tuple(chain(*(var.index_ordering() + for var in return_variables))) return tuple(quadrature_indices) + split_argument_indices