jnp.vectorize static argument #19064
Answered
by
jakevdp
chrisflesher
asked this question in
Ideas
-
Hello, I like the syntactic sugar of using
But when I run this it says:
Is there a way to set a default value for |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Dec 20, 2023
Replies: 1 comment 3 replies
-
No, unfortunately def _cross_correlate(source_image: jax.Array,
target_image: jax.Array,
ifft: typing.Callable = lambda x: jnp.real(jnp.fft.ifft2(x)),
) -> jax.Array:
return _cross_correlate_impl(source_image, target_image, iff)
@functools.partial(jnp.vectorize, signature=('(a,b),(a,b)->(c,d)'), excluded=(2,))
def _cross_correlate_impl(source_image: jax.Array, target_image: jax.Array, ifft: typing.Callable) -> jax.Array:
fft = jnp.fft.fft2
return ifft(fft(source_image) * fft(target_image[::-1, ::-1]) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I opened #19066 to track adding this feature. Once that is complete, I believe passing
excluded=(2, "ifft")
to your original function would be sufficient to get the behavior you want.