Skip to content

Commit

Permalink
Allow to store tensors/arrays with their original dimensions and make…
Browse files Browse the repository at this point in the history
… it the default option
  • Loading branch information
Toni-SM committed Sep 10, 2024
1 parent 33e7c01 commit 6907e58
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
12 changes: 7 additions & 5 deletions skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _get_space_size(self,
raise ValueError(f"Space type {type(space)} not supported")

def _get_tensors_view(self, name):
if self.tensors_keep_dimensions[name]:
return self.tensors_view[name] if self._views else self.tensors[name].reshape(-1, *self.tensors_keep_dimensions[name])
return self.tensors_view[name] if self._views else self.tensors[name].reshape(-1, self.tensors[name].shape[-1])

def share_memory(self) -> None:
Expand Down Expand Up @@ -202,7 +204,7 @@ def create_tensor(self,
name: str,
size: Union[int, Tuple[int], gym.Space, gymnasium.Space],
dtype: Optional[np.dtype] = None,
keep_dimensions: bool = False) -> bool:
keep_dimensions: bool = True) -> bool:
"""Create a new internal tensor in memory
The tensor will have a 3-components shape (memory size, number of environments, size).
Expand Down Expand Up @@ -247,7 +249,7 @@ def create_tensor(self,
self.tensors[name] = getattr(self, f"_tensor_{name}")
with jax.default_device(self.device):
self.tensors_view[name] = self.tensors[name].reshape(*view_shape)
self.tensors_keep_dimensions[name] = keep_dimensions
self.tensors_keep_dimensions[name] = size if keep_dimensions else None
# fill the tensors (float tensors) with NaN
for name, tensor in self.tensors.items():
if tensor.dtype == np.float32 or tensor.dtype == np.float64:
Expand Down Expand Up @@ -309,7 +311,7 @@ def add_samples(self, **tensors: Mapping[str, Union[np.ndarray, jax.Array]]) ->
dim, shape = tmp.ndim, tmp.shape

# multi environment (number of environments equals num_envs)
if dim == 2 and shape[0] == self.num_envs:
if dim > 1 and shape[0] == self.num_envs:
if self._jax:
for name, tensor in tensors.items():
if name in self.tensors:
Expand All @@ -320,14 +322,14 @@ def add_samples(self, **tensors: Mapping[str, Union[np.ndarray, jax.Array]]) ->
self.tensors[name][self.memory_index] = tensor
self.memory_index += 1
# multi environment (number of environments less than num_envs)
elif dim == 2 and shape[0] < self.num_envs:
elif dim > 1 and shape[0] < self.num_envs:
raise NotImplementedError # TODO:
for name, tensor in tensors.items():
if name in self.tensors:
self.tensors[name] = self.tensors[name].at[self.memory_index, self.env_index:self.env_index + tensor.shape[0]].set(tensor)
self.env_index += tensor.shape[0]
# single environment - multi sample (number of environments greater than num_envs (num_envs = 1))
elif dim == 2 and self.num_envs == 1:
elif dim > 1 and self.num_envs == 1:
raise NotImplementedError # TODO:
for name, tensor in tensors.items():
if name in self.tensors:
Expand Down
8 changes: 4 additions & 4 deletions skrl/memories/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def create_tensor(self,
name: str,
size: Union[int, Tuple[int], gym.Space, gymnasium.Space],
dtype: Optional[torch.dtype] = None,
keep_dimensions: bool = False) -> bool:
keep_dimensions: bool = True) -> bool:
"""Create a new internal tensor in memory
The tensor will have a 3-components shape (memory size, number of environments, size).
Expand Down Expand Up @@ -264,19 +264,19 @@ def add_samples(self, **tensors: torch.Tensor) -> None:
dim, shape = tmp.ndim, tmp.shape

# multi environment (number of environments equals num_envs)
if dim == 2 and shape[0] == self.num_envs:
if dim > 1 and shape[0] == self.num_envs:
for name, tensor in tensors.items():
if name in self.tensors:
self.tensors[name][self.memory_index].copy_(tensor)
self.memory_index += 1
# multi environment (number of environments less than num_envs)
elif dim == 2 and shape[0] < self.num_envs:
elif dim > 1 and shape[0] < self.num_envs:
for name, tensor in tensors.items():
if name in self.tensors:
self.tensors[name][self.memory_index, self.env_index:self.env_index + tensor.shape[0]].copy_(tensor)
self.env_index += tensor.shape[0]
# single environment - multi sample (number of environments greater than num_envs (num_envs = 1))
elif dim == 2 and self.num_envs == 1:
elif dim > 1 and self.num_envs == 1:
for name, tensor in tensors.items():
if name in self.tensors:
num_samples = min(shape[0], self.memory_size - self.memory_index)
Expand Down

0 comments on commit 6907e58

Please sign in to comment.