diff --git a/MaxText/layers.py b/MaxText/layers.py index c5e5c97e5..5bb411898 100644 --- a/MaxText/layers.py +++ b/MaxText/layers.py @@ -368,8 +368,8 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): def __call__(self, inputs_q: Array, inputs_kv: Array, + enable_flash_attention, decoder_segment_ids = None, - enable_flash_attention = True, inputs_positions:Optional[Array] = None, mask: Optional[Array] = None, bias: Optional[Array] = None, @@ -1046,6 +1046,7 @@ def __call__(self, mesh = mesh)( lnx, lnx, + enable_flash_attention=cfg.enable_flash_attention, decoder_segment_ids=decoder_segment_ids, inputs_positions=decoder_positions, mask=decoder_mask, diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 330218770..2d01706bf 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -54,7 +54,7 @@ def setUp(self): config=self.cfg, mesh = self.mesh) self.variable = self.attention.init({'params': self.rng, 'aqt': self.rng}, jnp.ones((self.BS, self.MAX_TARGET_LENGTH, self.BASE_EMB_DIM)), - jnp.ones((self.BS, self.MAX_TARGET_LENGTH, self.BASE_EMB_DIM)), jnp.ones((self.BS, self.MAX_TARGET_LENGTH))) + jnp.ones((self.BS, self.MAX_TARGET_LENGTH, self.BASE_EMB_DIM)), False) def get_decoder_mask(self):