Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 5, 2024
1 parent e5c3e32 commit b320ed8
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchrl/envs/libs/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example):
else:
if value.dtype is torch.bool:
value = value.to(torch.uint8)
value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))
shape = value.shape
# We need to flatten to fix https://github.com/pytorch/rl/issues/2184
value = value.contiguous()
value = jax_dlpack.from_dlpack(value.detach().flatten())
value = value.reshape(shape)
t[name] = value.reshape(example.shape).view(example.dtype)
return type(object_example)(**t)

Expand All @@ -149,3 +153,4 @@ def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> Tensor
)
else:
raise TypeError(f"Unsupported data type {type(data)}")

0 comments on commit b320ed8

Please sign in to comment.