Skip to content

Commit

Permalink
Use only a single array for masking
Browse files Browse the repository at this point in the history
  • Loading branch information
JBorrow committed Sep 19, 2024
1 parent 19dc97e commit a0923e1
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 48 deletions.
163 changes: 116 additions & 47 deletions swiftsimio/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,79 @@ def __init__(self, metadata: SWIFTMetadata, spatial_only=True):
if not spatial_only:
self._generate_empty_masks()

def _generate_empty_masks(self):

def _generate_mapping_dictionary(self) -> dict[str, str]:
"""
Generates the empty (i.e. all False) masks for all available particle
types.
Creates cross-links between 'group names' and their underlying cell metadata
names. Allows for pointers to be used instead of re-creating masks.
"""

for group_name in self.metadata.present_group_names:
if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
self.group_mapping = {
group: f"_{group}" for group in self.metadata.present_group_names
}
else:
# We actually only have _one_ mask!
self.group_mapping = {
group: "_shared" for group in self.metadata.present_group_names
}

return self.group_mapping

def _generate_update_list(self) -> list[str]:
"""
Gets a list of internal mask variables that need to be updated when
we change the spatial mask.
"""

if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
return [f"_{group}" for group in self.metadata.present_group_names]
else:
# We actually only have _one_ mask!
return ["_shared"]

def _create_pointers(self):
# Create pointers for every single particle type.
for group_name, data_name in self._generate_mapping_dictionary().items():
setattr(
self,
group_name,
np.ones(getattr(self.metadata, f"n_{group_name}"), dtype=bool),
getattr(self, data_name)
)

setattr(
self,
f"{group_name}_size",
getattr(self, f"{data_name}_size")
)


def _generate_empty_masks(self):
"""
Generates the empty (i.e. all False) masks for all available particle
types.
"""

mapping = self._generate_mapping_dictionary()

if self.metadata.shared_cell_counts is not None:
size = getattr(self.metadata, f"n_{self.metadata.shared_cell_counts.lower()}")
self._shared = np.ones(size, dtype=bool)
self._shared_size = size

else:
# Create empty masks for each and every particle type.
for group_name, data_name in mapping.items():
size = getattr(self.metadata, f"n_{group_name}")
setattr(self, data_name, np.ones(size, dtype=bool))
setattr(self, f"{data_name}_size", size)

self._create_pointers()

return

def _unpack_cell_metadata(self):
Expand All @@ -104,29 +164,39 @@ def _unpack_cell_metadata(self):
# file i/o implemented
offset_handle = cell_handle["Offsets"]


if self.metadata.shared_cell_counts is not None:
# Single - called _shared.
self.offsets["shared"] = offset_handle[self.metadata.shared_cell_counts][:]
self.counts["shared"] = count_handle[self.metadata.shared_cell_counts][:]
else:
for group, group_name in zip(
self.metadata.present_groups, self.metadata.present_group_names
):
counts = count_handle[group][:]
offsets = offset_handle[group][:]

self.offsets[group_name] = offset_handle[group][:]
self.counts[group_name] = count_handle[group][:]

# Only want to compute this once (even if it is fast, we do not
# have a reliable stable sort in the case where cells do not
# contain at least one of each type of particle).
sort = None

for group, group_name in zip(
self.metadata.present_groups, self.metadata.present_group_names
):
if self.metadata.shared_cell_counts is None:
counts = count_handle[group][:]
offsets = offset_handle[group][:]
else:
counts = count_handle[self.metadata.shared_cell_counts][:]
offsets = offset_handle[self.metadata.shared_cell_counts][:]
# Now perform sort:
for key in self.offsets.keys():
offsets = self.offsets[key]
counts = self.counts[key]

# When using MPI, we cannot assume that these are sorted.
if sort is None:
# Only compute once; not stable between particle
# types if some datasets do not have particles in a cell!
sort = np.argsort(offsets)

self.offsets[group_name] = offsets[sort]
self.counts[group_name] = counts[sort]
self.offsets[key] = offsets[sort]
self.counts[key] = counts[sort]

# Also need to sort centers in the same way
self.centers = unyt.unyt_array(centers_handle[:][sort], units=self.units.length)
Expand Down Expand Up @@ -180,8 +250,11 @@ def constrain_mask(
print("You cannot constrain a mask if spatial_only=True")
print("Please re-initialise the SWIFTMask object with spatial_only=False")
return

mapping = self._generate_mapping_dictionary()
data_name = mapping[group_name]

current_mask = getattr(self, group_name)
current_mask = getattr(self, data_name)

group_metadata = getattr(self.metadata, f"{group_name}_properties")
unit_dict = {
Expand Down Expand Up @@ -209,7 +282,7 @@ def constrain_mask(

current_mask[current_mask] = new_mask

setattr(self, group_name, current_mask)
setattr(self, data_name, current_mask)

return

Expand Down Expand Up @@ -288,7 +361,7 @@ def _generate_cell_mask(self, restrict):

return cell_mask

def _update_spatial_mask(self, restrict, group_name: str, cell_mask: np.array):
def _update_spatial_mask(self, restrict, data_name: str, cell_mask: np.array):
"""
Updates the particle mask using the cell mask.
Expand All @@ -302,28 +375,30 @@ def _update_spatial_mask(self, restrict, group_name: str, cell_mask: np.array):
restrict : list
currently unused
group_name : str
particle type to update
data_name : str
underlying data to update (e.g. _gas, _shared)
cell_mask : np.array
cell mask used to update the particle mask
"""

count_name = data_name[1:] # Remove the underscore

if self.spatial_only:
counts = self.counts[group_name][cell_mask]
offsets = self.offsets[group_name][cell_mask]
counts = self.counts[count_name][cell_mask]
offsets = self.offsets[count_name][cell_mask]

this_mask = [[o, c + o] for c, o in zip(counts, offsets)]

setattr(self, group_name, np.array(this_mask))
setattr(self, f"{group_name}_size", np.sum(counts))
setattr(self, data_name, np.array(this_mask))
setattr(self, f"{data_name}_size", np.sum(counts))

else:
counts = self.counts[group_name][~cell_mask]
offsets = self.offsets[group_name][~cell_mask]
counts = self.counts[count_name][~cell_mask]
offsets = self.offsets[count_name][~cell_mask]

# We must do the whole boolean mask business.
this_mask = getattr(self, group_name)
this_mask = getattr(self, data_name)

for count, offset in zip(counts, offsets):
this_mask[offset : count + offset] = False
Expand Down Expand Up @@ -373,8 +448,10 @@ def constrain_spatial(self, restrict, intersect: bool = False):
# we just make a new mask
self.cell_mask = self._generate_cell_mask(restrict)

for group_name in self.metadata.present_group_names:
self._update_spatial_mask(restrict, group_name, self.cell_mask)
for mask in self._generate_update_list():
self._update_spatial_mask(restrict, mask, self.cell_mask)

self._create_pointers()

return

Expand All @@ -388,27 +465,19 @@ def convert_masks_to_ranges(self):
If you don't know what you are doing please don't use this.
"""

if self.spatial_only:
# We are already done!
return
else:
# Spatial only already comes like this!
if not self.spatial_only:
# We must do the whole boolean mask stuff. To do that, we
# First, convert each boolean mask into an integer mask
# Use the accelerate.ranges_from_array function to convert
# This into a set of ranges.
for mask in self._generate_update_list():
where_array = np.where(getattr(self, mask))[0]
setattr(self, f"{mask}_size", where_array.size)
print(mask, where_array)
setattr(self, mask, ranges_from_array(where_array))

for group_name in self.metadata.present_group_names:
setattr(
self,
group_name,
# Because it nests things in a list for some reason.
np.where(getattr(self, group_name))[0],
)

setattr(self, f"{group_name}_size", getattr(self, group_name).size)

for group_name in self.metadata.present_group_names:
setattr(self, group_name, ranges_from_array(getattr(self, group_name)))
self._create_pointers()

return

Expand All @@ -431,7 +500,7 @@ def constrain_index(self, index: int):
setattr(self, f"{group_name}_size", 1)
return

def get_masked_counts_offsets(self) -> (Dict[str, np.array], Dict[str, np.array]):
def get_masked_counts_offsets(self) -> tuple[dict[str, np.array], dict[str, np.array]]:
"""
Returns the particle counts and offsets in cells selected by the mask
Expand Down
4 changes: 4 additions & 0 deletions swiftsimio/metadata/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,13 +1272,17 @@ def __init__(self, filename: str, units: SWIFTUnits):

self.get_metadata()
self.postprocess_header()
self.unpack_subhalo_number()

self.load_groups()

# After we've loaded all this metadata, we can safely release the file handle.
self.handle.close()

return

def unpack_subhalo_number(self):
self.n_subhalos = int(self.num_subhalo[0])

@property
def present_groups(self):
Expand Down
30 changes: 29 additions & 1 deletion tests/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,39 @@

from tests.helper import requires

from swiftsimio import load
from swiftsimio import load, mask


@requires("soap_example.hdf5")
def test_soap_can_load(filename):
data = load(filename)

return


@requires("soap_example.hdf5")
def test_soap_can_mask_spatial(filename):
this_mask = mask(filename, spatial_only=True)

bs = this_mask.metadata.boxsize
this_mask.constrain_spatial(
[[0 * b, 0.5 * b] for b in bs]
)

data = load(filename, mask=this_mask)

data.spherical_overdensity_200_mean.total_mass[0]


@requires("soap_example.hdf5")
def test_soap_can_mask_non_spatial(filename):
this_mask = mask(filename, spatial_only=False)

bs = this_mask.metadata.boxsize
this_mask.constrain_spatial(
[[0 * b, 0.5 * b] for b in bs]
)

data = load(filename, mask=this_mask)

data.spherical_overdensity_200_mean.total_mass[0]

0 comments on commit a0923e1

Please sign in to comment.