Skip to content

Commit

Permalink
[Feature] flexible batch_locked for jumanji
Browse files Browse the repository at this point in the history
ghstack-source-id: 600bf2237122d1db878625d2e1889dab6a603f74
Pull Request resolved: #2382
  • Loading branch information
vmoens committed Aug 9, 2024
1 parent e4fcd86 commit 96e0458
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 37 deletions.
25 changes: 21 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,7 @@ def test_jumanji_seeding(self, envname):

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_batch_size(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
env.set_seed(0)
tdreset = env.reset()
tdrollout = env.rollout(max_steps=50)
Expand All @@ -1564,7 +1564,7 @@ def test_jumanji_batch_size(self, envname, batch_size):

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_spec_rollout(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
env.set_seed(0)
check_env_specs(env)

Expand All @@ -1575,7 +1575,7 @@ def test_jumanji_consistency(self, envname, batch_size):
import numpy as onp
from torchrl.envs.libs.jax_utils import _tree_flatten

env = JumanjiEnv(envname, batch_size=batch_size)
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
obs_keys = list(env.observation_spec.keys(True))
env.set_seed(1)
rollout = env.rollout(10)
Expand Down Expand Up @@ -1613,7 +1613,7 @@ def test_jumanji_consistency(self, envname, batch_size):
@pytest.mark.parametrize("batch_size", [[3], []])
def test_jumanji_rendering(self, envname, batch_size):
# check that this works with a batch-size
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size)
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size, jit=True)
env.set_seed(0)
env.transform.transform_observation_spec(env.base_env.observation_spec)

Expand All @@ -1626,6 +1626,23 @@ def test_jumanji_rendering(self, envname, batch_size):

check_env_specs(env)

@pytest.mark.parametrize("jit", [True, False])
def test_jumanji_batch_unlocked(self, envname, jit):
torch.manual_seed(0)
env = JumanjiEnv(envname, jit=jit)
env.set_seed(0)
assert not env.batch_locked
reset = env.reset(TensorDict(batch_size=[16]))
assert reset.batch_size == (16,)
env.rand_step(reset)
t0 = time.time()
r = env.rollout(
20, auto_reset=False, tensordict=reset, break_when_all_done=True
)
assert r.batch_size[0] == 16
done = r["next", "done"].float()
assert (done.cumprod(-2) == done).all()


ENVPOOL_CLASSIC_CONTROL_ENVS = [
PENDULUM_VERSIONED(),
Expand Down
22 changes: 18 additions & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,16 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
# sanity check
self._assert_tensordict_shape(tensordict)
if not self.batch_locked:
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
partial_steps = tensordict.get("_step", None)
if partial_steps is not None:
if partial_steps.all():
partial_steps = None
else:
tensordict_batch_size = tensordict.batch_size
partial_steps = partial_steps.view(tensordict_batch_size)
tensordict = tensordict[partial_steps]
next_preset = tensordict.get("next", None)

next_tensordict = self._step(tensordict)
Expand All @@ -1485,6 +1495,10 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
next_preset.exclude(*next_tensordict.keys(True, True))
)
tensordict.set("next", next_tensordict)
if partial_steps is not None:
result = tensordict.new_zeros(tensordict_batch_size)
result[partial_steps] = tensordict
return result
return tensordict

@classmethod
Expand Down Expand Up @@ -2696,7 +2710,7 @@ def _rollout_stop_early(
if break_when_all_done:
if partial_steps is not True:
# At least one partial step has been done
del td_append["_partial_steps"]
del td_append["_step"]
td_append = torch.where(
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
)
Expand All @@ -2722,17 +2736,17 @@ def _rollout_stop_early(
_terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_partial_steps",
key="_step",
write_full_false=False,
)
partial_step_curr = tensordict.get("_partial_steps", None)
partial_step_curr = tensordict.get("_step", None)
if partial_step_curr is not None:
partial_step_curr = ~partial_step_curr
partial_steps = partial_steps & partial_step_curr
if partial_steps is not True:
if not partial_steps.any():
break
tensordict.set("_partial_steps", partial_steps)
tensordict.set("_step", partial_steps)

if callback is not None:
callback(self, tensordict)
Expand Down
12 changes: 8 additions & 4 deletions torchrl/envs/libs/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,21 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
return None


def _tensordict_to_object(tensordict: TensorDictBase, object_example):
def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None):
"""Converts a TensorDict to a namedtuple or a dataclass."""
from jax import dlpack as jax_dlpack, numpy as jnp

if batch_size is None:
batch_size = []
t = {}
_fields = _get_object_fields(object_example)
for name, example in _fields.items():
value = tensordict.get(name, None)
if isinstance(value, TensorDictBase):
t[name] = _tensordict_to_object(value, example)
t[name] = _tensordict_to_object(value, example, batch_size=batch_size)
elif value is None:
if isinstance(example, dict):
t[name] = _tensordict_to_object({}, example)
t[name] = _tensordict_to_object({}, example, batch_size=batch_size)
else:
t[name] = None
else:
Expand All @@ -140,7 +142,9 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example):
t[name] = value
else:
value = jnp.reshape(value, tuple(shape))
t[name] = value.view(example.dtype).reshape(example.shape)
t[name] = value.view(example.dtype).reshape(
(*batch_size, *example.shape)
)
return type(object_example)(**t)


Expand Down
138 changes: 113 additions & 25 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
Paper: https://arxiv.org/abs/2306.09884
.. note:: For better performance, turn `jit` on when instantiating this class.
The `jit` attribute can also be flipped during code execution:
>>> env.jit = True # Used jit
>>> env.jit = False # eager
Args:
env (jumanji.env.Environment): the env to wrap.
categorical_action_encoding (bool, optional): if ``True``, categorical
Expand All @@ -136,6 +142,22 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
Defaults to ``False``.
Keyword Args:
batch_size (torch.Size, optional): the batch size of the environment.
With ``jumanji``, this indicates the number of vectorized environments.
If the batch-size is empty, the environment is not batch-locked and an arbitrary number
of environments can be executed simultaneously.
Defaults to ``torch.Size([])``.
>>> import jumanji
>>> from torchrl.envs import JumanjiWrapper
>>> base_env = jumanji.make("Snake-v1")
>>> env = JumanjiWrapper(base_env)
>>> # Set the batch-size of the TensorDict instead of the env allows to control the number
>>> # of envs being run simultaneously
>>> tdreset = env.reset(TensorDict(batch_size=[32]))
>>> # Execute a rollout until all envs are done or max steps is reached, whichever comes first
>>> rollout = env.rollout(100, break_when_all_done=True, auto_reset=False, tensordict=tdreset)
from_pixels (bool, optional): Whether the environment should render its output.
This will drastically impact the environment throughput. Only the first environment
will be rendered. See :meth:`~torchrl.envs.JumanjiWrapper.render` for more information.
Expand All @@ -146,17 +168,15 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
of rewards across steps.
device (torch.device, optional): if provided, the device on which the data
is to be cast. Defaults to ``torch.device("cpu")``.
batch_size (torch.Size, optional): the batch size of the environment.
With ``jumanji``, this indicates the number of vectorized environments.
Defaults to ``torch.Size([])``.
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
for envs to be ``done`` just after :meth:`~.reset` is called.
Defaults to ``False``.
jit (bool, optional): whether the step and reset method should be wrapped in `jit`.
Defaults to ``False``.
Attributes:
available_envs: environments availalbe to build
Examples:
Examples:
>>> import jumanji
>>> from torchrl.envs import JumanjiWrapper
Expand Down Expand Up @@ -334,6 +354,7 @@ def __init__(
self,
env: "jumanji.env.Environment" = None, # noqa: F821
categorical_action_encoding=True,
jit: bool = True,
**kwargs,
):
if not _has_jumanji:
Expand All @@ -343,7 +364,26 @@ def __init__(
self.categorical_action_encoding = categorical_action_encoding
if env is not None:
kwargs["env"] = env
batch_locked = kwargs.pop("batch_locked", kwargs.get("batch_size") is not None)
super().__init__(**kwargs)
self._batch_locked = batch_locked
self.jit = jit

@property
def jit(self):
return self._jit

@jit.setter
def jit(self, value):
self._jit = value
if value:
import jax

self._env_reset = jax.jit(self._env.reset)
self._env_step = jax.jit(self._env.step)
else:
self._env_reset = self._env.reset
self._env_step = self._env.step

def _build_env(
self,
Expand Down Expand Up @@ -486,17 +526,21 @@ def _set_seed(self, seed):
raise Exception("Jumanji requires an integer seed.")
self.key = jax.random.PRNGKey(seed)

def read_state(self, state):
state_dict = _object_to_tensordict(state, self.device, self.batch_size)
def read_state(self, state, batch_size=None):
state_dict = _object_to_tensordict(
state, self.device, self.batch_size if batch_size is None else batch_size
)
return self.state_spec["state"].encode(state_dict)

def read_obs(self, obs):
def read_obs(self, obs, batch_size=None):
from jax import numpy as jnp

if isinstance(obs, (list, jnp.ndarray, np.ndarray)):
obs_dict = _ndarray_to_tensor(obs).to(self.device)
else:
obs_dict = _object_to_tensordict(obs, self.device, self.batch_size)
obs_dict = _object_to_tensordict(
obs, self.device, self.batch_size if batch_size is None else batch_size
)
return super().read_obs(obs_dict)

def render(
Expand Down Expand Up @@ -561,7 +605,11 @@ def render(
isinteractive = plt.isinteractive()
plt.ion()
buf = io.BytesIO()
state = _tensordict_to_object(tensordict.get("state"), _state_example)
state = _tensordict_to_object(
tensordict.get("state"),
_state_example,
batch_size=tensordict.batch_size if not self.batch_locked else None,
)
self._env.render(state, **kwargs)
plt.savefig(buf, format="png")
buf.seek(0)
Expand All @@ -580,24 +628,33 @@ def render(
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
import jax

if self.batch_locked:
batch_size = self.batch_size
else:
batch_size = tensordict.batch_size

# prepare inputs
state = _tensordict_to_object(tensordict.get("state"), self._state_example)
state = _tensordict_to_object(
tensordict.get("state"),
self._state_example,
batch_size=tensordict.batch_size if not self.batch_locked else None,
)
action = self.read_action(tensordict.get("action"))

# flatten batch size into vector
state = _tree_flatten(state, self.batch_size)
action = _tree_flatten(action, self.batch_size)
state = _tree_flatten(state, batch_size)
action = _tree_flatten(action, batch_size)

# jax vectorizing map on env.step
state, timestep = jax.vmap(self._env.step)(state, action)
state, timestep = jax.vmap(self._env_step)(state, action)

# reshape batch size from vector
state = _tree_reshape(state, self.batch_size)
timestep = _tree_reshape(timestep, self.batch_size)
state = _tree_reshape(state, batch_size)
timestep = _tree_reshape(timestep, batch_size)

# collect outputs
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
state_dict = self.read_state(state, batch_size=batch_size)
obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
reward = self.read_reward(np.asarray(timestep.reward))
done = timestep.step_type == self.lib.types.StepType.LAST
done = _ndarray_to_tensor(done).view(torch.bool).to(self.device)
Expand All @@ -622,32 +679,63 @@ def _reset(
import jax
from jax import numpy as jnp

if self.batch_locked:
numel = self.numel()
batch_size = self.batch_size
else:
numel = tensordict.numel()
batch_size = tensordict.batch_size

# generate random keys
self.key, *keys = jax.random.split(self.key, self.numel() + 1)
self.key, *keys = jax.random.split(self.key, numel + 1)

# jax vectorizing map on env.reset
state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys))
state, timestep = jax.vmap(self._env_reset)(jnp.stack(keys))

# reshape batch size from vector
state = _tree_reshape(state, self.batch_size)
timestep = _tree_reshape(timestep, self.batch_size)
state = _tree_reshape(state, batch_size)
timestep = _tree_reshape(timestep, batch_size)

# collect outputs
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
done_td = self.full_done_spec.zero()
state_dict = self.read_state(state, batch_size=batch_size)
obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
if not self.batch_locked:
done_td = self.full_done_spec.zero(batch_size)
else:
done_td = self.full_done_spec.zero()

# build results
tensordict_out = TensorDict(
source=obs_dict,
batch_size=self.batch_size,
batch_size=batch_size,
device=self.device,
)
tensordict_out.update(done_td)
tensordict_out["state"] = state_dict

return tensordict_out

def read_reward(self, reward):
"""Reads the reward and maps it to the reward space.
Args:
reward (torch.Tensor or TensorDict): reward to be mapped.
"""
if isinstance(reward, int) and reward == 0:
return self.reward_spec.zero()
if self.batch_locked:
reward = self.reward_spec.encode(reward, ignore_device=True)
else:
reward = torch.as_tensor(reward)
if reward.shape[-1] != self.reward_spec.shape[-1]:
reward = reward.unsqueeze(-1)

if reward is None:
reward = torch.tensor(np.nan).expand(self.reward_spec.shape)

return reward

def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple:
...

Expand Down

0 comments on commit 96e0458

Please sign in to comment.