From d9a72a132f5e4fd8e422e08f1234838e4459eaf3 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 11 Dec 2024 15:18:22 +0000 Subject: [PATCH] Cache basis evaluation --- tsfc/fem.py | 11 +++++++++-- tsfc/kernel_interface/common.py | 3 +-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tsfc/fem.py b/tsfc/fem.py index abc8bc7cb8..76b56140f6 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -3,7 +3,7 @@ import collections import itertools -from functools import singledispatch +from functools import cached_property, singledispatch import gem import numpy @@ -18,7 +18,6 @@ from gem.node import traversal from gem.optimise import constant_fold_zero, ffc_rounding from gem.unconcatenate import unconcatenate -from gem.utils import cached_property from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors, CellFacetJacobian, CellOrientation, CellOrigin, CellVertices, CellVolume, Coefficient, FacetArea, @@ -42,6 +41,8 @@ TSFCConstantMixin, entity_avg, one_times, preprocess_expression, simplify_abs) +from pyop2.caching import serial_cache + class ContextBase(ProxyKernelInterface): """Common UFL -> GEM translation context.""" @@ -296,6 +297,12 @@ def point_expr(self): def weight_expr(self): return self.quadrature_rule.weight_expression + def make_basis_evaluation_key(self, finat_element, mt, entity_id): + domain = extract_unique_domain(mt.terminal) + restriction = mt.restriction + return (finat_element, mt.local_derivatives, domain, restriction, entity_id) + + @serial_cache(hashkey=make_basis_evaluation_key) def basis_evaluation(self, finat_element, mt, entity_id): return finat_element.basis_evaluation(mt.local_derivatives, self.point_set, diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index df7e879f09..5cb2096961 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -1,7 +1,7 @@ import collections import operator import string -from functools import reduce +from functools import cached_property, reduce from itertools import chain, product import gem @@ -13,7 +13,6 @@ from gem.node import traversal from gem.optimise import constant_fold_zero from gem.optimise import remove_componenttensors as prune -from gem.utils import cached_property from numpy import asarray from tsfc import fem, ufl_utils from tsfc.finatinterface import as_fiat_cell, create_element