You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX jitting can be insanely slow when there are large constants in the graph. We could add a helper to convert any large constants to symbolic inputs (we already did some constant folding work on our end anyway), so JAX can't get hang up on those.
The idea is to have a pytensor.graph.replace.replace_large_constants_by_inputs that returns the graph with constants replaced by PyTensor input variables and the respective values