|
2 | 2 | import jax.numpy as jnp
|
3 | 3 | from flowMC.nfmodel.base import Distribution
|
4 | 4 | from jaxtyping import Array, Float
|
5 |
| -from typing import Callable |
| 5 | +from typing import Callable, Union |
6 | 6 | from dataclasses import field
|
7 | 7 |
|
8 | 8 |
|
@@ -352,16 +352,16 @@ def __init__(
|
352 | 352 | self,
|
353 | 353 | xmin: float,
|
354 | 354 | xmax: float,
|
355 |
| - alpha: float, |
| 355 | + alpha: Union[int, float], |
356 | 356 | naming: list[str],
|
357 | 357 | transforms: dict[tuple[str, Callable]] = {},
|
358 | 358 | ):
|
359 | 359 | super().__init__(naming, transforms)
|
360 | 360 | assert isinstance(xmin, float), "xmin must be a float"
|
361 | 361 | assert isinstance(xmax, float), "xmax must be a float"
|
362 |
| - assert isinstance(alpha, (float)), "alpha must be a float" |
| 362 | + assert isinstance(alpha, (int, float)), "alpha must be a int or a float" |
363 | 363 | if alpha < 0.0:
|
364 |
| - assert alpha < 0.0 or xmin > 0.0, "With negative alpha, xmin must > 0" |
| 364 | + assert xmin > 0.0, "With negative alpha, xmin must > 0" |
365 | 365 | assert self.n_dim == 1, "Powerlaw needs to be 1D distributions"
|
366 | 366 | self.xmax = xmax
|
367 | 367 | self.xmin = xmin
|
|
0 commit comments