Skip to content

Commit

Permalink
feat: use astropy unit conversion API
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Sep 27, 2024
1 parent 1e0bac0 commit 6865587
Showing 1 changed file with 1 addition and 44 deletions.
45 changes: 1 addition & 44 deletions src/unxt/_interop/unxt_interop_astropy/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import astropy.units as u
from astropy.coordinates import Angle as AstropyAngle, Distance as AstropyDistance
from astropy.units import Quantity as AstropyQuantity
from jaxtyping import Array
from packaging.version import Version
from plum import conversion_method, dispatch

import quaxed.numpy as jnp
Expand All @@ -22,7 +20,6 @@
Quantity,
UncheckedQuantity,
)
from unxt._interop.optional_deps import OptDeps

# ============================================================================
# AbstractQuantity
Expand Down Expand Up @@ -243,46 +240,6 @@ def convert_astropy_quantity_to_unxt_distmod(q: AstropyQuantity, /) -> DistanceM
AstropyUnit: TypeAlias = u.UnitBase | u.Unit | u.FunctionUnitBase | u.StructuredUnit


if Version("7.0") <= OptDeps.ASTROPY.version:

def _apy7_unit_to(self: AstropyUnit, other: AstropyUnit, value: Array, /) -> Array:
return self.to(other, value)

else:

def _apy7_unit_to(self: AstropyUnit, other: AstropyUnit, value: Array, /) -> Array:
"""Convert the value to the other unit."""
# return self.get_converter(Unit(other), equivalencies)(value)
# First see if it is just a scaling.
try:
scale = self._to(other)
except u.UnitsError:
pass
else:
return scale * value

# if that doesn't work, maybe we can do it with equivalencies?
try:
return self._apply_equivalencies(
self, other, self._normalize_equivalencies([])
)(value)
except u.UnitsError as exc:
# Last hope: maybe other knows how to do it?
# We assume the equivalencies have the unit itself as first item.
# TODO: maybe better for other to have a `_back_converter` method?
if hasattr(other, "equivalencies"):
for funit, tunit, _, b in other.equivalencies:
if other is funit:
try:
converter = self.get_converter(tunit, [])
except Exception: # noqa: BLE001, S110 # pylint: disable=W0718
pass
else:
return b(converter(value))

raise exc # noqa: TRY201


@dispatch # type: ignore[misc]
def uconvert(unit: AstropyUnit, x: AbstractQuantity, /) -> AbstractQuantity:
"""Convert the quantity to the specified units.
Expand Down Expand Up @@ -312,4 +269,4 @@ def uconvert(unit: AstropyUnit, x: AbstractQuantity, /) -> AbstractQuantity:
# if isinstance(x.value, jax.core.Tracer) and not can_convert_unit(x.unit, u):
# return x.value

return replace(x, value=_apy7_unit_to(x.unit, unit, x.value), unit=unit)
return replace(x, value=x.unit.to(unit, x.value), unit=unit)

0 comments on commit 6865587

Please sign in to comment.