From e11e833dcd7d9eaaa1a683ac6b629d604c57496f Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Wed, 8 Jan 2025 16:50:16 +0000 Subject: [PATCH] Invalidate spatial index when coordinates updated (#3956) --- firedrake/mesh.py | 21 ++++++++++++------- .../regression/test_point_eval_fs.py | 9 ++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 4e209340c7..9936efcefc 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -2269,6 +2269,9 @@ def __init__(self, coordinates): # submesh self.submesh_parent = None + self._spatial_index = None + self._saved_coordinate_dat_version = coordinates.dat.dat_version + def _ufl_signature_data_(self, *args, **kwargs): return (type(self), self.extruded, self.variable_layers, super()._ufl_signature_data_(*args, **kwargs)) @@ -2448,12 +2451,9 @@ def clear_spatial_index(self): Use this if you move the mesh (for example by reassigning to the coordinate field).""" - try: - del self.spatial_index - except AttributeError: - pass + self._spatial_index = None - @utils.cached_property + @property def spatial_index(self): """Spatial index to quickly find which cell contains a given point. @@ -2466,10 +2466,15 @@ def spatial_index(self): can be found. """ - from firedrake import function, functionspace from firedrake.parloops import par_loop, READ, MIN, MAX + if ( + self._spatial_index + and self.coordinates.dat.dat_version == self._saved_coordinate_dat_version + ): + return self._spatial_index + gdim = self.geometric_dimension() if gdim <= 1: info_red("libspatialindex does not support 1-dimension, falling back on brute force.") @@ -2531,7 +2536,9 @@ def spatial_index(self): coords_max = coords_mid + (tolerance + 0.5)*d # Build spatial index - return spatialindex.from_regions(coords_min, coords_max) + self._spatial_index = spatialindex.from_regions(coords_min, coords_max) + self._saved_coordinate_dat_version = self.coordinates.dat.dat_version + return self._spatial_index @PETSc.Log.EventDecorator() def locate_cell(self, x, tolerance=None, cell_ignore=None): diff --git a/tests/firedrake/regression/test_point_eval_fs.py b/tests/firedrake/regression/test_point_eval_fs.py index 0ac1a0439a..91a751a826 100644 --- a/tests/firedrake/regression/test_point_eval_fs.py +++ b/tests/firedrake/regression/test_point_eval_fs.py @@ -213,3 +213,12 @@ def test_point_reset_works(): f.assign(1) m.clear_spatial_index() assert np.allclose([1.0], f.at((0.3, 0.3))) + + +def test_changing_coordinates_invalidates_spatial_index(): + mesh = UnitSquareMesh(2, 2) + mesh.init() + + saved_spatial_index = mesh.spatial_index + mesh.coordinates.assign(mesh.coordinates * 2) + assert mesh.spatial_index != saved_spatial_index