Skip to content

Commit

Permalink
Added the necessary transforms for Hindsight Experience Replay
Browse files Browse the repository at this point in the history
  • Loading branch information
Dimitrios Tsaras authored and Dimitrios Tsaras committed Dec 19, 2024
1 parent 133d709 commit 5dbbd26
Showing 1 changed file with 163 additions and 0 deletions.
163 changes: 163 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TensorDictBase,
unravel_key,
unravel_key_list,
pad_sequence,
)
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import (
Expand Down Expand Up @@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
high=torch.iinfo(torch.int64).max,
)
return super().transform_observation_spec(observation_spec)


class HERSubGoalSampler(Transform):
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
"""
def __init__(
self,
num_samples: int = 4,
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.num_samples = num_samples
self.subgoal_idx_key = subgoal_idx_key
self.strategy = strategy

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)

batch_size, trajectory_len = trajectories.shape

if self.strategy == "last":
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size)

else:
subgoal_idxs = []
for i in range(batch_size):
subgoal_idxs.append(
TensorDict(
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]},
batch_size=torch.Size(),
)
)
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True)


class HERSubGoalAssigner(Transform):
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""
def __init__(
self,
achieved_goal_key: str = "achieved_goal",
desired_goal_key: str = "desired_goal",
):
self.achieved_goal_key = achieved_goal_key
self.desired_goal_key = desired_goal_key

def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase:
batch_size, trajectory_len = trajectories.shape
for i in range(batch_size):
subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key]
desired_goal_shape = trajectories[i][self.desired_goal_key].shape
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape)
trajectories[i][subgoals_idxs[i]]["next", "done"] = True
# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True

return trajectories


class HERRewardTransform(Transform):
"""This module assigns the reward to the trajectory according to the new subgoal.
Args:
reward_name (str): The key to the reward. Defaults to "reward".
"""
def __init__(
self
):
pass

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
return trajectories


class HindsightExperienceReplayTransform(Transform):
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
This module is a wrapper that includes the following modules:
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal.
Args:
SubGoalSampler (Transform):
SubGoalAssigner (Transform):
RewardTransform (Transform):
"""
def __init__(
self,
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.SubGoalSampler = SubGoalSampler
self.SubGoalAssigner = SubGoalAssigner
self.RewardTransform = RewardTransform
self.assign_subgoal_idxs = assign_subgoal_idxs

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
augmentation_td = self.her_augmentation(tensordict)
return torch.cat([tensordict, augmentation_td], dim=0)

def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor:
return self.her_augmentation(tensordict)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(self.ENV_ERR)

def her_augmentation(self, trajectories: TensorDictBase):
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_length = trajectories.shape

new_trajectories = trajectories.clone(True)

# Sample subgoal indices
subgoal_idxs = self.SubGoalSampler(new_trajectories)

# Create new trajectories
augmented_trajectories = []
list_idxs = []
for i in range(batch_size):
idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key]

if "masks" in subgoal_idxs.keys():
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]]

list_idxs.append(idxs.unsqueeze(-1))
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True)

if self.assign_subgoal_idxs:
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length)

augmented_trajectories.append(new_traj)
augmented_trajectories = torch.cat(augmented_trajectories, dim=0)
associated_idxs = torch.cat(list_idxs, dim=0)

# Assign subgoals to the new trajectories
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs)

# Adjust the rewards based on the new subgoals
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories)

return augmented_trajectories

0 comments on commit 5dbbd26

Please sign in to comment.