Skip to content

Commit

Permalink
Merge pull request #212 from yuyttenhove/master
Browse files Browse the repository at this point in the history
Add support for np.cbrt on cosmo arrays
  • Loading branch information
kyleaoman authored Nov 30, 2024
2 parents 2477c50 + a928d21 commit c24ac52
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ classifiers = [
dependencies = [
"numpy",
"h5py",
"unyt>=2.9.0",
"unyt>=3.0.2",
"numba>=0.50.0",
]

Expand Down
4 changes: 2 additions & 2 deletions swiftsimio/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
expm1,
log1p,
sqrt,
# cbrt, # TODO: Needs upstream (unyt) support first
cbrt,
square,
reciprocal,
sin,
Expand Down Expand Up @@ -656,7 +656,7 @@ class cosmo_array(unyt_array):
log1p: _return_without_cosmo_factor,
sqrt: _sqrt_cosmo_factor,
square: _square_cosmo_factor,
# cbrt: _cbrt_cosmo_factor, # TODO: Needs upstream (unyt) support first
cbrt: _cbrt_cosmo_factor,
reciprocal: _reciprocal_cosmo_factor,
sin: _return_without_cosmo_factor,
cos: _return_without_cosmo_factor,
Expand Down
23 changes: 2 additions & 21 deletions swiftsimio/visualisation/smoothing_length/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,11 @@ def get_hsml(data: SWIFTDataset) -> cosmo_array:
The extracted "smoothing lengths".
"""
try:
# TODO remove this hack once np.cbrt is supported by unyt
volumes = data.gas.volumes
units = (hasattr(volumes, "units"), getattr(volumes, "units", None))
comoving = getattr(volumes, "comoving", None)
cosmo_factor = (
hasattr(volumes, "cosmo_factor"),
getattr(volumes, "cosmo_factor", None),
)
if units[0]:
units_hsml = units[1] ** (1.0 / 3.0)
else:
units_hsml = None
hsml = cosmo_array(
cbrt(volumes.value),
units=units_hsml,
comoving=comoving,
cosmo_factor=_cbrt_cosmo_factor(cosmo_factor),
)
hsml = cbrt(data.gas.volumes)
except AttributeError:
try:
# Try computing the volumes explicitly?
masses = data.gas.masses
densities = data.gas.densities
hsml = cbrt(masses / densities)
hsml = cbrt(data.gas.masses / data.gas.densities)
except AttributeError:
# Fall back to SPH behavior if above didn't work...
hsml = get_hsml_sph(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def test_nonzero_tcmb(filename):

output_cosmology = swift_cosmology_to_astropy(cosmo=cosmo, units=units)

assert isclose(output_cosmology._Ogamma0, 0.1)
assert isclose(output_cosmology.Ogamma0, 0.1)
12 changes: 12 additions & 0 deletions tests/test_cosmo_array_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,18 @@ def test_square_ufunc(self):
assert res.comoving is False
assert res.cosmo_factor == inp.cosmo_factor ** 2

def test_cbrt_ufunc(self):
inp = cosmo_array(
[8],
u.kpc,
comoving=False,
cosmo_factor=cosmo_factor(a ** 1, scale_factor=1.0),
)
res = np.cbrt(inp)
assert res.to_value(u.kpc ** (1.0 / 3.0)) == 2 # also ensures units ok
assert res.comoving is False
assert res.cosmo_factor == inp.cosmo_factor ** (1.0 / 3.0)

def test_reciprocal_ufunc(self):
inp = cosmo_array(
[2.0],
Expand Down

0 comments on commit c24ac52

Please sign in to comment.