Skip to content

Commit

Permalink
Merged comments from Junpeng
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Sep 16, 2024
1 parent 1cbe9cf commit 35442a3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def is_turning(
return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: bool
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand Down Expand Up @@ -279,7 +279,7 @@ def is_turning(
# return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: bool
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand Down
2 changes: 1 addition & 1 deletion blackjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ class WelfordAlgorithmState(NamedTuple):

#: JAX Scalar types
Scalar = Union[float, int]
Numeric = Union[Array, Scalar]
Numeric = Union[jax.Array, Scalar]

0 comments on commit 35442a3

Please sign in to comment.