diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 48667728071..9b83995b1cd 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) - state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value) + state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": if action.shape != state_action_value.shape: