Skip to content

Commit 3b01836

Browse files
committed
refactor: update type hints and improve type safety across transforms and constraints
1 parent c2ca7c3 commit 3b01836

File tree

3 files changed

+144
-123
lines changed

3 files changed

+144
-123
lines changed

numpyro/_typing.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from collections import OrderedDict
66
from collections.abc import Callable
7-
from typing import Any, Protocol, runtime_checkable
7+
from typing import Any, Optional, Protocol, Union, runtime_checkable
88

99
try:
10-
from typing import ParamSpec, TypeAlias, Union
10+
from typing import ParamSpec, TypeAlias
1111
except ImportError:
12-
from typing_extensions import ParamSpec, TypeAlias, Union
12+
from typing_extensions import ParamSpec, TypeAlias
1313

1414
import numpy as np
1515

@@ -23,7 +23,7 @@
2323
TraceT: TypeAlias = OrderedDict[str, Message]
2424

2525
# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars
26-
StrictArrayT: TypeAlias = Union[np.ndarray, jax.Array]
26+
StrictArrayT = Union[np.ndarray, jax.Array]
2727

2828

2929
@runtime_checkable
@@ -94,18 +94,28 @@ def is_discrete(self) -> bool: ...
9494
class TransformT(Protocol):
9595
domain: ConstraintT = ...
9696
codomain: ConstraintT = ...
97-
_inv: "TransformT" = None
97+
_inv: Optional["TransformT"] = ...
9898

99-
def __call__(self, x: jax.Array) -> jax.Array: ...
100-
def _inverse(self, y: jax.Array) -> jax.Array: ...
99+
def __call__(self, x: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ...
100+
def _inverse(self, y: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ...
101101
def log_abs_det_jacobian(
102-
self, x: jax.Array, y: jax.Array, intermediates=None
103-
) -> jax.Array: ...
104-
def call_with_intermediates(self, x: jax.Array) -> tuple[jax.Array, None]: ...
102+
self,
103+
x: Union[jax.Array, Any],
104+
y: Union[jax.Array, Any],
105+
intermediates: Optional[Any] = None,
106+
) -> Union[jax.Array, Any]: ...
107+
def call_with_intermediates(
108+
self, x: Union[jax.Array, Optional[Any]]
109+
) -> tuple[Union[jax.Array, Any], Any]: ...
105110
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
106111
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
107112

108113
@property
109114
def inv(self) -> "TransformT": ...
110115
@property
111-
def sign(self) -> jax.Array: ...
116+
def sign(self) -> Union[ArrayLike, Any]: ...
117+
118+
119+
class UnusedParam(object):
120+
def __repr__(self):
121+
return "UnusedParam"

numpyro/distributions/constraints.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -801,18 +801,18 @@ def tree_flatten(self):
801801
corr_cholesky: ConstraintT = _CorrCholesky()
802802
corr_matrix: ConstraintT = _CorrMatrix()
803803
dependent: ConstraintT = _Dependent()
804-
greater_than: ConstraintT = _GreaterThan
805-
greater_than_eq: ConstraintT = _GreaterThanEq
806-
less_than: ConstraintT = _LessThan
807-
less_than_eq: ConstraintT = _LessThanEq
804+
greater_than = _GreaterThan
805+
greater_than_eq = _GreaterThanEq
806+
less_than = _LessThan
807+
less_than_eq = _LessThanEq
808808
independent = _IndependentConstraint
809-
integer_interval: ConstraintT = _IntegerInterval
810-
integer_greater_than: ConstraintT = _IntegerGreaterThan
811-
interval: ConstraintT = _Interval
809+
integer_interval = _IntegerInterval
810+
integer_greater_than = _IntegerGreaterThan
811+
interval = _Interval
812812
l1_ball: ConstraintT = _L1Ball()
813813
lower_cholesky: ConstraintT = _LowerCholesky()
814814
scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky()
815-
multinomial: ConstraintT = _Multinomial
815+
multinomial = _Multinomial
816816
nonnegative: ConstraintT = _Nonnegative()
817817
nonnegative_integer: ConstraintT = _IntegerNonnegative()
818818
ordered_vector: ConstraintT = _OrderedVector()
@@ -830,5 +830,5 @@ def tree_flatten(self):
830830
softplus_positive: ConstraintT = _SoftplusPositive()
831831
sphere: ConstraintT = _Sphere()
832832
unit_interval: ConstraintT = _UnitInterval()
833-
open_interval: ConstraintT = _OpenInterval
834-
zero_sum: ConstraintT = _ZeroSum
833+
open_interval = _OpenInterval
834+
zero_sum = _ZeroSum

0 commit comments

Comments
 (0)