Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent eff8ce1 commit 2a0d9fa
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value)
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.shape != state_action_value.shape:
Expand Down

0 comments on commit 2a0d9fa

Please sign in to comment.