Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 3c8f8c3 + 8e78398 commit 98d7450
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.shape != state_action_value.shape:
if action.ndim < state_action_value.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
chosen_state_action_value = torch.gather(
Expand Down Expand Up @@ -830,7 +830,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.shape != state_action_value.shape:
if action.ndim < state_action_value.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
chosen_state_action_value = torch.gather(
Expand Down Expand Up @@ -867,7 +867,7 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.shape != state_action_value.shape:
if action.ndim < state_action_value.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1)
Expand Down

0 comments on commit 98d7450

Please sign in to comment.