Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 0634dfd + 1f05619 commit 3cd393e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
1 change: 1 addition & 0 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def make_discrete_loss(loss_cfg, model):
loss_function=loss_cfg.loss_function,
temperature=loss_cfg.temperature,
expectile=loss_cfg.expectile,
action_space="categorical",
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
target_net_updater = HardUpdate(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _process_action_space_spec(action_space, spec):
return action_space, spec


def _find_action_space(action_space):
def _find_action_space(action_space) -> str:
if isinstance(action_space, TensorSpec):
if isinstance(action_space, Composite):
if "action" in action_space.keys():
Expand Down
14 changes: 10 additions & 4 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
state_action_value = td_q.get(self.tensor_keys.state_action_value)
state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.shape != state_action_value.shape:
Expand All @@ -791,9 +791,11 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
chosen_state_action_value = torch.gather(
state_action_value, -1, index=action
).squeeze(-1)
else:
elif self.action_space == "one_hot":
action = action.to(torch.float)
chosen_state_action_value = (state_action_value * action).sum(-1)
else:
raise RuntimeError(f"Unknown action space {self.action_space}.")
min_Q, _ = torch.min(chosen_state_action_value, dim=0)
if log_prob.shape != min_Q.shape:
raise RuntimeError(
Expand Down Expand Up @@ -834,9 +836,11 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
chosen_state_action_value = torch.gather(
state_action_value, -1, index=action
).squeeze(-1)
else:
elif self.action_space == "one_hot":
action = action.to(torch.float)
chosen_state_action_value = (state_action_value * action).sum(-1)
else:
raise RuntimeError(f"Unknown action space {self.action_space}.")
min_Q, _ = torch.min(chosen_state_action_value, dim=0)
# state value
td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
Expand Down Expand Up @@ -867,9 +871,11 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1)
else:
elif self.action_space == "one_hot":
action = action.to(torch.float)
pred_val = (state_action_value * action).sum(-1)
else:
raise RuntimeError(f"Unknown action space {self.action_space}.")

td_error = (pred_val - target_value.expand_as(pred_val)).pow(2)
loss_qval = distance_loss(
Expand Down

0 comments on commit 3cd393e

Please sign in to comment.