Skip to content

Commit

Permalink
fix the way we deal with truncation
Browse files Browse the repository at this point in the history
Summary: When episodes terminate because of timeout, TD-based methods should still bootstrap from the last state's value.

Reviewed By: rodrigodesalvobraz, jb3618columbia

Differential Revision: D55029520

fbshipit-source-id: 6b9c459cd2e6bb5c16dacd71f6d7f094583d1ad9
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Apr 1, 2024
1 parent c6b74e0 commit f96642b
Show file tree
Hide file tree
Showing 41 changed files with 122 additions and 110 deletions.
2 changes: 1 addition & 1 deletion pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def observe(
if action_result.available_action_space is None
else action_result.available_action_space
), # next_available_actions
done=action_result.done,
terminated=action_result.terminated,
max_number_actions=(
self.policy_learner.action_representation_module.max_number_actions
if not self.policy_learner.is_action_continuous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
batch=batch_filtered, batch_size=batch_filtered.state.shape[0], z=z
)
* self._discount_factor
* (1 - batch_filtered.done.float())
* (1 - batch_filtered.terminated.float())
) + batch_filtered.reward # (batch_size), r + gamma * V(s)

criterion = torch.nn.MSELoss()
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
# r + gamma * (min{Qtarget_1(s', a from target actor network),
# Qtarget_2(s', a from target actor network)})
expected_state_action_values = (
next_q * self._discount_factor * (1 - batch.done.float())
next_q * self._discount_factor * (1 - batch.terminated.float())
) + batch.reward # shape (batch_size)

assert isinstance(self._critic, TwinCritic), "DDPG requires TwinCritic critic"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
state_batch = batch.state # (batch_size x state_dim)
action_batch = batch.action # (batch_size x action_dim)
reward_batch = batch.reward # (batch_size)
done_batch = batch.done # (batch_size)
terminated_batch = batch.terminated # (batch_size)

batch_size = state_batch.shape[0]
# sanity check they have same batch_size
assert reward_batch.shape[0] == batch_size
assert done_batch.shape[0] == batch_size
assert terminated_batch.shape[0] == batch_size

state_action_values = self._Q.get_q_values(
state_batch=state_batch,
Expand All @@ -280,7 +280,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
expected_state_action_values = (
self.get_next_state_values(batch, batch_size)
* self._discount_factor
* (1 - done_batch.float())
* (1 - terminated_batch.float())
) + reward_batch # (batch_size), r + gamma * V(s)

criterion = torch.nn.MSELoss()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

# compute targets for batch of (state, action, next_state): target y = r + gamma * V(s')
target = (
values_next_states * self._discount_factor * (1 - batch.done.float())
values_next_states
* self._discount_factor
* (1 - batch.terminated.float())
) + batch.reward # shape: (batch_size)

assert isinstance(
Expand Down
4 changes: 2 additions & 2 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
for i, transition in enumerate(reversed(replay_buffer.memory)):
td_error = (
transition.reward
+ self._discount_factor * next_value * (~transition.done)
+ self._discount_factor * next_value * (~transition.terminated)
- state_values[i]
)
gae = (
td_error
+ self._discount_factor
* self._trace_decay_param
* (~transition.done)
* (~transition.terminated)
* gae
)
assert isinstance(transition, OnPolicyTransition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

"""
Step 2: compute Bellman target for each quantile location
- add a dimension to the reward and (1-done) vectors so they
- add a dimension to the reward and (1-terminated) vectors so they
can be broadcasted with the next state quantiles
"""

with torch.no_grad():
quantile_next_state_greedy_action_values = self._get_next_state_quantiles(
batch, batch_size
) * self._discount_factor * (1 - batch.done.float()).unsqueeze(
) * self._discount_factor * (1 - batch.terminated.float()).unsqueeze(
-1
) + batch.reward.unsqueeze(
-1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def learn(self, replay_buffer: ReplayBuffer) -> Dict[str, Any]:
# compute return for all states in the buffer
cum_reward = self._critic(
self._history_summarization_module(replay_buffer.memory[-1].next_state)
).detach() * (~replay_buffer.memory[-1].done)
).detach() * (~replay_buffer.memory[-1].terminated)
for transition in reversed(replay_buffer.memory):
cum_reward += transition.reward
assert isinstance(transition, OnPolicyTransition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ def reset(self, action_space: ActionSpace) -> None:
def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

reward_batch = batch.reward # (batch_size)
done_batch = batch.done # (batch_size)
terminated_batch = batch.terminated # (batch_size)

assert done_batch is not None
assert terminated_batch is not None
expected_state_action_values = (
self._get_next_state_expected_values(batch)
* self._discount_factor
* (1 - done_batch.float())
* (1 - terminated_batch.float())
) + reward_batch # (batch_size), r + gamma * V(s)

assert isinstance(self._critic, TwinCritic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

reward_batch = batch.reward # shape: (batch_size)
done_batch = batch.done # shape: (batch_size)
terminated_batch = batch.terminated # shape: (batch_size)

if done_batch is not None:
if terminated_batch is not None:
expected_state_action_values = (
self._get_next_state_expected_values(batch)
* self._discount_factor
* (1 - done_batch.float())
* (1 - terminated_batch.float())
) + reward_batch # shape of expected_state_action_values: (batch_size)
else:
raise AssertionError("done_batch should not be None")
raise AssertionError("terminated_batch should not be None")

loss = twin_critic_action_value_loss(
state_batch=batch.state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def learn(
next_state,
_curr_available_actions,
_next_available_actions,
done,
terminated,
_max_number_actions,
_cost,
) = transition
Expand All @@ -132,7 +132,7 @@ def learn(
for next_action in self._action_space
]

if done:
if terminated:
next_state_value = 0
else:
# pyre-fixme[6]: For 1st argument expected
Expand All @@ -154,14 +154,14 @@ def learn(
self.q_values[(state, action.item())] = new_q_value

if self.debug:
self.print_debug_information(state, action, reward, next_state, done)
self.print_debug_information(state, action, reward, next_state, terminated)

return {
"state": state,
"action": action,
"reward": reward,
"next_state": next_state,
"done": done,
"terminated": terminated,
}

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
Expand All @@ -173,13 +173,13 @@ def print_debug_information(
action: Action,
reward: Reward,
next_state: SubjectiveState,
done: bool,
terminated: bool,
) -> None:
print("state:", state)
print("action:", action)
print("reward:", reward)
print("next state:", next_state)
print("done:", done)
print("terminated:", terminated)
print("q-values:", self.q_values)

def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
# r + gamma * (min{Qtarget_1(s', a from target actor network),
# Qtarget_2(s', a from target actor network)})
expected_state_action_values = (
next_q * self._discount_factor * (1 - batch.done.float())
next_q * self._discount_factor * (1 - batch.terminated.float())
) + batch.reward # (batch_size)

# update twin critics towards bellman target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DiscreteContextualBanditReplayBuffer(TensorBasedReplayBuffer):
from other replay buffers
- No next action or next state related
- action is action idx instead of action value
- done is not needed, as for contextual bandit, it is always True
- terminated is not needed, as for contextual bandit, it is always True
"""

def __init__(self, capacity: int) -> None:
Expand All @@ -47,7 +47,7 @@ def push(
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
terminated: bool,
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def push(
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
terminated: bool,
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
Expand All @@ -63,7 +63,7 @@ def push(
next_state,
curr_available_actions,
next_available_actions,
done,
terminated,
max_number_actions,
cost,
)
Expand Down
2 changes: 1 addition & 1 deletion pearl/replay_buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def push(
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
terminated: bool,
max_number_actions: Optional[int],
cost: Optional[float] = None,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def push(
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
terminated: bool,
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
Expand Down Expand Up @@ -94,7 +94,7 @@ def push(
curr_unavailable_actions_mask=curr_unavailable_actions_mask,
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
done=self._process_single_done(done),
terminated=self._process_single_terminated(terminated),
cost=self._process_single_cost(cost),
bootstrap_mask=bootstrap_mask,
)
Expand Down Expand Up @@ -129,6 +129,6 @@ def sample(self, batch_size: int) -> TransitionWithBootstrapMaskBatch:
curr_unavailable_actions_mask=transition_batch.curr_unavailable_actions_mask,
next_available_actions=transition_batch.next_available_actions,
next_unavailable_actions_mask=transition_batch.next_unavailable_actions_mask,
done=transition_batch.done,
terminated=transition_batch.terminated,
bootstrap_mask=bootstrap_mask_batch,
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def push(
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
terminated: bool,
max_number_actions: Optional[int],
cost: Optional[float] = None,
) -> None:
Expand All @@ -63,7 +63,7 @@ def push(
curr_unavailable_actions_mask=curr_unavailable_actions_mask,
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
done=self._process_single_done(done),
terminated=self._process_single_terminated(terminated),
cost=self._process_single_cost(cost),
).to(self.device)
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def push(
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
terminated: bool,
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
Expand Down Expand Up @@ -76,10 +76,10 @@ def push(
curr_unavailable_actions_mask=self.cache.curr_unavailable_actions_mask,
next_available_actions=self.cache.next_available_actions,
next_unavailable_actions_mask=self.cache.next_unavailable_actions_mask,
done=self.cache.done,
terminated=self.cache.terminated,
).to(self.device)
)
if not done:
if not terminated:
# save current push into cache
self.cache = Transition(
state=current_state,
Expand All @@ -90,7 +90,7 @@ def push(
curr_unavailable_actions_mask=curr_unavailable_actions_mask,
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
done=self._process_single_done(done),
terminated=self._process_single_terminated(terminated),
).to(self.device)
else:
# for terminal state, push directly
Expand All @@ -106,6 +106,6 @@ def push(
curr_unavailable_actions_mask=curr_unavailable_actions_mask,
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
done=self._process_single_done(done),
terminated=self._process_single_terminated(terminated),
).to(self.device)
)
Loading

0 comments on commit f96642b

Please sign in to comment.