Skip to content

Commit

Permalink
Torch cannot test mixed precision on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Sep 21, 2023
1 parent 77fca18 commit 5775439
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions keras_core/layers/rnn/dropout_rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,27 @@ def test_basics(self):
run_mixed_precision_check=False,
)

# custom mixed_float16 check
self.run_layer_test(
layers.RNN,
init_kwargs={
"cell": RNNCellWithDropout(5, seed=1337, dtype="mixed_float16"),
"dtype": "mixed_float16",
},
input_shape=(3, 2, 4),
call_kwargs={"training": True},
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_non_trainable_variables=1,
supports_masking=True,
run_mixed_precision_check=False,
)
# Custom mixed_float16 check
# Never test mixed precision on torch CPU. Torch lacks support.
if backend.backend() == "torch":
import torch

run_mixed_precision_check = torch.cuda.is_available()
if run_mixed_precision_check:
self.run_layer_test(
layers.RNN,
init_kwargs={
"cell": RNNCellWithDropout(
5, seed=1337, dtype="mixed_float16"
),
"dtype": "mixed_float16",
},
input_shape=(3, 2, 4),
call_kwargs={"training": True},
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_non_trainable_variables=1,
supports_masking=True,
run_mixed_precision_check=False,
)

0 comments on commit 5775439

Please sign in to comment.