Behavior of custom_vmap with no batched dimensions #16414
Replies: 1 comment 15 replies
-
Thanks for the question! The issue is that when none of the axes are mapped, the batching tracer falls back to calling the primitive itself. The code is here: https://github.com/google/jax/blob/bfe8acb31e04a540daad3f568239ec0e5c3f0d0f/jax/_src/interpreters/batching.py#L393-L394 In the use-cases that I think changing this would be a big task: if we were to call the batching rule even in cases where there was no batching, every single batching rule currently in use would have to be updated to handle this non-batched case. That would be a big project just for the batching rules defined in I wonder if there are other ways you could handle this in your case – for example, could you define a default implementation that handles the non-batched calling path for your primitive? |
Beta Was this translation helpful? Give feedback.
-
Hello,
I'm interesting in writing custom vmap rules for some functions. I've been investigating the use of jax.custom_batching.custom_vmap. Everything seems to make sense, except I can't quite understand what's happening in the case where I don't provide any batched dimensions to vmap.
Here's a simple example:
If I run this, I get this output:
No problem so far. The output is as expected, and this all seems to make sense. But consider running this code instead:
If I run this, I get this output:
This gives the expected output. But my custom vmap rule was never called. From the jaxpr, I think that's because vmap thinks it's pointless to run the function multiple times—it can just run the function once and then make a copy of the result.
That's clearly a good optimization in most cases. But is there any way to avoid it? The problem for me is that I want to create "abstract" primitives that represent distributions. These don't themselves have any implementation. Instead, I have written program transformations like
sample()
andprob()
that inject functionality (PRNGKeys, etc). But the injected functionality depends on the transformation, so I don't want to inject it when callingvmap
, I only want to do it later.For "normal" vmap calls where at least one dimension is batched, I can intercept the vmap call with a
def_vmap
rule that can bind a new "batched" abstract primitive, which I can then deal with in the program transformation. That all seems to work very well! But in the above case, thedef_vmap
rule gets skipped, and I seem to be out of luck—the jaxpr represents calling the abstract primitive once, and then making a bunch of copies of that. I don't want that because I want to eventually feed differentPRNGKey
s to each of the "branches".Is there any solution to this? If it were possible, the ideal would just be a way to disable this "no batched dims optimization", so that my
def_vmap
rule always gets triggered byjax.vmap
.Beta Was this translation helpful? Give feedback.
All reactions