Skip to content

Commit

Permalink
change observations only from float64 to float32 before adding to the…
Browse files Browse the repository at this point in the history
… replay buffer

Summary: This is to prepare future diffs for supporting Atari games. In Atari games, the observations are uint8 type. Currently, pearl converts uint8 observations to float32 before adding them to the replay buffer. A huge amount of memory is used when the agent saves float32 images in the replay buffer. This diff reduces the memory usage by not making the conversion when the observations are added to the replay buffer, and makes this conversion when batchs are sampled from the replay buffer.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65923483

fbshipit-source-id: b43825a31c4da2aa4e58cea5bb47025379b0903e
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 7, 2024
1 parent 3e52c4a commit 8941dbe
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pearl/replay_buffers/tensor_based_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _create_transition_batch(
x.next_unavailable_actions_mask
)

state_batch = torch.cat(state_list)
state_batch = torch.cat(state_list).type(torch.float32)
action_batch = torch.cat(action_list)
reward_batch = torch.cat(reward_list)
terminated_batch = torch.cat(terminated_list)
Expand All @@ -358,7 +358,7 @@ def _create_transition_batch(
cost_batch = None
next_state_batch, next_action_batch = None, None
if has_next_state:
next_state_batch = torch.cat(next_state_list)
next_state_batch = torch.cat(next_state_list).type(torch.float32)
if has_next_action:
next_action_batch = torch.cat(next_action_list)
curr_available_actions_batch, curr_unavailable_actions_mask_batch = None, None
Expand Down
4 changes: 2 additions & 2 deletions pearl/utils/instantiations/environments/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def reset(self, seed: int | None = None) -> tuple[Observation, ActionSpace]:
# TODO: Deprecate this part at some point and only support new
# version of Gymnasium?
observation = list(reset_result.values())[0] # pyre-ignore
if isinstance(observation, np.ndarray):
if isinstance(observation, np.float64):
observation = observation.astype(np.float32)
return observation, self.action_space

Expand Down Expand Up @@ -151,7 +151,7 @@ def step(self, action: Action) -> ActionResult:
else:
available_action_space = None

if isinstance(observation, np.ndarray):
if isinstance(observation, np.float64):
observation = observation.astype(np.float32)
if isinstance(reward, np.float64):
reward = reward.astype(np.float32)
Expand Down

0 comments on commit 8941dbe

Please sign in to comment.