Skip to content

Commit

Permalink
Replace view by reshape to support unclear operations
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Sep 30, 2024
1 parent fe57d20 commit f30e9f7
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions skrl/utils/spaces/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,29 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor
# Box
if isinstance(space, spaces.Box):
if isinstance(x, torch.Tensor):
return x.view(-1, *space.shape)
return x.reshape(-1, *space.shape)
elif isinstance(x, np.ndarray):
return torch.tensor(x, device=device, dtype=torch.float32).view(-1, *space.shape)
return torch.tensor(x, device=device, dtype=torch.float32).reshape(-1, *space.shape)
else:
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# Discrete
elif isinstance(space, spaces.Discrete):
if isinstance(x, torch.Tensor):
return x.view(-1, 1)
return x.reshape(-1, 1)
elif isinstance(x, np.ndarray):
return torch.tensor(x, device=device, dtype=torch.int32).view(-1, 1)
return torch.tensor(x, device=device, dtype=torch.int32).reshape(-1, 1)
elif isinstance(x, np.number) or type(x) in [int, float]:
return torch.tensor([x], device=device, dtype=torch.int32).view(-1, 1)
return torch.tensor([x], device=device, dtype=torch.int32).reshape(-1, 1)
else:
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
if isinstance(x, torch.Tensor):
return x.view(-1, *space.shape)
return x.reshape(-1, *space.shape)
elif isinstance(x, np.ndarray):
return torch.tensor(x, device=device, dtype=torch.int32).view(-1, *space.shape)
return torch.tensor(x, device=device, dtype=torch.int32).reshape(-1, *space.shape)
elif type(x) in [list, tuple]:
return torch.tensor([x], device=device, dtype=torch.int32).view(-1, *space.shape)
return torch.tensor([x], device=device, dtype=torch.int32).reshape(-1, *space.shape)
else:
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# composite spaces
Expand All @@ -79,7 +79,7 @@ def flatten_tensorized_space(x: Any) -> torch.Tensor:
# fundamental spaces
# Box / Discrete / MultiDiscrete
if isinstance(x, torch.Tensor):
return x.view(x.shape[0], -1) if x.ndim > 1 else x.view(1, -1)
return x.reshape(x.shape[0], -1) if x.ndim > 1 else x.reshape(1, -1)
# composite spaces
# Dict
elif isinstance(x, dict):
Expand All @@ -102,13 +102,13 @@ def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
return x.view(-1, *space.shape)
return x.reshape(-1, *space.shape)
# Discrete
elif isinstance(space, spaces.Discrete):
return x.view(-1, 1)
return x.reshape(-1, 1)
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
return x.view(-1, *space.shape)
return x.reshape(-1, *space.shape)
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
Expand Down

0 comments on commit f30e9f7

Please sign in to comment.