Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for Code Structure Improvement Using jax.lax.cond #245

Open
1 task done
helpingstar opened this issue Sep 13, 2024 · 9 comments
Open
1 task done

Proposal for Code Structure Improvement Using jax.lax.cond #245

helpingstar opened this issue Sep 13, 2024 · 9 comments
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed

Comments

@helpingstar
Copy link
Contributor

Is your feature request related to a problem? Please describe

This is a simple question regarding code style. It is not related to any bugs.

timestep = jax.lax.cond(
done | (new_state.step_count >= self.time_limit),
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
),
lambda: transition(
reward=reward,
observation=observation,
extras=extras,
),
)
return new_state, timestep

timestep = jax.lax.cond(
done,
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
),
lambda: transition(
reward=reward,
observation=observation,
extras=extras,
),
)
return state, timestep

Rather than repeatedly using lambda and duplicating variables as shown in the code above, it seems better to follow the functional style of jax.lax.cond and write it in the style of the solution code below.

It seems that there is little to no difference in performance.
If this is a minor issue, I will close it.

Describe the solution you'd like

timestep = jax.lax.cond(
    done,
    termination,
    transition,
    reward,
    observation,
    extras,
)

Describe alternatives you've considered

None

Additional context

next_timestep = jax.lax.cond(
done,
termination,
transition,
reward,
next_observation,
)
return next_state, next_timestep

timestep = lax.cond(
done,
termination,
transition,
reward,
obs,
)
return next_state, timestep

next_timestep = jax.lax.cond(
done,
termination,
transition,
reward,
next_observation,
)
return next_state, next_timestep

timestep = jax.lax.cond(
done,
termination,
transition,
reward,
observation,
)
return next_state, timestep

next_timestep = jax.lax.cond(
done,
termination,
transition,
reward,
next_observation,
)
return next_state, next_timestep


Misc

  • Check for duplicate requests.
@helpingstar helpingstar added the enhancement New feature or request label Sep 13, 2024
@sash-a
Copy link
Collaborator

sash-a commented Sep 13, 2024

Hi, thanks for the suggestion! Agreed it isn't as clean as the others, but the reason it needs to be done this way is because the transition and termination function take different arguments:

def transition(
    reward: Array,
    observation: Observation,
    discount: Optional[Array] = None,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
) -> TimeStep:
def termination(
    reward: Array,
    observation: Observation,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
) -> TimeStep:

If we simply passed in un-named arguments through the cond like done in other envs it would pass the extras as discounts to the transition branch. There isn't really a clean solution here unless JAX allows for named arguments to jax.lax.cond

@clement-bonnet
Copy link
Collaborator

Hi, thank you for your comment and your observation. Yes, we would rather have what you suggested, i.e. a simple jax.lax.cond statement. However, as far as I know, cond statements only take positional argument (as opposed to kwargs), and the presence of extras in Connector and Game2048 that you mentioned above does not play well with the termination and transition functions. The transition function has an additional optional discount argument between observations and extras, which the termination function does not have. Therefore, they do not have the same signature as far as it concerns the first three positional arguments. So you can't simply use a cond statement.

@clement-bonnet
Copy link
Collaborator

Thank you @sash-a for replying so fast that I didn't see your comment! Agree with what you said.

A solution would be to move the discount argument from transition to after the extras argument so that they share the same (reward, observation, extras) positional arguments. Unless as you mentioned jax implements kwargs for cond.

@sash-a
Copy link
Collaborator

sash-a commented Sep 13, 2024

Ye that seems like a reasonable solution

@helpingstar
Copy link
Contributor Author

helpingstar commented Sep 13, 2024

It seems that the cases including the lambda form also included extras. Thank you both for your kind guidance. I also think @clement-bonnet 's solution looks good.

Since all instances calling transition use extras as keyword arguments, it seems possible to change the function definition. However, considering potential version issues, I’m cautious about submitting a PR right away.

I’m learning a lot from studying Jumanji’s code and can see the effort put into systematically building the JAX reinforcement learning environment throughout. Thank you very much for writing such excellent code.

I hope @clement-bonnet 's solution is implemented (or that jax.lax.cond will support keyword arguments. I hope there’s something I can help with.). I'll leave the issue open. Feel free to close it at any time!

@sash-a sash-a added good first issue Good for newcomers help wanted Extra attention is needed labels Oct 25, 2024
@sash-a
Copy link
Collaborator

sash-a commented Nov 1, 2024

@helpingstar would you be interested in making a PR to fix this?

@helpingstar
Copy link
Contributor Author

@sash-a I’d be glad to help with this. Is there a preferred timeline?

@sash-a
Copy link
Collaborator

sash-a commented Nov 1, 2024

No rush honestly, whenever you have time! I think @clement-bonnet suggested fix should work well

@helpingstar
Copy link
Contributor Author

@sash-a Got it, thanks for letting me know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants