diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 68acb9bfed4..168416d80da 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -388,6 +388,7 @@ def make_discrete_loss(loss_cfg, model): loss_function=loss_cfg.loss_function, temperature=loss_cfg.temperature, expectile=loss_cfg.expectile, + action_space="categorical", ) loss_module.make_value_estimator(gamma=loss_cfg.gamma) target_net_updater = HardUpdate( diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index db2c8afca10..d43cbd7810d 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -307,7 +307,7 @@ def _process_action_space_spec(action_space, spec): return action_space, spec -def _find_action_space(action_space): +def _find_action_space(action_space) -> str: if isinstance(action_space, TensorSpec): if isinstance(action_space, Composite): if "action" in action_space.keys(): diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 71d1a22e17b..48667728071 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) - state_action_value = td_q.get(self.tensor_keys.state_action_value) + state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": if action.shape != state_action_value.shape: @@ -791,9 +791,11 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: chosen_state_action_value = torch.gather( state_action_value, -1, index=action ).squeeze(-1) - else: + elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") min_Q, _ = torch.min(chosen_state_action_value, dim=0) if log_prob.shape != min_Q.shape: raise RuntimeError( @@ -834,9 +836,11 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: chosen_state_action_value = torch.gather( state_action_value, -1, index=action ).squeeze(-1) - else: + elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") min_Q, _ = torch.min(chosen_state_action_value, dim=0) # state value td_copy = tensordict.select(*self.value_network.in_keys, strict=False) @@ -867,9 +871,11 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) - else: + elif self.action_space == "one_hot": action = action.to(torch.float) pred_val = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") td_error = (pred_val - target_value.expand_as(pred_val)).pow(2) loss_qval = distance_loss(