From 35442a3c08e53907481658e7912deb145c79dd0f Mon Sep 17 00:00:00 2001 From: Adrien Corenflos Date: Mon, 16 Sep 2024 20:19:31 +0100 Subject: [PATCH] Merged comments from Junpeng --- blackjax/mcmc/metrics.py | 4 ++-- blackjax/types.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 678b5cc37..4e079714b 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -187,7 +187,7 @@ def is_turning( return turning_at_left | turning_at_right def scale( - position: ArrayLikeTree, element: ArrayLikeTree, inv: bool + position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -279,7 +279,7 @@ def is_turning( # return turning_at_left | turning_at_right def scale( - position: ArrayLikeTree, element: ArrayLikeTree, inv: bool + position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree ) -> ArrayLikeTree: """Scale elements by the mass matrix. diff --git a/blackjax/types.py b/blackjax/types.py index be73b0d29..4b23fcd22 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -46,4 +46,4 @@ class WelfordAlgorithmState(NamedTuple): #: JAX Scalar types Scalar = Union[float, int] -Numeric = Union[Array, Scalar] +Numeric = Union[jax.Array, Scalar]