diff --git a/keras_nlp/src/models/falcon/falcon_transformer_decoder.py b/keras_nlp/src/models/falcon/falcon_transformer_decoder.py index 301549b8d3..c11f9fdee2 100644 --- a/keras_nlp/src/models/falcon/falcon_transformer_decoder.py +++ b/keras_nlp/src/models/falcon/falcon_transformer_decoder.py @@ -133,7 +133,7 @@ def call( mask = decoder_padding_mask if mask is None: batch_size, seq_length = ops.shape(inputs)[:2] - mask = ops.ones((batch_size, seq_length), dtype="int32") + mask = ops.ones((batch_size, seq_length), dtype="int") alibi = self._build_alibi_tensor(self.num_attention_heads, mask) # Attention block. @@ -225,7 +225,7 @@ def _build_alibi_tensor(self, num_heads, attention_mask): self._get_slopes(num_heads), dtype=self.compute_dtype, ) # num_heads - attention_mask = ops.cast(attention_mask, dtype="int32") + attention_mask = ops.cast(attention_mask, dtype="int") arange_tensor = ( ((ops.cumsum(attention_mask, axis=-1) - 1) * attention_mask) )[:, None, :]