Skip to content

Commit c2ca7c3

Browse files
committed
chore: incomplete changes for type hint in numpyro.distribution.transforms
1 parent ddbd0b8 commit c2ca7c3

File tree

4 files changed

+257
-165
lines changed

4 files changed

+257
-165
lines changed

numpyro/_typing.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from typing import Any, Protocol, runtime_checkable
88

99
try:
10-
from typing import ParamSpec, TypeAlias
10+
from typing import ParamSpec, TypeAlias, Union
1111
except ImportError:
12-
from typing_extensions import ParamSpec, TypeAlias
12+
from typing_extensions import ParamSpec, TypeAlias, Union
13+
14+
import numpy as np
1315

1416
import jax
1517
from jax.typing import ArrayLike
@@ -20,6 +22,9 @@
2022
Message: TypeAlias = dict[str, Any]
2123
TraceT: TypeAlias = OrderedDict[str, Message]
2224

25+
# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars
26+
StrictArrayT: TypeAlias = Union[np.ndarray, jax.Array]
27+
2328

2429
@runtime_checkable
2530
class ConstraintT(Protocol):
@@ -87,20 +92,20 @@ def is_discrete(self) -> bool: ...
8792

8893
@runtime_checkable
8994
class TransformT(Protocol):
90-
domain = ConstraintT
91-
codomain = ConstraintT
95+
domain: ConstraintT = ...
96+
codomain: ConstraintT = ...
9297
_inv: "TransformT" = None
9398

94-
def __call__(self, x: ArrayLike) -> ArrayLike: ...
95-
def _inverse(self, y: ArrayLike) -> ArrayLike: ...
99+
def __call__(self, x: jax.Array) -> jax.Array: ...
100+
def _inverse(self, y: jax.Array) -> jax.Array: ...
96101
def log_abs_det_jacobian(
97-
self, x: ArrayLike, y: ArrayLike, intermediates=None
98-
) -> ArrayLike: ...
99-
def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ...
102+
self, x: jax.Array, y: jax.Array, intermediates=None
103+
) -> jax.Array: ...
104+
def call_with_intermediates(self, x: jax.Array) -> tuple[jax.Array, None]: ...
100105
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
101106
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
102107

103108
@property
104109
def inv(self) -> "TransformT": ...
105110
@property
106-
def sign(self) -> ArrayLike: ...
111+
def sign(self) -> jax.Array: ...

numpyro/distributions/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def tree_flatten(self):
805805
greater_than_eq: ConstraintT = _GreaterThanEq
806806
less_than: ConstraintT = _LessThan
807807
less_than_eq: ConstraintT = _LessThanEq
808-
independent: ConstraintT = _IndependentConstraint
808+
independent = _IndependentConstraint
809809
integer_interval: ConstraintT = _IntegerInterval
810810
integer_greater_than: ConstraintT = _IntegerGreaterThan
811811
interval: ConstraintT = _Interval

0 commit comments

Comments
 (0)