Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hk.switch does not work inside a hk.vmap function when hk.set_state is used #785

Open
rsmath opened this issue Jul 2, 2024 · 1 comment

Comments

@rsmath
Copy link

rsmath commented Jul 2, 2024

Hi,

I have been using Haiku (amazing tool, btw!) for about 8 months now. Up until now, I used to wrap my custom haiku modules around hk.transform. Inside my module, I vmapped a function (using hk.vmap) that contained a hk.switch statement (to evaluate a chosen branch function).

I recently moved to using stateful modules, which needs the hk.transform_with_state transform. I also need to keep track of a specific value over time in my machine learning model that is not to be updated by the optimizer. For this, I am using hk.set_state("name", val) to store it and access it later. However, as soon as I use any kind of set_state call anywhere in the model, the vmapped function fails with the error

ValueError: vmap has mapped output but out_axes is None

Is there any way to use hk.switch inside a hk.vmap function when hk.set_state is used in the module?

Thank you.

@rsmath
Copy link
Author

rsmath commented Jul 2, 2024

Code

import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp


def forward(input_):
    # hk.set_state("some_state_val", 2) # line that causes the problem

    def func(i, x):
        temp = hk.switch(i.squeeze(), applys, x)
        return temp

    func_vmap = hk.vmap(func, in_axes=(0, None), split_rng=True)

    applys = []
    for i in range(10):
        temp_apply = lambda x: x**2
        applys.append(temp_apply)

    pred = func_vmap(input_[1], input_[0])

    return pred

rng_key = jax.random.PRNGKey(4)

stateful_forward = hk.transform_with_state(forward)

data = jnp.asarray(np.random.rand(100, 30))
idx = jnp.asarray(np.random.randint(0, 10, (100, 1))).astype(jnp.int32)

inp = [data, idx]

init, apply = stateful_forward.init, stateful_forward.apply

init = jax.jit(init)
apply = jax.jit(apply)

params, state = init(rng_key, inp)

output, state = apply(params, state, rng_key, inp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant