Skip to content

Commit

Permalink
Merge pull request #262 from google:mohit/fix_flash_flag
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583172562
  • Loading branch information
maxtex authors committed Nov 16, 2023
2 parents 82ab361 + 4a47db8 commit da44730
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion MaxText/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids):
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
enable_flash_attention,
decoder_segment_ids = None,
enable_flash_attention = True,
inputs_positions:Optional[Array] = None,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
Expand Down Expand Up @@ -1046,6 +1046,7 @@ def __call__(self,
mesh = mesh)(
lnx,
lnx,
enable_flash_attention=cfg.enable_flash_attention,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
mask=decoder_mask,
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setUp(self):
config=self.cfg,
mesh = self.mesh)
self.variable = self.attention.init({'params': self.rng, 'aqt': self.rng}, jnp.ones((self.BS, self.MAX_TARGET_LENGTH, self.BASE_EMB_DIM)),
jnp.ones((self.BS, self.MAX_TARGET_LENGTH, self.BASE_EMB_DIM)), jnp.ones((self.BS, self.MAX_TARGET_LENGTH)))
jnp.ones((self.BS, self.MAX_TARGET_LENGTH, self.BASE_EMB_DIM)), False)


def get_decoder_mask(self):
Expand Down

0 comments on commit da44730

Please sign in to comment.