Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache basis evaluation #3921

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import collections
import itertools
from functools import singledispatch
from functools import cached_property, singledispatch

import gem
import numpy
Expand All @@ -18,7 +18,6 @@
from gem.node import traversal
from gem.optimise import constant_fold_zero, ffc_rounding
from gem.unconcatenate import unconcatenate
from gem.utils import cached_property
from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors,
CellFacetJacobian, CellOrientation, CellOrigin,
CellVertices, CellVolume, Coefficient, FacetArea,
Expand All @@ -42,6 +41,8 @@
TSFCConstantMixin, entity_avg, one_times,
preprocess_expression, simplify_abs)

from pyop2.caching import serial_cache


class ContextBase(ProxyKernelInterface):
"""Common UFL -> GEM translation context."""
Expand Down Expand Up @@ -296,6 +297,12 @@ def point_expr(self):
def weight_expr(self):
return self.quadrature_rule.weight_expression

def make_basis_evaluation_key(self, finat_element, mt, entity_id):
domain = extract_unique_domain(mt.terminal)
restriction = mt.restriction
return (self, finat_element, mt.local_derivatives, domain, restriction, entity_id)

@serial_cache(hashkey=make_basis_evaluation_key)
def basis_evaluation(self, finat_element, mt, entity_id):
return finat_element.basis_evaluation(mt.local_derivatives,
self.point_set,
Expand Down
3 changes: 1 addition & 2 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import operator
import string
from functools import reduce
from functools import cached_property, reduce
from itertools import chain, product

import gem
Expand All @@ -13,7 +13,6 @@
from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from gem.utils import cached_property
from numpy import asarray
from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell, create_element
Expand Down
Loading