Skip to content

Commit

Permalink
Merge pull request #201 from SWIFTSIM/masking_in_9.0
Browse files Browse the repository at this point in the history
Use only a single array for masking
  • Loading branch information
JBorrow authored Sep 20, 2024
2 parents b0b251a + 61b328c commit 22e595c
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 39 deletions.
25 changes: 25 additions & 0 deletions docs/source/masking/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,31 @@ ask for the temperature of particles, it will recieve an array containing
temperatures of particles that lie in the region [0.2, 0.7] and have a
density between 0.4 and 0.8 g/cm^3.

Row Masking
-----------

For certian scenarios, in particular halo catalogues, all arrays are of the
same length (you can check this through the ``metadata.homogeneous_arrays``
attribute). Often, you are interested in a handful of, or a single, row,
corresponding to the properties of a particular object. You can use the
methods ``constrain_index`` and ``constrain_indices`` to do this, which
return ``swiftsimio`` data objects containing arrays with only those
rows.

.. code-block:: python
import swiftsimio as sw
mask = sw.mask(filename)
mask.constrain_indices([1, 99, 23421])
data = sw.load(filename, mask=mask)
Here, the length of all the arrays will be 3. A quick performance note: if you
are using many indices (over 1000), you will want to set ``spatial_only=False``
to potentially benefit from range reading of overlapping rows in a single chunk.

Writing subset of snapshot
--------------------------
In some cases it may be useful to write a subset of an existing snapshot to its
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ packages = [

[project]
name = "swiftsimio"
version="9.0.0"
version="9.0.1"
authors = [
{ name="Josh Borrow", email="[email protected]" },
]
Expand Down
129 changes: 106 additions & 23 deletions swiftsimio/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from swiftsimio import SWIFTMetadata
from swiftsimio.metadata.objects import SWIFTMetadata

from swiftsimio.objects import InvalidSnapshot

Expand All @@ -24,6 +24,9 @@ class SWIFTMask(object):
Pass in the SWIFTMetadata.
"""

group_mapping: dict | None = None
group_size_mapping: dict | None = None

def __init__(self, metadata: SWIFTMetadata, spatial_only=True):
"""
SWIFTMask constructor
Expand Down Expand Up @@ -72,6 +75,9 @@ def _generate_mapping_dictionary(self) -> dict[str, str]:
names. Allows for pointers to be used instead of re-creating masks.
"""

if self.group_mapping is not None:
return self.group_mapping

if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
Expand All @@ -86,6 +92,31 @@ def _generate_mapping_dictionary(self) -> dict[str, str]:

return self.group_mapping

def _generate_size_mapping_dictionary(self) -> dict[str, str]:
"""
Creates cross-links between 'group names' and their underlying cell metadata
names. Allows for pointers to be used instead of re-creating masks.
"""

if self.group_size_mapping is not None:
return self.group_size_mapping

if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
self.group_size_mapping = {
f"{group}_size": f"_{group}_size"
for group in self.metadata.present_group_names
}
else:
# We actually only have _one_ mask!
self.group_size_mapping = {
f"{group}_size": "_shared_size"
for group in self.metadata.present_group_names
}

return self.group_size_mapping

def _generate_update_list(self) -> list[str]:
"""
Gets a list of internal mask variables that need to be updated when
Expand All @@ -100,12 +131,22 @@ def _generate_update_list(self) -> list[str]:
# We actually only have _one_ mask!
return ["_shared"]

def _create_pointers(self):
# Create pointers for every single particle type.
for group_name, data_name in self._generate_mapping_dictionary().items():
setattr(self, group_name, getattr(self, data_name))
def __getattr__(self, name):
"""
Overloads the getattr method to allow for direct access to the masks
for each particle type.
"""
mappings = {
**self._generate_mapping_dictionary(),
**self._generate_size_mapping_dictionary(),
}

underlying_name = mappings.get(name, None)

setattr(self, f"{group_name}_size", getattr(self, f"{data_name}_size"))
if underlying_name is not None:
return getattr(self, underlying_name)

raise AttributeError(f"Attribute {name} not found in SWIFTMask")

def _generate_empty_masks(self):
"""
Expand All @@ -129,8 +170,6 @@ def _generate_empty_masks(self):
setattr(self, data_name, np.ones(size, dtype=bool))
setattr(self, f"{data_name}_size", size)

self._create_pointers()

return

def _unpack_cell_metadata(self):
Expand Down Expand Up @@ -238,9 +277,10 @@ def constrain_mask(
"""

if self.spatial_only:
print("You cannot constrain a mask if spatial_only=True")
print("Please re-initialise the SWIFTMask object with spatial_only=False")
return
raise ValueError(
"You cannot constrain a mask if spatial_only=True. "
"Please re-initialise the SWIFTMask object with spatial_only=False"
)

mapping = self._generate_mapping_dictionary()
data_name = mapping[group_name]
Expand Down Expand Up @@ -442,8 +482,6 @@ def constrain_spatial(self, restrict, intersect: bool = False):
for mask in self._generate_update_list():
self._update_spatial_mask(restrict, mask, self.cell_mask)

self._create_pointers()

return

def convert_masks_to_ranges(self):
Expand All @@ -465,11 +503,8 @@ def convert_masks_to_ranges(self):
for mask in self._generate_update_list():
where_array = np.where(getattr(self, mask))[0]
setattr(self, f"{mask}_size", where_array.size)
print(mask, where_array)
setattr(self, mask, ranges_from_array(where_array))

self._create_pointers()

return

def constrain_index(self, index: int):
Expand All @@ -483,16 +518,64 @@ def constrain_index(self, index: int):
index : int
The index of the row to select.
"""
if not self.metadata.filetype == "SOAP":
warnings.warn("Not masking a SOAP catalogue, nothing constrained.")
return
for group_name in self.metadata.present_group_names:
setattr(self, group_name, np.array([[index, index + 1]]))
setattr(self, f"{group_name}_size", 1)

if not self.metadata.homogeneous_arrays:
raise RuntimeError(
"Cannot constrain to a single row in a non-homogeneous array; you currently "
f"are using a {self.metadata.output_type} file"
)

if not self.spatial_only:
raise RuntimeError(
"Cannot constrain to a single row in a non-spatial mask; you currently "
"are using a non-spatial mask"
)

for mask in self._generate_update_list():
setattr(self, mask, np.array([[index, index + 1]]))
setattr(self, f"{mask}_size", 1)

return

def constrain_indices(self, indices: list[int]):
"""
Constrain the mask to a list of rows.
Parameters
----------
indices : list[int]
An list of the indices of the rows to mask.
"""

if not self.metadata.homogeneous_arrays:
raise RuntimeError(
"Cannot constrain to a single row in a non-homogeneous array; you currently "
f"are using a {self.metadata.output_type} file"
)

if self.spatial_only:
if len(indices) > 1000:
warnings.warn(
"You are constraining a large number of indices with a spatial "
"mask, potentially leading to lots of overlap. You should "
"use a non-spatial mask (i.e. spatial_only=False)"
)

for mask in self._generate_update_list():
setattr(self, mask, np.array([[i, i + 1] for i in indices]))
setattr(self, f"{mask}_size", len(indices))

else:
for mask in self._generate_update_list():
comparison_array = np.zeros(getattr(self, mask).size, dtype=bool)
comparison_array[indices] = True
setattr(
self, mask, np.logical_and(getattr(self, mask), comparison_array)
)

return

def get_masked_counts_offsets(
self
self,
) -> tuple[dict[str, np.array], dict[str, np.array]]:
"""
Returns the particle counts and offsets in cells selected by the mask
Expand Down
8 changes: 8 additions & 0 deletions swiftsimio/metadata/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class SWIFTMetadata(ABC):
# (as is the case in SOAP) or whether each type (e.g. Gas, Dark Matter, etc.)
# has its own top-level cell grid counts.
shared_cell_counts: str | None = None
# Whether all the arrays in this files have the same length and order (as is
# the case for SOAP, all arrays correspond to subhalos) or whether there are
# multiple types (e.g. Gas, Dark Matter, etc.). Allows you to use constrain_index
# in masking as everyone uses the same _shared mask!
homogeneous_arrays: bool = False

@abstractmethod
def __init__(self, filename, units: "SWIFTUnits"):
Expand Down Expand Up @@ -1223,6 +1228,8 @@ class SWIFTFOFMetadata(SWIFTMetadata):
class.
"""

homogeneous_arrays: bool = True

def __init__(self, filename: str, units: SWIFTUnits):
self.filename = filename
self.units = units
Expand Down Expand Up @@ -1265,6 +1272,7 @@ class SWIFTSOAPMetadata(SWIFTMetadata):

masking_valid: bool = True
shared_cell_counts: str = "Subhalos"
homogeneous_arrays: bool = True

def __init__(self, filename: str, units: SWIFTUnits):
self.filename = filename
Expand Down
21 changes: 8 additions & 13 deletions swiftsimio/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,20 @@
+ SWIFTDataset, a container class for all of the above.
"""


from swiftsimio.accelerated import read_ranges_from_file
from swiftsimio.objects import cosmo_array, cosmo_factor, a
from swiftsimio.objects import cosmo_array, cosmo_factor

from swiftsimio.metadata.objects import (
metadata_discriminator,
SWIFTUnits,
SWIFTGroupMetadata,
SWIFTMetadata,
)

import re
import h5py
import unyt
import numpy as np
import warnings

from datetime import datetime
from pathlib import Path

from typing import Union, Callable, List, Optional
from typing import Union, List


def generate_getter(
Expand Down Expand Up @@ -181,9 +174,11 @@ def getter(self):
cosmo_array(
# Only use column data if array is multidimensional, otherwise
# we will crash here
handle[field][:, columns]
if handle[field].ndim > 1
else handle[field][:],
(
handle[field][:, columns]
if handle[field].ndim > 1
else handle[field][:]
),
unit,
cosmo_factor=cosmo_factor,
name=description,
Expand Down Expand Up @@ -316,7 +311,7 @@ def generate_empty_properties(self):
class __SWIFTNamedColumnDataset(object):
"""
Holder class for individual named datasets. Very similar to
__SWIFTGroupsDatasets but much simpler.
__SWIFTGroupDatasets but much simpler.
"""

def __init__(self, field_path: str, named_columns: List[str], name: str):
Expand Down
2 changes: 0 additions & 2 deletions swiftsimio/subset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
it to a new file.
"""

from swiftsimio.reader import SWIFTUnits, SWIFTMetadata
from swiftsimio.masks import SWIFTMask
from swiftsimio.accelerated import read_ranges_from_file
import swiftsimio.metadata as metadata

import unyt
import h5py
import numpy as np
from typing import Optional, List
Expand Down
Loading

0 comments on commit 22e595c

Please sign in to comment.