From 2d56cff1890bf4bd65dcadbd307e53e420896f72 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:59:16 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchrl/objectives/iql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 9b83995b1cd..26cd8e2e89a 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -785,7 +785,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < state_action_value.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) chosen_state_action_value = torch.gather( @@ -830,7 +830,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < state_action_value.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) chosen_state_action_value = torch.gather( @@ -867,7 +867,7 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < state_action_value.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1)