diff --git a/pearl/neural_networks/sequential_decision_making/actor_networks.py b/pearl/neural_networks/sequential_decision_making/actor_networks.py index 21c05419..be614a5b 100644 --- a/pearl/neural_networks/sequential_decision_making/actor_networks.py +++ b/pearl/neural_networks/sequential_decision_making/actor_networks.py @@ -50,6 +50,20 @@ def action_scaling( return centered_and_scaled_action +def action_unscaling( + action_space: ActionSpace, input_action: torch.Tensor +) -> torch.Tensor: + """ + The reverse operation of action_scaling + """ + assert isinstance(action_space, BoxActionSpace) + device = input_action.device + low = action_space.low.clone().detach().to(device) + high = action_space.high.clone().detach().to(device) + unscaled_action = (((input_action - low) / (high - low)) * 2) - 1 + return unscaled_action + + def noise_scaling(action_space: ActionSpace, input_noise: torch.Tensor) -> torch.Tensor: """ This function rescales any input vector from [-1, 1]^{action_dim} to [low, high]^{action_dim}. @@ -473,16 +487,17 @@ def get_log_probability( std = log_std.exp() normal = Normal(mean, std) - # assume that input actions are in [-1, 1]^d - # TODO: change this to add a transform for unscaling and uncentering depending on the - # action space - unscaled_action_batch = torch.clip(action_batch, -1 + epsilon, 1 - epsilon) + normalized_action_batch = torch.clip( + action_unscaling(self._action_space, action_batch), + -1 + epsilon, + 1 - epsilon, + ) # transform actions from [-1, 1]^d to [-inf, inf]^d - unnormalized_action_batch = torch.atanh(unscaled_action_batch) + unnormalized_action_batch = torch.atanh(normalized_action_batch) log_prob = normal.log_prob(unnormalized_action_batch) log_prob -= torch.log( - self._action_bound * (1 - unscaled_action_batch.pow(2)) + epsilon + self._action_bound * (1 - normalized_action_batch.pow(2)) + epsilon ) # for multi-dimensional action space, sum log probabilities over individual