From 5efabf539a4eac8288e563c1a3c00854b8d5c8fb Mon Sep 17 00:00:00 2001 From: Sayed Qaiser Ali Date: Thu, 21 Sep 2023 16:35:27 +0530 Subject: [PATCH] corrected jax expression, numpy function --- keras_core/backend/jax/math.py | 2 +- keras_core/backend/numpy/math.py | 2 +- keras_core/ops/math_test.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_core/backend/jax/math.py b/keras_core/backend/jax/math.py index 7ff0a729e..efafb6e91 100644 --- a/keras_core/backend/jax/math.py +++ b/keras_core/backend/jax/math.py @@ -251,4 +251,4 @@ def rsqrt(x): def erf(x): - return jnp.erf(x) + return jax.lax.erf(x) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index d7e6fd3c4..ce8cdb4d2 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -305,4 +305,4 @@ def rsqrt(x): def erf(x): - return scipy.special.erf(x) + return np.array(scipy.special.erf(x)) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index 14324cbe3..1b5ffcb2c 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -846,7 +846,7 @@ def test_erf_operation_basic(self): ) # Output from the erf operation in keras_core - output_from_erf_op = kmath.erf(sample_values).numpy() + output_from_erf_op = kmath.erf(sample_values) # Assert that the outputs are close self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) @@ -860,7 +860,7 @@ def test_erf_operation_dtype(self): expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( sample_values ) - output_from_erf_op = kmath.erf(sample_values).numpy() + output_from_erf_op = kmath.erf(sample_values) self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) def test_erf_operation_edge_cases(self): @@ -869,7 +869,7 @@ def test_erf_operation_edge_cases(self): expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( edge_values ) - output_from_edge_erf_op = kmath.erf(edge_values).numpy() + output_from_edge_erf_op = kmath.erf(edge_values) self.assertAllClose( expected_edge_output, output_from_edge_erf_op, atol=1e-5 )