From 7736f67cfb4fa847bf53dbba2ce4bf73aace3adb Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:15:44 +0530 Subject: [PATCH] Fix bug in broadcast_to Op with Jax backend when keras variable as input (#19118) * Fix bug in broadcast_to Op with Jax backend when keras variable as input * Fix format error * convert_to_numpy chnaged to convert_to_tensor * Removed duplicate import of convert_to_tensor --- keras/backend/jax/numpy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/backend/jax/numpy.py b/keras/backend/jax/numpy.py index ac9d10ae367..0b062489065 100644 --- a/keras/backend/jax/numpy.py +++ b/keras/backend/jax/numpy.py @@ -322,6 +322,7 @@ def average(x, axis=None, weights=None): def broadcast_to(x, shape): + x = convert_to_tensor(x) return jnp.broadcast_to(x, shape)