diff --git a/keras_core/layers/rnn/dropout_rnn_cell_test.py b/keras_core/layers/rnn/dropout_rnn_cell_test.py index a73512f75..4062f0aec 100644 --- a/keras_core/layers/rnn/dropout_rnn_cell_test.py +++ b/keras_core/layers/rnn/dropout_rnn_cell_test.py @@ -67,19 +67,27 @@ def test_basics(self): run_mixed_precision_check=False, ) - # custom mixed_float16 check - self.run_layer_test( - layers.RNN, - init_kwargs={ - "cell": RNNCellWithDropout(5, seed=1337, dtype="mixed_float16"), - "dtype": "mixed_float16", - }, - input_shape=(3, 2, 4), - call_kwargs={"training": True}, - expected_output_shape=(3, 5), - expected_num_trainable_weights=2, - expected_num_non_trainable_weights=0, - expected_num_non_trainable_variables=1, - supports_masking=True, - run_mixed_precision_check=False, - ) + # Custom mixed_float16 check + # Never test mixed precision on torch CPU. Torch lacks support. + if backend.backend() == "torch": + import torch + + run_mixed_precision_check = torch.cuda.is_available() + if run_mixed_precision_check: + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": RNNCellWithDropout( + 5, seed=1337, dtype="mixed_float16" + ), + "dtype": "mixed_float16", + }, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_non_trainable_variables=1, + supports_masking=True, + run_mixed_precision_check=False, + )