Skip to content

Commit

Permalink
fix a bug in GaussianActorNetwork
Browse files Browse the repository at this point in the history
Summary: Currently GaussianActorNetwork assumes that actions are between -1 and 1. This diff fixes this issue.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65855658

fbshipit-source-id: b08cfa447386d56269f887572a45ed2ef320b644
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 13, 2024
1 parent ff2b13c commit 4687b47
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions pearl/neural_networks/sequential_decision_making/actor_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ def action_scaling(
return centered_and_scaled_action


def action_unscaling(
action_space: ActionSpace, input_action: torch.Tensor
) -> torch.Tensor:
"""
The reverse operation of action_scaling
"""
assert isinstance(action_space, BoxActionSpace)
device = input_action.device
low = action_space.low.clone().detach().to(device)
high = action_space.high.clone().detach().to(device)
unscaled_action = (((input_action - low) / (high - low)) * 2) - 1
return unscaled_action


def noise_scaling(action_space: ActionSpace, input_noise: torch.Tensor) -> torch.Tensor:
"""
This function rescales any input vector from [-1, 1]^{action_dim} to [low, high]^{action_dim}.
Expand Down Expand Up @@ -473,16 +487,17 @@ def get_log_probability(
std = log_std.exp()
normal = Normal(mean, std)

# assume that input actions are in [-1, 1]^d
# TODO: change this to add a transform for unscaling and uncentering depending on the
# action space
unscaled_action_batch = torch.clip(action_batch, -1 + epsilon, 1 - epsilon)
normalized_action_batch = torch.clip(
action_unscaling(self._action_space, action_batch),
-1 + epsilon,
1 - epsilon,
)

# transform actions from [-1, 1]^d to [-inf, inf]^d
unnormalized_action_batch = torch.atanh(unscaled_action_batch)
unnormalized_action_batch = torch.atanh(normalized_action_batch)
log_prob = normal.log_prob(unnormalized_action_batch)
log_prob -= torch.log(
self._action_bound * (1 - unscaled_action_batch.pow(2)) + epsilon
self._action_bound * (1 - normalized_action_batch.pow(2)) + epsilon
)

# for multi-dimensional action space, sum log probabilities over individual
Expand Down

0 comments on commit 4687b47

Please sign in to comment.