Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 98d7450 + 54d9949 commit 67a91d4
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,15 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.ndim < state_action_value.ndim:
if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
chosen_state_action_value = torch.gather(
state_action_value, -1, index=action
).squeeze(-1)
chosen_state_action_value = torch.vmap(
lambda state_action_value, action: torch.gather(
state_action_value, -1, index=action
).squeeze(-1),
(0, None),
)(state_action_value, action)
elif self.action_space == "one_hot":
action = action.to(torch.float)
chosen_state_action_value = (state_action_value * action).sum(-1)
Expand Down Expand Up @@ -830,12 +833,17 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.ndim < state_action_value.ndim:
if action.ndim < (
state_action_value.ndim - (td_q.ndim - tensordict.ndim)
):
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
chosen_state_action_value = torch.gather(
state_action_value, -1, index=action
).squeeze(-1)
chosen_state_action_value = torch.vmap(
lambda state_action_value, action: torch.gather(
state_action_value, -1, index=action
).squeeze(-1),
(0, None),
)(state_action_value, action)
elif self.action_space == "one_hot":
action = action.to(torch.float)
chosen_state_action_value = (state_action_value * action).sum(-1)
Expand Down Expand Up @@ -867,10 +875,15 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
if self.action_space == "categorical":
if action.ndim < state_action_value.ndim:
if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1)
pred_val = torch.vmap(
lambda state_action_value, action: torch.gather(
state_action_value, -1, index=action
).squeeze(-1),
(0, None),
)(state_action_value, action)
elif self.action_space == "one_hot":
action = action.to(torch.float)
pred_val = (state_action_value * action).sum(-1)
Expand Down

0 comments on commit 67a91d4

Please sign in to comment.