Causal Mask HLO for jnp.tril(jnp.ones()) can be simplified #9709
Replies: 3 comments 4 replies
-
This looks like a useful optimisation to me. However I think it could also be possible to pre-optimize this on the jax side, maybe you can try
instead of converting to bool at the end. Still, ideally the compiler should be able to figure out that this can be simplified. |
Beta Was this translation helpful? Give feedback.
-
A more direct way to generate this mask is mask = jnp.tri(seq_len, dtype=bool) That results in more-or-less the HLO you mentioned:
|
Beta Was this translation helpful? Give feedback.
-
I think we identified common JAX code issues encountered when generating a Causal mask. The user either neglected to specify I opened PR-9867 (Merged) to simplify potential Causal mask suboptimal HLO |
Beta Was this translation helpful? Give feedback.
-
Quite often we see the following JAX code to create and apply causal mask in the self-attention layer.
If we jit + lower + compile + as_text this function
then we will get the following hlo:
We noticed that
fused_computation
can be simplified.In particular, Mask related code can be simplified to just iota + iota + compare
As a result the following several ops can be removed
What do you think about recognizing such a pattern and applying the described simplification to it?
Seems like a very common use-case in LLM models.
Our team will be happy to work on it.
Beta Was this translation helpful? Give feedback.
All reactions