diff --git a/keras/src/backend/jax/export.py b/keras/src/backend/jax/export.py index 963648460dcd..dd754c144184 100644 --- a/keras/src/backend/jax/export.py +++ b/keras/src/backend/jax/export.py @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs): self._tf_trackable.non_trainable_variables, non_trainable_variables, ): - var.assign(new_value) + var.assign(tf.cast(new_value, var.dtype)) return output stateful_fn.__signature__ = inspect.Signature( diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 8fd69d1f5bf8..86093125ed17 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -5,6 +5,8 @@ from keras.src import backend from keras.src import tree from keras.src.api_export import keras_export +from keras.src.backend.common.variables import is_float_dtype +from keras.src.backend.common.variables import standardize_dtype from keras.src.layers.layer import Layer from keras.src.saving import serialization_lib from keras.src.utils import jax_utils @@ -291,18 +293,28 @@ def _create_variables(self, values, trainable): """ def create_variable(value): - if backend.is_tensor(value) or isinstance(value, np.ndarray): - variable = self.add_weight( - value.shape, initializer="zeros", trainable=trainable + if backend.is_tensor(value) or isinstance( + value, (np.ndarray, np.generic) + ): + dtype = value.dtype + if is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + value.shape, + initializer=value, + dtype=dtype, + trainable=trainable, ) - variable.assign(value) - return variable - elif isinstance(value, (np.generic, int, float)): - variable = self.add_weight( - (), initializer="zeros", trainable=trainable + elif isinstance(value, (bool, int, float)): + dtype = standardize_dtype(type(value)) + if is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + (), + initializer=backend.convert_to_tensor(value), + dtype=dtype, + trainable=trainable, ) - variable.assign(value) - return variable else: return value diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 359bdca41c9c..306c930660f6 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -15,6 +15,7 @@ from keras.src import testing from keras.src import tree from keras.src import utils +from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer @@ -362,6 +363,18 @@ def call(self, inputs): "non_trainable_weights": 1, "non_trainable_params": 1, }, + { + "testcase_name": "training_state_dtype_policy", + "init_kwargs": { + "call_fn": jax_stateful_apply, + "init_fn": jax_stateful_init, + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 1, + "non_trainable_params": 1, + }, ) def test_jax_layer( self, @@ -414,6 +427,19 @@ def test_jax_layer( "non_trainable_weights": 8, "non_trainable_params": 536, }, + { + "testcase_name": "training_rng_unbound_method_dtype_policy", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, ) @pytest.mark.skipif(flax is None, reason="Flax library is not available.") def test_flax_layer(