Skip to content

Commit

Permalink
Add the truncated field.
Browse files Browse the repository at this point in the history
Summary: PPO needs to use the truncated signal. This diff adds this signal.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65910424

fbshipit-source-id: 487b6026d846abb427825969c7c85404cc90940b
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 21, 2024
1 parent 96c527b commit 7462526
Show file tree
Hide file tree
Showing 25 changed files with 80 additions and 20 deletions.
1 change: 1 addition & 0 deletions pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def observe(
else action_result.available_action_space
), # next_available_actions
terminated=action_result.terminated,
truncated=action_result.truncated,
# pyre-fixme[6]: For 8th argument expected `Optional[int]` but got
# `Union[None, Tensor, Module]`.
max_number_actions=(
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
td_error
+ self._discount_factor
* self._trace_decay_param
* (~transition.terminated)
* (not (transition.terminated or transition.truncated))
* gae
)
assert isinstance(transition, PPOTransition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def learn(
_curr_available_actions,
_next_available_actions,
terminated,
truncated,
_max_number_actions,
_cost,
) = transition
Expand Down Expand Up @@ -170,14 +171,17 @@ def learn(
self.q_values[(state, action.item())] = new_q_value

if self.debug:
self.print_debug_information(state, action, reward, next_state, terminated)
self.print_debug_information(
state, action, reward, next_state, terminated, truncated
)

return {
"state": state,
"action": action,
"reward": reward,
"next_state": next_state,
"terminated": terminated,
"truncated": truncated,
}

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
Expand All @@ -190,12 +194,14 @@ def print_debug_information(
reward: Reward,
next_state: SubjectiveState,
terminated: bool,
truncated: bool,
) -> None:
print("state:", state)
print("action:", action)
print("reward:", reward)
print("next state:", next_state)
print("terminated:", terminated)
print("truncated:", truncated)
print("q-values:", self.q_values)

def __str__(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions pearl/replay_buffers/basic_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _store_transition(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions_tensor_with_padding: Optional[Tensor],
curr_unavailable_actions_mask: Optional[Tensor],
next_state: Optional[SubjectiveState],
Expand All @@ -44,6 +45,7 @@ def _store_transition(
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
truncated=self._process_single_truncated(truncated),
cost=self._process_single_cost(cost),
)
self.memory.append(transition)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ActionSpace,
ActionSpace,
bool,
bool,
Optional[int],
Optional[float],
]
Expand Down Expand Up @@ -58,6 +59,7 @@ def push(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions: Optional[ActionSpace] = None,
next_state: Optional[SubjectiveState] = None,
next_available_actions: Optional[ActionSpace] = None,
Expand All @@ -74,6 +76,7 @@ def push(
curr_available_actions,
next_available_actions,
to_default_device_if_tensor(terminated),
to_default_device_if_tensor(truncated),
max_number_actions,
to_default_device_if_tensor(cost),
)
Expand Down
1 change: 1 addition & 0 deletions pearl/replay_buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def push(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions: Optional[ActionSpace] = None,
next_state: Optional[SubjectiveState] = None,
next_available_actions: Optional[ActionSpace] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _store_transition(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions_tensor_with_padding: Optional[Tensor],
curr_unavailable_actions_mask: Optional[Tensor],
next_state: Optional[SubjectiveState],
Expand All @@ -78,6 +79,7 @@ def _store_transition(
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
truncated=self._process_single_truncated(truncated),
cost=self._process_single_cost(cost),
bootstrap_mask=bootstrap_mask,
)
Expand Down Expand Up @@ -109,5 +111,6 @@ def sample(self, batch_size: int) -> TransitionWithBootstrapMaskBatch:
next_available_actions=transition_batch.next_available_actions,
next_unavailable_actions_mask=transition_batch.next_unavailable_actions_mask,
terminated=transition_batch.terminated,
truncated=transition_batch.truncated,
bootstrap_mask=bootstrap_mask_batch,
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
ActionSpace,
ActionSpace,
bool,
bool,
Optional[int],
Optional[float],
]
Expand All @@ -72,6 +73,7 @@ def push(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions: Optional[ActionSpace] = None,
next_state: Optional[SubjectiveState] = None,
next_available_actions: Optional[ActionSpace] = None,
Expand All @@ -86,6 +88,7 @@ def push(
action,
reward,
terminated,
truncated,
curr_available_actions,
next_state,
next_available_actions,
Expand All @@ -111,11 +114,12 @@ def push(
curr_available_actions,
next_available_actions,
terminated,
truncated,
max_number_actions,
cost,
)
)
if terminated:
if terminated or truncated:
additional_goal = next_state[: -self._goal_dim] # final mode
for (
state,
Expand All @@ -124,6 +128,7 @@ def push(
curr_available_actions,
next_available_actions,
terminated,
truncated,
max_number_actions,
cost,
) in self._trajectory:
Expand All @@ -141,6 +146,7 @@ def push(
if self._terminated_fn is None
else self._terminated_fn(state, action)
),
truncated,
curr_available_actions,
next_state,
next_available_actions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _store_transition(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions_tensor_with_padding: Optional[Tensor],
curr_unavailable_actions_mask: Optional[Tensor],
next_state: Optional[SubjectiveState],
Expand Down Expand Up @@ -67,9 +68,10 @@ def _store_transition(
next_available_actions=self.cache.next_available_actions,
next_unavailable_actions_mask=self.cache.next_unavailable_actions_mask,
terminated=self.cache.terminated,
truncated=self.cache.truncated,
)
)
if not terminated:
if not (terminated or truncated):
# save current push into cache
self.cache = Transition(
state=current_state,
Expand All @@ -81,9 +83,10 @@ def _store_transition(
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
truncated=self._process_single_truncated(truncated),
)
else:
# for terminal state, push directly
# for terminal state or time out, push directly
self.memory.append(
Transition(
state=current_state,
Expand All @@ -97,5 +100,6 @@ def _store_transition(
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
truncated=self._process_single_truncated(truncated),
)
)
12 changes: 12 additions & 0 deletions pearl/replay_buffers/tensor_based_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _store_transition(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions_tensor_with_padding: Optional[Tensor],
curr_unavailable_actions_mask: Optional[Tensor],
next_state: Optional[SubjectiveState],
Expand All @@ -62,6 +63,7 @@ def push(
action: Action,
reward: Reward,
terminated: bool,
truncated: bool,
curr_available_actions: Optional[ActionSpace] = None,
next_state: Optional[SubjectiveState] = None,
next_available_actions: Optional[ActionSpace] = None,
Expand Down Expand Up @@ -120,6 +122,7 @@ def push(
action,
reward,
terminated,
truncated,
curr_available_actions_tensor_with_padding,
curr_unavailable_actions_mask,
next_state,
Expand Down Expand Up @@ -169,6 +172,9 @@ def _process_single_cost(self, cost: Optional[float]) -> Optional[torch.Tensor]:
def _process_single_terminated(self, terminated: bool) -> torch.Tensor:
return torch.tensor([terminated]) # (1,)

def _process_single_truncated(self, truncated: bool) -> torch.Tensor:
return torch.tensor([truncated]) # (1,)

@staticmethod
def create_action_tensor_and_mask(
max_number_actions: Optional[int],
Expand Down Expand Up @@ -258,6 +264,7 @@ def sample(self, batch_size: int) -> TransitionBatch:
next_available_actions = tensor(batch_size, action_dim, action_dim),
next_available_actions_mask = tensor(batch_size, action_dim),
terminated = tensor(batch_size, ),
truncated = tensor(batch_size, ),
)
"""
if batch_size > len(self):
Expand Down Expand Up @@ -297,6 +304,7 @@ def _create_transition_batch(
next_available_actions=torch.empty(0),
next_unavailable_actions_mask=torch.empty(0),
terminated=torch.empty(0),
truncated=torch.empty(0),
cost=torch.empty(0),
).to(self.device_for_batches)

Expand All @@ -311,6 +319,7 @@ def _create_transition_batch(
reward_list = []
cost_list = []
terminated_list = []
truncated_list = []
next_state_list = []
next_action_list = []
curr_available_actions_list = []
Expand All @@ -322,6 +331,7 @@ def _create_transition_batch(
action_list.append(x.action)
reward_list.append(x.reward)
terminated_list.append(x.terminated)
truncated_list.append(x.truncated)
if has_cost_available:
cost_list.append(x.cost)
if has_next_state:
Expand All @@ -344,6 +354,7 @@ def _create_transition_batch(
action_batch = torch.cat(action_list)
reward_batch = torch.cat(reward_list)
terminated_batch = torch.cat(terminated_list)
truncated_batch = torch.cat(truncated_list)
if has_cost_available:
cost_batch = torch.cat(cost_list)
else:
Expand Down Expand Up @@ -377,5 +388,6 @@ def _create_transition_batch(
next_available_actions=next_available_actions_batch,
next_unavailable_actions_mask=next_unavailable_actions_mask_batch,
terminated=terminated_batch,
truncated=truncated_batch,
cost=cost_batch,
).to(self.device_for_batches)
5 changes: 5 additions & 0 deletions pearl/replay_buffers/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Transition:
action: torch.Tensor
reward: torch.Tensor
terminated: torch.Tensor = torch.tensor(True) # default True is useful for bandits
truncated: torch.Tensor = torch.tensor(True) # default True is useful for bandits
next_state: Optional[torch.Tensor] = None
next_action: Optional[torch.Tensor] = None
curr_available_actions: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -65,6 +66,7 @@ class TransitionBatch:
action: torch.Tensor
reward: torch.Tensor
terminated: torch.Tensor = torch.tensor(True) # default True is useful for bandits
truncated: torch.Tensor = torch.tensor(True) # default True is useful for bandits
next_state: Optional[torch.Tensor] = None
next_action: Optional[torch.Tensor] = None
curr_available_actions: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -133,17 +135,20 @@ def _filter_tensor(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
filtered_action = _filter_tensor(batch.action)
filtered_reward = _filter_tensor(batch.reward)
filtered_terminated = _filter_tensor(batch.terminated)
filtered_truncated = _filter_tensor(batch.truncated)

assert filtered_state is not None
assert filtered_action is not None
assert filtered_reward is not None
assert filtered_terminated is not None
assert filtered_truncated is not None

return TransitionBatch(
state=filtered_state,
action=filtered_action,
reward=filtered_reward,
terminated=filtered_terminated,
truncated=filtered_truncated,
next_state=_filter_tensor(batch.next_state),
next_action=_filter_tensor(batch.next_action),
curr_available_actions=_filter_tensor(batch.curr_available_actions),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def create_offline_data(
"curr_available_actions": env.action_space,
"next_available_actions": env.action_space,
"terminated": action_result.terminated,
"truncated": action_result.truncated,
}

observation = action_result.observation
Expand Down Expand Up @@ -184,7 +185,7 @@ def get_data_collection_agent_returns(
data_collection_agent_returns = []
g = 0
for transition in list(data):
if transition["terminated"]:
if transition["terminated"] or transition["truncated"]:
data_collection_agent_returns.append(g)
g = 0
else:
Expand Down
4 changes: 2 additions & 2 deletions pearl/utils/functional_utils/learning/loss_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def compute_cql_loss(
Inputs:
1) q_network: to compute the q values of every (state, action) pair.
2) batch: batch of data transitions (state, action, reward, terminated, next_state) along with
(current and next) available actions.
2) batch: batch of data transitions (state, action, reward, terminated, truncated, next_state)
along with (current and next) available actions.
3) batch_size: size of batch.
Outputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_offline_data_in_buffer(
- Assumes the offline data is an iterable consisting of transition tuples
(observation, action, reward, next_observation, curr_available_actions,
next_available_actions, terminated) as dictionaries.
next_available_actions, terminated, truncated) as dictionaries.
- Also assumes offline data is in a .pt file; reading from a
csv file can also be added later.
Expand Down Expand Up @@ -125,6 +125,7 @@ def get_offline_data_in_buffer(
curr_available_actions=transition["curr_available_actions"],
next_available_actions=transition["next_available_actions"],
terminated=transition["done"],
truncated=False,
max_number_actions=max_number_actions_if_discrete,
)

Expand Down
Loading

0 comments on commit 7462526

Please sign in to comment.