From 56cfabaa417a70a43866326fe4d5e13983d6a222 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:57:24 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchrl/objectives/iql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 48667728071..9b83995b1cd 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) - state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value) + state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": if action.shape != state_action_value.shape: