Replies: 1 comment
-
After reading autodidax manual, jax source code and unit tests for persistent cache, I came up with this solution: import jax
from jax._src import compilation_cache as cc
from jax._src import xla_bridge
import numpy as np
import jax.numpy as jnp
from jax import jit
@jit
def fun(x):
print("tracing!")
return x ** 2
x_input = jnp.array([1.0, 2.0])
# Compile and save compilation to disk, run this code only the first time
cc.initialize_cache("mycache")
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
computation = (
jax.jit(fun)
.lower(x_input)
.compiler_ir()
)
executable = backend.compile(str(computation), compile_options)
cc.put_executable("myexecutable", "afun", executable, backend)
# After python restart just run the following:
cc.initialize_cache("mycache")
backend = xla_bridge.get_backend()
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1
)
executable = cc.get_executable("myexecutable", compile_options, backend)
executable.execute([jnp.array([2.0, 3.0])])[0] This directly loads executable from the disk and executes it directly, without running tracers so this is really fast. I am aware that executable is compiled only for specific shape of the input and it will not work if the shape changes, but I can manage this problem by recompiling the function for different input shapes. I have one question though, is this solution safe to use? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Is it possible to write persistent cache for jit tracers such that tracing won't be executed again after python restarts, rather it will be retrieved from disk?
For example:
on second run (after python restart I get):
If I read thet correctly this line
DEBUG:jax._src.dispatch:Finished tracing + transforming fun for pjit in 0.0014641284942626953 sec
means that tracing is performed. For my actual function the tracing takes a long time to finish and I was thinking how can I store the tracing information so it won't be recomputed?
If I run this function again (without python restart) no tracing is performed this time.
For my actual case I have one function that takes many combinations of input dimensions and tracing being run after each dimension change is troublesome and takes many seconds while actual computation is done in miliseconds. Why do we have this mechanism at all when persistent compiled cache is on?
Beta Was this translation helpful? Give feedback.
All reactions