Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent b5e3ea6 commit a409233
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 51 deletions.
69 changes: 69 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,3 +1927,72 @@ def _step(
def _set_seed(self, seed: Optional[int]):
self.manual_seed = seed
return seed


class EnvWithScalarAction(EnvBase):
def __init__(self, singleton: bool = False, **kwargs):
super().__init__(**kwargs)
self.singleton = singleton
self.action_spec = Bounded(
-1,
1,
shape=(
*self.batch_size,
1,
)
if self.singleton
else self.batch_size,
)
self.observation_spec = Composite(
observation=Unbounded(
shape=(
*self.batch_size,
3,
)
),
shape=self.batch_size,
)
self.done_spec = Composite(
done=Unbounded(self.batch_size + (1,), dtype=torch.bool),
terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
shape=self.batch_size,
)
self.reward_spec = Unbounded(
shape=(
*self.batch_size,
1,
)
)

def _reset(self, td: TensorDict):
return TensorDict(
observation=torch.randn(*self.batch_size, 3, device=self.device),
done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
truncated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
terminated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
device=self.device,
)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
return TensorDict(
observation=torch.randn(*self.batch_size, 3, device=self.device),
reward=torch.zeros(1, device=self.device),
done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
truncated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
terminated=torch.zeros(
*self.batch_size, 1, dtype=torch.bool, device=self.device
),
)

def _set_seed(self, seed: Optional[int]):
...
103 changes: 83 additions & 20 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
CountingEnvCountPolicy,
DiscreteActionConvMockEnv,
DiscreteActionConvMockEnvNumpy,
EnvWithScalarAction,
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
Expand All @@ -66,6 +67,7 @@
CountingEnvCountPolicy,
DiscreteActionConvMockEnv,
DiscreteActionConvMockEnvNumpy,
EnvWithScalarAction,
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
Expand Down Expand Up @@ -11781,17 +11783,33 @@ def test_transform_inverse(self):

class TestActionDiscretizer(TransformBase):
@pytest.mark.parametrize("categorical", [True, False])
def test_single_trans_env_check(self, categorical):
base_env = ContinuousActionVecMockEnv()
@pytest.mark.parametrize(
"env_cls",
[
ContinuousActionVecMockEnv,
partial(EnvWithScalarAction, singleton=True),
partial(EnvWithScalarAction, singleton=False),
],
)
def test_single_trans_env_check(self, categorical, env_cls):
base_env = env_cls()
env = base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_serial_trans_env_check(self, categorical):
@pytest.mark.parametrize(
"env_cls",
[
ContinuousActionVecMockEnv,
partial(EnvWithScalarAction, singleton=True),
partial(EnvWithScalarAction, singleton=False),
],
)
def test_serial_trans_env_check(self, categorical, env_cls):
def make_env():
base_env = ContinuousActionVecMockEnv()
base_env = env_cls()
return base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
Expand All @@ -11800,9 +11818,17 @@ def make_env():
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_parallel_trans_env_check(self, categorical):
@pytest.mark.parametrize(
"env_cls",
[
ContinuousActionVecMockEnv,
partial(EnvWithScalarAction, singleton=True),
partial(EnvWithScalarAction, singleton=False),
],
)
def test_parallel_trans_env_check(self, categorical, env_cls):
def make_env():
base_env = ContinuousActionVecMockEnv()
base_env = env_cls()
env = base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
Expand All @@ -11812,17 +11838,33 @@ def make_env():
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_trans_serial_env_check(self, categorical):
env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform(
@pytest.mark.parametrize(
"env_cls",
[
ContinuousActionVecMockEnv,
partial(EnvWithScalarAction, singleton=True),
partial(EnvWithScalarAction, singleton=False),
],
)
def test_trans_serial_env_check(self, categorical, env_cls):
env = SerialEnv(2, env_cls).append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_trans_parallel_env_check(self, categorical):
env = ParallelEnv(
2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx
).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical))
@pytest.mark.parametrize(
"env_cls",
[
ContinuousActionVecMockEnv,
partial(EnvWithScalarAction, singleton=True),
partial(EnvWithScalarAction, singleton=False),
],
)
def test_trans_parallel_env_check(self, categorical, env_cls):
env = ParallelEnv(2, env_cls, mp_start_method=mp_ctx).append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

def test_transform_no_env(self):
Expand All @@ -11838,7 +11880,6 @@ def test_transform_compose(self):
check_env_specs(env)

@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
@pytest.mark.parametrize("envname", ["cheetah", "pendulum"])
@pytest.mark.parametrize("interval_as_tensor", [False, True])
@pytest.mark.parametrize("categorical", [True, False])
@pytest.mark.parametrize(
Expand All @@ -11851,15 +11892,37 @@ def test_transform_compose(self):
ActionDiscretizer.SamplingStrategy.RANDOM,
],
)
def test_transform_env(self, envname, interval_as_tensor, categorical, sampling):
@pytest.mark.parametrize(
"env_cls",
[
"cheetah",
"pendulum",
partial(EnvWithScalarAction, singleton=True),
partial(EnvWithScalarAction, singleton=False),
],
)
def test_transform_env(self, env_cls, interval_as_tensor, categorical, sampling):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_env = GymEnv(
HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(),
device=device,
)
if interval_as_tensor:
num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6)
if env_cls == "cheetah":
base_env = GymEnv(
HALFCHEETAH_VERSIONED(),
device=device,
)
num_intervals = torch.arange(5, 11)
elif env_cls == "pendulum":
base_env = GymEnv(
PENDULUM_VERSIONED(),
device=device,
)
num_intervals = torch.arange(5, 6)
else:
base_env = env_cls(
device=device,
)
num_intervals = torch.arange(5, 6)

if not interval_as_tensor:
# override
num_intervals = 5
t = ActionDiscretizer(
num_intervals=num_intervals,
Expand Down
77 changes: 46 additions & 31 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8594,17 +8594,23 @@ def transform_input_spec(self, input_spec):
n_act = action_spec.shape
if not n_act:
n_act = ()
empty_shape = True
else:
n_act = (n_act[-1],)
empty_shape = False
self.n_act = n_act

self.dtype = action_spec.dtype
interval = (action_spec.high - action_spec.low)
if action_spec.ndimension():
interval = interval.unsqueeze(-1)
interval = action_spec.high - action_spec.low

num_intervals = self.num_intervals

if not empty_shape:
interval = interval.unsqueeze(-1)
elif isinstance(num_intervals, torch.Tensor):
num_intervals = int(num_intervals.squeeze())
self.num_intervals = torch.as_tensor(num_intervals)

def custom_arange(nint):
result = torch.arange(
start=0.0,
Expand All @@ -8627,14 +8633,13 @@ def custom_arange(nint):

if isinstance(num_intervals, int):
arange = (
custom_arange(num_intervals).expand((*n_act, num_intervals)) * interval
custom_arange(num_intervals).expand((*n_act, num_intervals))
* interval
)
low = action_spec.low
if action_spec.ndimension():
if not empty_shape:
low = low.unsqueeze(-1)
self.register_buffer(
"intervals", low + arange
)
self.register_buffer("intervals", low + arange)
else:
arange = [
custom_arange(_num_intervals) * interval
Expand All @@ -8649,15 +8654,17 @@ def custom_arange(nint):
)
]


if not isinstance(num_intervals, torch.Tensor):
nvec = torch.as_tensor(num_intervals, device=action_spec.device)
else:
nvec = num_intervals
if nvec.ndim > 1:
raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}")
if (nvec.ndim == 0 or nvec.numel() == 1) and action_spec.shape:
nvec = nvec.expand(action_spec.shape[-1])
if nvec.ndim == 0 or nvec.numel() == 1:
if not empty_shape:
nvec = nvec.expand(action_spec.shape[-1])
else:
nvec = nvec.squeeze()
self.register_buffer("nvec", nvec)
if self.sampling == self.SamplingStrategy.RANDOM:
# compute jitters
Expand All @@ -8668,7 +8675,7 @@ def custom_arange(nint):
else (*action_spec.shape[:-1], nvec.sum())
)

if action_spec.shape:
if not empty_shape:
cls = (
functools.partial(MultiCategorical, remove_singleton=False)
if self.categorical
Expand All @@ -8677,13 +8684,12 @@ def custom_arange(nint):
action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device)

else:
cls = (
Categorical
if self.categorical
else OneHot
)
action_spec = cls(n=nvec, shape=shape, device=action_spec.device)
cls = Categorical if self.categorical else OneHot
action_spec = cls(n=int(nvec), shape=shape, device=action_spec.device)

batch_size = self.parent.batch_size
if batch_size:
action_spec = action_spec.expand(batch_size + action_spec.shape)
input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec

if self.out_keys_inv[0] != self.in_keys_inv[0]:
Expand Down Expand Up @@ -8721,8 +8727,8 @@ def _inv_call(self, tensordict):
if self.categorical:
action = action.unsqueeze(-1)
if isinstance(intervals, torch.Tensor):
print('action', action, action.shape)
print('intervals', intervals, intervals.shape)
shape = action.shape[: -intervals.ndim]
intervals = intervals.expand(shape + intervals.shape)
action = intervals.gather(index=action, dim=-1).squeeze(-1)
else:
action = torch.stack(
Expand All @@ -8733,17 +8739,26 @@ def _inv_call(self, tensordict):
-1,
)
else:
nvec = self.nvec.tolist()
action = action.split(nvec, dim=-1)
if isinstance(intervals, torch.Tensor):
intervals = intervals.unbind(-2)
action = torch.stack(
[
intervals[action].view(action.shape[:-1])
for (intervals, action) in zip(intervals, action)
],
-1,
)
nvec = self.nvec
empty_shape = not nvec.ndim
if not empty_shape:
nvec = nvec.tolist()
if isinstance(intervals, torch.Tensor):
shape = action.shape[: (-intervals.ndim + 1)]
intervals = intervals.expand(shape + intervals.shape)
intervals = intervals.unbind(-2)
action = action.split(nvec, dim=-1)
action = torch.stack(
[
intervals[action].view(action.shape[:-1])
for (intervals, action) in zip(intervals, action)
],
-1,
)
else:
shape = action.shape[: -intervals.ndim]
intervals = intervals.expand(shape + intervals.shape)
action = intervals[action].squeeze(-1)

if self.sampling == self.SamplingStrategy.RANDOM:
action = action + self.jitters * torch.rand_like(self.jitters)
Expand Down

0 comments on commit a409233

Please sign in to comment.