Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.
Unxt supports JAX's compelling features:
- JIT compilation (
jit) - vectorization (
vmap, etc.) - auto-differentiation (
grad,jacobian,hessian) - GPU/TPU/multi-host acceleration
And best of all, unxt doesn't force you to use special unit-compatible
re-exports of JAX libraries. You can use unxt with existing JAX code, and with
quax's simple decorator, JAX will work with unxt.Quantity.
pip install unxtusing uv
uv add unxtfrom source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.gitbuilding from source
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e . # editable modeimport unxt as u
import jax.numpy as jnp
x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
print(x)
# Quantity['length']([1., 2., 3., 4.], unit='km')The constituent value and unit are accessible as attributes:
repr(x.value)
# Array([1., 2., 3., 4.], dtype=float64)
repr(x.unit)
# Unit("m")Quantity objects obey the rules of unitful arithmetic.
# Addition / Subtraction
print(x + x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")
# Multiplication / Division
print(2 * x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")
y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")
print(x / y)
# Quantity['speed']([0.25, 0.4 , 0.5 , 0.57142857], unit='km / yr')
# Exponentiation
print(x**2)
# Quantity['area']([ 1., 4., 9., 16.], unit='km2')
# Unit checking on operations
try:
x + y
except Exception as e:
print(e)
# 'yr' (time) and 'km' (length) are not convertibleQuantities can be converted to different units:
print(u.uconvert("m", x)) # via function
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')
print(x.uconvert("m")) # via method
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')Since Quantity is parametric, it can do runtime dimension checking!
LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](2, unit='km')
try:
LengthQuantity(2, "s")
except ValueError as e:
print(e)
# Physical type mismatch.unxt is built on quax, which enables custom array-ish objects in
JAX. For convenience we use the quaxed library, which is just a
quax.quaxify wrapper around jax to avoid boilerplate code.
Note
Using quaxed is optional. You can directly use quaxify, and even
apply it to the top-level function instead of individual functions.
from quaxed import grad, vmap
import quaxed.numpy as jnp
print(jnp.square(x))
# Quantity['area']([ 1., 4., 9., 16.], unit='km2')
print(jnp.power(x, 3))
# Quantity['volume']([ 1., 8., 27., 64.], unit='km3')
print(vmap(grad(lambda x: x**3))(x))
# Quantity['area']([ 3., 12., 27., 48.], unit='km2')See the documentation for more examples and details of JIT and AD
If you found this library to be useful and want to support the development and maintenance of lower-level code libraries for the scientific community, please consider citing this work.
We welcome contributions! Contributions are how open source projects improve and grow.
To contribute to unxt, please
fork the repository, make a
development branch, develop on that branch, then
open a pull request from the
branch in your fork to main.
To report bugs, request features, or suggest other ideas, please open an issue.
For more information, see CONTRIBUTING.md.