Skip to content

Commit d57d9a6

Browse files
committed
fix errors
1 parent 3b01836 commit d57d9a6

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

numpyro/distributions/transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
import math
6-
from typing import Any, Optional, Sequence, Tuple, Union
6+
from typing import Any, Optional, Sequence, Tuple, Union, cast
77
import warnings
88
import weakref
99

@@ -79,10 +79,10 @@ def inv(self: TransformT) -> TransformT:
7979
# TODO: can not understand the implementation (type wise)
8080
inv = None
8181
if self._inv is not None:
82-
inv = self._inv()
82+
inv = self._inv
8383
if inv is None:
84-
inv: TransformT = _InverseTransform(self)
85-
self._inv: TransformT = weakref.ref(inv)
84+
inv = cast(TransformT, _InverseTransform(self))
85+
self._inv = cast(TransformT, weakref.ref(inv))
8686
return inv
8787

8888
def __call__(self, x: Union[Array, Any]) -> Union[Array, Any]:
@@ -1547,7 +1547,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT:
15471547
n_imag = n - n_real
15481548
complex_dtype = jnp.result_type(x.dtype, jnp.complex64)
15491549
return (
1550-
x[..., :n_real]
1550+
jnp.asarray(x)[..., :n_real]
15511551
.astype(complex_dtype)
15521552
.at[..., 1 : 1 + n_imag]
15531553
.add(1j * x[..., n_real:])

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ doctest_optionflags = [
116116
[tool.mypy]
117117
ignore_errors = true
118118
ignore_missing_imports = true
119-
plugins = ["numpy.typing.mypy_plugin"]
120119

121120
[[tool.mypy.overrides]]
122121
module = [

0 commit comments

Comments
 (0)