From e8ff24baee0711c9fe6cc286e3dd82c854996d39 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 21 Jan 2025 17:56:37 +0000 Subject: [PATCH] Optimize zany-mapping matvec --- finat/physically_mapped.py | 9 ++++++--- gem/gem.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 4b6c6089..80abc775 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -270,15 +270,18 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): M = self.basis_transformation(coordinate_mapping) # we expect M to be sparse with O(1) nonzeros per row # for each row, get the column index of each nonzero entry - csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] + csr = [[j for j in range(M.shape[1]) if not isinstance(M[i, j], gem.Zero)] for i in range(M.shape[0])] def matvec(table): # basis recombination using hand-rolled sparse-dense matrix multiplication - table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])] + ii = gem.indices(len(table.shape)-1) + phi = [gem.Indexed(table, (j, *ii)) for j in range(M.shape[1])] # the sum approach is faster than calling numpy.dot or gem.IndexSum - expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)] + expressions = [gem.ComponentTensor(sum(M[i, j] * phi[j] for j in js), ii) + for i, js in enumerate(csr)] val = gem.ListTensor(expressions) + # val = M @ table return gem.optimise.aggressive_unroll(val) result = super().basis_evaluation(order, ps, entity=entity) diff --git a/gem/gem.py b/gem/gem.py index 01bf2c8d..8371803a 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -689,7 +689,7 @@ def __new__(cls, aggregate, multiindex): ll = tuple(rep.get(k, k) for k in kk) return Indexed(C, ll) - # Simplify Constant and ListTensor + # Simplify Literal and ListTensor if isinstance(aggregate, (Constant, ListTensor)): if all(isinstance(i, int) for i in multiindex): # All indices fixed