You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm still trying to figure out whether it can happen that gradients are un-necessarily recomputed in certain cases for jitted functions, so I wrote that snippet with profiling:
def f(x):
return x**2 + x**3 - x + 2*jnp.log(x) # random operation
def g(x, z):
y = f(x)
return z*(y + y**8 + jnp.exp(x) + jnp.sin(y)) # random operation
def h(x):
y = f(x)
return y**2 + jnp.cos(x**5-y) + 3 # random operation
@jax.jit
def F(x,z):
# part of these gradients is shared (i.e. gradient of f)
u, grad_u = jax.value_and_grad(g)(x,z)
v, grad_v = jax.value_and_grad(h)(x)
return u*grad_u + v*grad_v
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
F(2.,5.).block_until_ready()
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# running again on function that has been compiled already
F(8.,3.).block_until_ready()
The goal is to find out whether f and its gradients are computed twice. In the first Perfetto trace that I get I do see two slices (which would be the bad news) but in the second trace I just don't see anything (just a call to PJitFunction(F) which itself contains DevicePut and TftpCpuExecutable):absolutely no info on the underlying computations. I wonder if the first Perfetto trace that I see isn't actually the compilation part ? (and is therefore irrelevant for my analysis). If so: there's no way to profile the jitted executable ?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all,
I'm still trying to figure out whether it can happen that gradients are un-necessarily recomputed in certain cases for jitted functions, so I wrote that snippet with profiling:
The goal is to find out whether f and its gradients are computed twice. In the first Perfetto trace that I get I do see two slices (which would be the bad news) but in the second trace I just don't see anything (just a call to PJitFunction(F) which itself contains DevicePut and TftpCpuExecutable):absolutely no info on the underlying computations. I wonder if the first Perfetto trace that I see isn't actually the compilation part ? (and is therefore irrelevant for my analysis). If so: there's no way to profile the jitted executable ?
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions