Skip to content

Commit

Permalink
Add missing convert_to_tensor in broadcast_to for NumPy/JAX backends.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607793016
  • Loading branch information
jburnim authored and tensorflower-gardener committed Feb 16, 2024
1 parent b24cdb2 commit 6467548
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensorflow_probability/python/internal/backend/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ def __init__(self, *args, **kwargs):

broadcast_to = utils.copy_docstring(
'tf.broadcast_to',
lambda input, shape, name=None: np.broadcast_to(input, shape))
lambda input, shape, name=None: np.broadcast_to(
_convert_to_tensor(input), shape))


def _cast(x, dtype):
Expand Down

0 comments on commit 6467548

Please sign in to comment.