Skip to content

Commit

Permalink
[BugFix] better device consistency in EGreedy (#1867)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 4, 2024
1 parent 0672359 commit 5f82601
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

out = action_tensordict.get(action_key)
eps = self.eps.item()
cond = (
torch.rand(action_tensordict.shape, device=action_tensordict.device)
< eps
).to(out.dtype)
cond = torch.rand(action_tensordict.shape, device=out.device) < eps
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
Expand All @@ -177,7 +174,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
spec.update_mask(action_mask)
out = cond * spec.rand().to(out.device) + (1 - cond) * out
out = torch.where(cond, spec.rand().to(out.device), out)
else:
raise RuntimeError("spec must be provided to the exploration wrapper.")
action_tensordict.set(action_key, out)
Expand Down

0 comments on commit 5f82601

Please sign in to comment.