diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2ab0bc1a..433b385a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -292,8 +292,6 @@ class AlignedSpin(Prior): """ amax: Float = 0.99 - xmax: Float = 0.99 - xmin: Float = -0.99 chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) @@ -310,8 +308,6 @@ def __init__( super().__init__(naming, transforms) assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" self.amax = amax - self.xmax = amax - self.xmin = -amax # build the interpolation table for the ppf of the one-sided distribution chi_axis = jnp.linspace(1e-31, self.amax, num=1000) @@ -319,6 +315,14 @@ def __init__( self.chi_axis = chi_axis self.cdf_vals = cdf_vals + @property + def xmin(self): + return -self.amax + + @property + def xmax(self): + return self.amax + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: