Feedback on API for implicit arrays #16280
Replies: 2 comments 10 replies
-
@chaokunyang re: #16289, would this be applicable to your use case? I don't know if there's enough demand for nullability for it to become a first class JAX feature but I think you could implement it in "user space" this way. Edit: I took a crack at prototyping it and it seems to work alright! |
Beta Was this translation helpful? Give feedback.
-
I can see this must be part of the code you've omitted. Why is this necessary? As per #16259 we can probably do this without new avals.
I think this is definitely desirable! If I were doing this myself I'd using
Actually, my main criticism of the above approach is that I think we should go full multiple dispatch. For example this means not defining @register_multiple_dispatch(jax.lax.mul_p)
def _(x: ImplicitZeros, y):
return x
@register_multiple_dispatch(jax.lax.mul_p)
def _(x, y: ImplicitZeros):
return y (side note, I'm also using primitives themselves, not strings, to refer to the operaiton to be overloaded).
I'd say probably not. I don't think this is needed.
Hehehe. |
Beta Was this translation helpful? Give feedback.
-
Edit (June 6): I've uploaded my transform as Qax in case anyone wants to try it out.
As I mentioned in #16259, I noticed that there were several cases where I wanted something which represented an array without fully instantiating it. This only requires a subset of the full power of custom interpreters, and requires a bunch of repeated tasks each time such as flattening args and passing the transformation to for higher level primitives such as remat.
I ended up with an abstract
ImplicitArray
type which may be subclassed by a class which implements:implicit_op("<primitive_name>")
, and can return either JAX types or furtherImplicitArray
values.Here's the base class, I've omitted the actual
Tracer
implementation and some helper functions for brevity:(You may want to just scroll past and look at the toy example since I think it gets the idea across.)
ImplicitArray interface
Here's a toy example of using it to represent a symbolic zero (I think you can probably get XLA to accomplish the same thing with constant folding, so it's just for demonstration purposes):
Symbolic zero example
Since I've implemented
mul
only, the output ofx * y
will be anotherImplicitZeros
, but then it gets materialized into a full matrix when the unknownreduce_sum
op is hit.The things I wasn't sure about were:
jax.core.pytyp_aval_mappings
?__add__
and__radd__
in python). Maybe the separate handlers setup would be preferable?ImplicitArray
can look at theparams
for a primitive and refuse to handle it by returningNone
. For example in my LoRA implementation I only allow gathers which look like embedding lookups, otherwise I materialize the array. Should the "decide whether to handle this op" be separate from the handling itself? Also, isNone
ever a valid return value for an op? If so I'll need to switch to some other null result, or maybe exceptional control flow.I'm interested in any other feedback as well.
P.S. This should be compatible with Equinox, just in case anyone pops through to ask about that ;)
Edit: Here's another example, based on the request from #16289. It's not feature complete, but it demonstrates the idea:
Masked array example
You can also nest ImplictArrays inside each other:
Masked symbolic zero
Edit June 13: I used Qax to implement 4-bit absmax quantization here. It only has a materialize rule since my adventures with triton kernels have mostly been unsuccessful so far. This isn't really new since my GPTQ code works the same way, but it was really easy to rewrite it this way.
Beta Was this translation helpful? Give feedback.
All reactions