From 08b992e1117e24dd082ca33d9a046520cba34f3b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 12:04:01 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchrl/objectives/iql.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 26cd8e2e89a..039d5fc1c34 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -785,12 +785,15 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.ndim < state_action_value.ndim: + if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - chosen_state_action_value = torch.gather( - state_action_value, -1, index=action - ).squeeze(-1) + chosen_state_action_value = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) @@ -830,12 +833,17 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.ndim < state_action_value.ndim: + if action.ndim < ( + state_action_value.ndim - (td_q.ndim - tensordict.ndim) + ): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - chosen_state_action_value = torch.gather( - state_action_value, -1, index=action - ).squeeze(-1) + chosen_state_action_value = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) @@ -867,10 +875,15 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.ndim < state_action_value.ndim: + if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) + pred_val = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) elif self.action_space == "one_hot": action = action.to(torch.float) pred_val = (state_action_value * action).sum(-1)