Skip to content

Commit

Permalink
Add support for dtype / DTypePolicy to JaxLayer and FlaxLayer.
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh committed Jan 6, 2025
1 parent 1adaaec commit 0a6a980
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
2 changes: 1 addition & 1 deletion keras/src/backend/jax/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs):
self._tf_trackable.non_trainable_variables,
non_trainable_variables,
):
var.assign(new_value)
var.assign(tf.cast(new_value, var.dtype))
return output

stateful_fn.__signature__ = inspect.Signature(
Expand Down
32 changes: 22 additions & 10 deletions keras/src/utils/jax_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend.common.variables import is_float_dtype
from keras.src.backend.common.variables import standardize_dtype
from keras.src.layers.layer import Layer
from keras.src.saving import serialization_lib
from keras.src.utils import jax_utils
Expand Down Expand Up @@ -291,18 +293,28 @@ def _create_variables(self, values, trainable):
"""

def create_variable(value):
if backend.is_tensor(value) or isinstance(value, np.ndarray):
variable = self.add_weight(
value.shape, initializer="zeros", trainable=trainable
if backend.is_tensor(value) or isinstance(
value, (np.ndarray, np.generic)
):
dtype = value.dtype
if is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
value.shape,
initializer=value,
dtype=dtype,
trainable=trainable,
)
variable.assign(value)
return variable
elif isinstance(value, (np.generic, int, float)):
variable = self.add_weight(
(), initializer="zeros", trainable=trainable
elif isinstance(value, (bool, int, float)):
dtype = standardize_dtype(type(value))
if is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
(),
initializer=backend.convert_to_tensor(value),
dtype=dtype,
trainable=trainable,
)
variable.assign(value)
return variable
else:
return value

Expand Down
26 changes: 26 additions & 0 deletions keras/src/utils/jax_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src import testing
from keras.src import tree
from keras.src import utils
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.saving import object_registration
from keras.src.utils.jax_layer import FlaxLayer
from keras.src.utils.jax_layer import JaxLayer
Expand Down Expand Up @@ -362,6 +363,18 @@ def call(self, inputs):
"non_trainable_weights": 1,
"non_trainable_params": 1,
},
{
"testcase_name": "training_state_dtype_policy",
"init_kwargs": {
"call_fn": jax_stateful_apply,
"init_fn": jax_stateful_init,
"dtype": DTypePolicy("mixed_float16"),
},
"trainable_weights": 6,
"trainable_params": 266610,
"non_trainable_weights": 1,
"non_trainable_params": 1,
},
)
def test_jax_layer(
self,
Expand Down Expand Up @@ -414,6 +427,19 @@ def test_jax_layer(
"non_trainable_weights": 8,
"non_trainable_params": 536,
},
{
"testcase_name": "training_rng_unbound_method_dtype_policy",
"flax_model_class": "FlaxDropoutModel",
"flax_model_method": None,
"init_kwargs": {
"method": "flax_dropout_wrapper",
"dtype": DTypePolicy("mixed_float16"),
},
"trainable_weights": 8,
"trainable_params": 648226,
"non_trainable_weights": 0,
"non_trainable_params": 0,
},
)
@pytest.mark.skipif(flax is None, reason="Flax library is not available.")
def test_flax_layer(
Expand Down

0 comments on commit 0a6a980

Please sign in to comment.