Skip to content

Commit

Permalink
change is_action_continuous to _is_action_continuous
Browse files Browse the repository at this point in the history
Summary: Sometimes we use is_action_continuous sometimes we use _is_action_continuous as an attribute of a policy learner. This diff makes it consistent by changing is_action_continuous to _is_action_continuous

Reviewed By: rodrigodesalvobraz

Differential Revision: D65853193

fbshipit-source-id: 847c82ddb31ae46a60da2698b02e232049abf91e
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 13, 2024
1 parent 3ade5e5 commit 9c8db81
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def __init__(
)

# set here so replay_buffer and policy_learner are in sync
self.replay_buffer.is_action_continuous = (
self.policy_learner.is_action_continuous
self.replay_buffer._is_action_continuous = (
self.policy_learner._is_action_continuous
)
self.replay_buffer.device_for_batches = self.device

Expand Down Expand Up @@ -192,7 +192,7 @@ def observe(
terminated=action_result.terminated,
max_number_actions=(
self.policy_learner.action_representation_module.max_number_actions
if not self.policy_learner.is_action_continuous
if not self.policy_learner._is_action_continuous
else None
), # max number of actions for discrete action space
cost=action_result.cost,
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/policy_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
self._batch_size = batch_size
self._training_steps = 0
self.on_policy = on_policy
self.is_action_continuous = is_action_continuous
self._is_action_continuous = is_action_continuous
self.distribution_enabled: bool = is_distribution_enabled()
self.requires_tensors = requires_tensors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(

self._action_dim: int = (
self.action_representation_module.representation_dim
if self.is_action_continuous
if self._is_action_continuous
else self.action_representation_module.max_number_actions
)

Expand Down Expand Up @@ -239,7 +239,7 @@ def act(
# Step 1: compute exploit_action
# (action computed by actor network; and without any exploration)
with torch.no_grad():
if self.is_action_continuous:
if self._is_action_continuous:
exploit_action = self._actor.sample_action(subjective_state)
action_probabilities = None
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
actor_loss = (advantage * loss).mean()

else:
if self.is_action_continuous:
if self._is_action_continuous:
log_action_probabilities = self._actor.get_log_probability(
batch.state, batch.action
).view(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def sample(self, batch_size: int) -> TransitionWithBootstrapMaskBatch:
# pyre-fixme[6]: For 1st argument expected `List[Transition]` but got
# `List[Union[Transition, TransitionBatch]]`.
transitions=samples,
is_action_continuous=self.is_action_continuous,
is_action_continuous=self._is_action_continuous,
)
# pyre-fixme[16]: Item `Transition` of `Union[Transition, TransitionBatch]`
# has no attribute `bootstrap_mask`.
Expand Down

0 comments on commit 9c8db81

Please sign in to comment.