diff --git a/autobound/jax/jaxpr_editor.py b/autobound/jax/jaxpr_editor.py index d5bca20..5c2306c 100644 --- a/autobound/jax/jaxpr_editor.py +++ b/autobound/jax/jaxpr_editor.py @@ -169,7 +169,11 @@ def vertex_to_var_or_literal(vertex): if vertex[0]: _, count, suffix, aval = vertex if count not in count_to_var: - count_to_var[count] = jax.core.Var(count, suffix, aval) + if jax.__version__ >= '0.4.25': + # count argument was removed in jax 0.4.25: https://github.com/google/jax/pull/10573 + count_to_var[count] = jax.core.Var(suffix, aval) + else: + count_to_var[count] = jax.core.Var(count, suffix, aval) return count_to_var[count] else: _, val, aval = vertex