Skip to content

Commit

Permalink
[BugFix] ActionDiscretizer scalar integration
Browse files Browse the repository at this point in the history
ghstack-source-id: 1fea2492610f8983b849f970cda7414181afad96
Pull Request resolved: #2619
  • Loading branch information
vmoens committed Nov 29, 2024
1 parent d537dcb commit a4d8477
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8585,21 +8585,23 @@ def _indent(s):

def transform_input_spec(self, input_spec):
try:
action_spec = input_spec["full_action_spec", self.in_keys_inv[0]]
action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]]
if not isinstance(action_spec, Bounded):
raise TypeError(
f"action spec type {type(action_spec)} is not supported."
f"action spec type {type(action_spec)} is not supported. The action spec type must be Bounded."
)

n_act = action_spec.shape
if not n_act:
n_act = 1
n_act = ()
else:
n_act = n_act[-1]
n_act = (n_act[-1],)
self.n_act = n_act

self.dtype = action_spec.dtype
interval = (action_spec.high - action_spec.low).unsqueeze(-1)
interval = (action_spec.high - action_spec.low)
if action_spec.ndimension():
interval = interval.unsqueeze(-1)

num_intervals = self.num_intervals

Expand All @@ -8625,10 +8627,13 @@ def custom_arange(nint):

if isinstance(num_intervals, int):
arange = (
custom_arange(num_intervals).expand(n_act, num_intervals) * interval
custom_arange(num_intervals).expand((*n_act, num_intervals)) * interval
)
low = action_spec.low
if action_spec.ndimension():
low = low.unsqueeze(-1)
self.register_buffer(
"intervals", action_spec.low.unsqueeze(-1) + arange
"intervals", low + arange
)
else:
arange = [
Expand All @@ -8644,19 +8649,14 @@ def custom_arange(nint):
)
]

cls = (
functools.partial(MultiCategorical, remove_singleton=False)
if self.categorical
else MultiOneHot
)

if not isinstance(num_intervals, torch.Tensor):
nvec = torch.as_tensor(num_intervals, device=action_spec.device)
else:
nvec = num_intervals
if nvec.ndim > 1:
raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}")
if nvec.ndim == 0 or nvec.numel() == 1:
if (nvec.ndim == 0 or nvec.numel() == 1) and action_spec.shape:
nvec = nvec.expand(action_spec.shape[-1])
self.register_buffer("nvec", nvec)
if self.sampling == self.SamplingStrategy.RANDOM:
Expand All @@ -8667,7 +8667,23 @@ def custom_arange(nint):
if self.categorical
else (*action_spec.shape[:-1], nvec.sum())
)
action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device)

if action_spec.shape:
cls = (
functools.partial(MultiCategorical, remove_singleton=False)
if self.categorical
else MultiOneHot
)
action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device)

else:
cls = (
Categorical
if self.categorical
else OneHot
)
action_spec = cls(n=nvec, shape=shape, device=action_spec.device)

input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec

if self.out_keys_inv[0] != self.in_keys_inv[0]:
Expand Down Expand Up @@ -8705,6 +8721,8 @@ def _inv_call(self, tensordict):
if self.categorical:
action = action.unsqueeze(-1)
if isinstance(intervals, torch.Tensor):
print('action', action, action.shape)
print('intervals', intervals, intervals.shape)
action = intervals.gather(index=action, dim=-1).squeeze(-1)
else:
action = torch.stack(
Expand Down

0 comments on commit a4d8477

Please sign in to comment.