Replies: 3 comments
-
This is complicated and not well documented, and unfortunately there's not really any one-size-fits-all solution (it depends a lot on what you want your transform to do). But here's one example of flattening arguments to a function before passing it to a custom jaxpr evaluator: https://github.com/google/jax/blob/1b0be5095a62064820301d1fa25f3c38596e1ae2/jax/experimental/sparse/transform.py#L394-L402 One way to proceed is roughly this:
Let me know if anything is unclear! |
Beta Was this translation helpful? Give feedback.
-
Adding on to Jake's answer: Handling higher-order primitives like The high-level idea is to wrap a recursive call to your interpreter in
|
Beta Was this translation helpful? Give feedback.
-
Hi All, I am a novice to jax and I am trying to learn its inner machinery, and, first let me say, I am huge fun of the works of yours! Anyhow, I am writing you as I am interested in developing a "final style" custom interpreter, as suggested here. Here is a simple example I am working on where I want to interpret a sum with a multiplication: from jax import core, lax, grad, jit
from jax.interpreters import partial_eval
import jax
def fun(a,b):
return a+b
class SwapTracer(core.Tracer):
def __init__(self,trace,value):
self._trace=trace
self.value=value
@property
def aval(self):
return core.get_aval(self.value)
def full_lower(self):
return self
class SwapTrace(core.Trace):
def __init__(self,parent_trace):
self.parent_trace=parent_trace
def process_primitive(self, primitive, tracers, params):
print(f"Primitive: {primitive}")
print(f"Tracers: {tracers}")
print(f"Params: {params}")
invals=[tracer.value if isinstance(tracer,SwapTracer) else tracer for tracer in tracers]
print(f"Invals: {invals}")
if primitive is lax.add_p:
with core.set_current_trace(self.parent_trace):
outvals=lax.mul_p.bind(*invals,**params)
else:
outvals=primitive.bind_with_trace(self.parent_trace,invals,params)
print(f"Outvals: {outvals}")
if primitive.multiple_results:
out_tracers=[SwapTracer(self,outval) for outval in outvals]
else:
out_tracers=SwapTracer(self,outvals)
print(f"Outtracers: {out_tracers}\n")
return out_tracers
def process_call(self, call_primitive, f, tracers, params):
print(f"Call Primitive: {call_primitive}")
print(f"Tracers: {tracers}")
print(f"Params: {params}")
invals=[tracer.value if isinstance(tracer,SwapTracer) else tracer for tracer in tracers]
outvals=call_primitive.bind(swap(f),*invals,**params)
return SwapTracer(self,outvals)
def swap(f):
def wrapped(*args):
with core.take_current_trace() as parent_trace:
print(f"Parent:{parent_trace}")
trace=SwapTrace(parent_trace)
print(f"Current:{trace}\n")
in_tracers=[SwapTracer(trace,arg) for arg in args]
with core.set_current_trace(trace):
out_tracers=f(*in_tracers)
return out_tracers.value
return wrapped
# This works
with jax.disable_jit(disable=True):
result=swap(fun)(1.,5.)
print(f"Result without pjit: {result}\n")
# How to intercept PJIT?
with jax.disable_jit(disable=False):
result=swap(fun)(1.,5.)
print(f"Result with pjit: {result}\n") Please apologise if the question is trivial, but I cannot to seem to find a way around it and Thank you very much Best |
Beta Was this translation helpful? Give feedback.
-
Hello, I'm trying to implement a custom interpreter that would cast intermediate values between full and half precision to use performant half-precision primitives where available, without compromising the results on brittle ops like
reduce_sum
orexp
. I give more context here. I'm following the guide here, my code is visible here and I have two main questions.Firstly, how to handle non-unary functions? The tutorial says "we handle unary functions so we don't worry about flattening/unflattening". I can't make this assumption and I can't figure out where to get info about the output tree shape, such that I can
tree_unflatten
the outputs of my routine.Secondly, how do I handle subexpressions? I assume it should be somehow possible to recursively apply
amp
on thesubfuns
obtained in linesubfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
, but those are actually of typelinear_util.WrappedFun
which confuses me a bit. Currently, applying my code on a function transformed withjit
makes my interpreter have no effect.Finally I would like to ask for feedback regarding this idea, in general. In the linked discussion thread in Flax I get the feedback that implementing AMP as a JAX transform, rather than in Flax specifically, is a good idea. That said, I recently found a haiku implementation which seems to be haiku-specific and rather complicated to use, compared to a functional transform. Is there a reason the authors would like to implement it in a haiku-specific way? Note: I am aware that this project also needs a loss/gradient scaling method, but that is already provided in Flax.
Thank you for your help.
Beta Was this translation helpful? Give feedback.
All reactions