Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Aug 15, 2024
1 parent 007a9e4 commit 4298c36
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion MaxText/attentions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ def main(argv: Sequence[str]) -> None:
# empty variable states in pure flash attention
vars = {}
jitted = jax.jit(nn.apply(test_tpu_flash_attention, attention_op))
query, key, value, decoder_segment_ids = next(train_iter)
for step in np.arange(start_step, config.steps):
if step == first_profiling_step:
prof.activate()
with jax.profiler.StepTraceAnnotation("train", step_num=step):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
query, key, value, decoder_segment_ids = next(train_iter)
out = jitted(vars, query, key, value, decoder_segment_ids)
# jax.block_until_ready(out)
# pdb.set_trace()
Expand Down

0 comments on commit 4298c36

Please sign in to comment.