diff --git a/ivy/functional/frontends/tensorflow/keras/activations.py b/ivy/functional/frontends/tensorflow/keras/activations.py index f80cb02763ea5..4c64d0b04eb5c 100644 --- a/ivy/functional/frontends/tensorflow/keras/activations.py +++ b/ivy/functional/frontends/tensorflow/keras/activations.py @@ -111,6 +111,44 @@ def selu(x): return ivy.selu(x) +@with_supported_dtypes( + {"2.13.0 and below": ("float16", "float32", "float64")}, + "tensorflow", +) +def serialize(activation, use_legacy_format=False): + # If the activation function is None, return None + if activation is None: + return None + + # If the activation function is already a string, return it + elif isinstance(activation, str): + return activation + + # If the activation function is callable (a function), get its name + elif callable(activation): + # Check if the function is in the custom_objects dictionary + if custom_objects: + for name, custom_func in custom_objects.items(): + if custom_func == activation: + return name + + # Check if the function is in the ACTIVATION_FUNCTIONS list + if activation.__name__ in ACTIVATION_FUNCTIONS: + return activation.__name__ + + # Check if the function is in the TensorFlow frontend activations + elif activation in tf_frontend.keras.activations.__dict__.values(): + for name, tf_func in tf_frontend.keras.activations.__dict__.items(): + if tf_func == activation: + return name + + else: + raise ValueError(f"Unknown activation function: {activation}.") + + else: + raise ValueError(f"Could not interpret activation function: {activation}") + + @to_ivy_arrays_and_back def sigmoid(x): return ivy.sigmoid(x) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py index da5def824a5c7..8782a884a2ed9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py @@ -342,6 +342,50 @@ def test_tensorflow_selu( ) +# serialize +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.serialize", + fn_name=st.sampled_from(get_callable_functions("keras.activations")).filter( + lambda x: not x[0].isupper() + and x + not in [ + "deserialize", + "get", + "keras_export", + "serialize", + "deserialize_keras_object", + "serialize_keras_object", + "get_globals", + ] + ), + dtype_and_data=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ), +) +def test_tensorflow_serialize( + *, + dtype_and_data, + fn_name, + fn_tree, + frontend, +): + dtype_data, data = dtype_and_data + simple_test_two_function( + fn_name=fn_name, + x=data[0], + frontend=frontend, + fn_str="serialize", + dtype_data=dtype_data[0], + rtol_=1e-01, + atol_=1e-01, + ivy_submodules=["keras", "activations"], + framework_submodules=["keras", "activations"], + ) + + # sigmoid @handle_frontend_test( fn_tree="tensorflow.keras.activations.sigmoid",