Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (5/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447730

fbshipit-source-id: 85ed104b2f8f5e26ae0dea9ee17392ecad8b9407
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent f8b1814 commit d2e79d9
Show file tree
Hide file tree
Showing 24 changed files with 112 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def __init__(
# advantage architecture
self.advantage_arch = VanillaValueNetwork(
input_dim=hidden_dims[-1] + action_dim, # state_arch out dim + action_dim
hidden_dims=hidden_dims
if advantage_hidden_dims is None
else advantage_hidden_dims,
hidden_dims=(
hidden_dims if advantage_hidden_dims is None else advantage_hidden_dims
),
output_dim=output_dim, # output_dim=1
)

Expand Down Expand Up @@ -481,12 +481,12 @@ def __init__(
super().__init__(
state_input_dim=state_dim,
action_input_dim=action_dim,
state_output_dim=state_dim
if state_output_dim is None
else state_output_dim,
action_output_dim=action_dim
if action_output_dim is None
else action_output_dim,
state_output_dim=(
state_dim if state_output_dim is None else state_output_dim
),
action_output_dim=(
action_dim if action_output_dim is None else action_output_dim
),
state_hidden_dims=[] if state_hidden_dims is None else state_hidden_dims,
action_hidden_dims=[] if action_hidden_dims is None else action_hidden_dims,
hidden_dims=hidden_dims,
Expand Down
16 changes: 10 additions & 6 deletions pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,17 @@ def observe(
# pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors
next_state=new_history,
curr_available_actions=self._action_space, # curr_available_actions
next_available_actions=self._action_space
if action_result.available_action_space is None
else action_result.available_action_space, # next_available_actions
next_available_actions=(
self._action_space
if action_result.available_action_space is None
else action_result.available_action_space
), # next_available_actions
done=action_result.done,
max_number_actions=self.policy_learner.action_representation_module.max_number_actions
if not self.policy_learner.is_action_continuous
else None, # max number of actions for discrete action space
max_number_actions=(
self.policy_learner.action_representation_module.max_number_actions
if not self.policy_learner.is_action_continuous
else None
), # max number of actions for discrete action space
cost=action_result.cost,
)

Expand Down
8 changes: 5 additions & 3 deletions pearl/policy_learners/contextual_bandits/disjoint_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def _partition_batch_by_arm(self, batch: TransitionBatch) -> List[TransitionBatc
TransitionBatch(
state=state[mask],
reward=batch.reward[mask],
weight=batch.weight[mask]
if batch.weight is not None
else torch.ones_like(mask, dtype=torch.float),
weight=(
batch.weight[mask]
if batch.weight is not None
else torch.ones_like(mask, dtype=torch.float)
),
# empty action features since disjoint model used
# action as index of per-arm model
# if arms need different features, use 3D `state` instead
Expand Down
8 changes: 5 additions & 3 deletions pearl/policy_learners/policy_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def __init__(
# the agent does not need dynamic action space support.
self._action_representation_module: ActionRepresentationModule = (
IdentityActionRepresentationModule(
max_number_actions=action_space.n
if isinstance(action_space, DiscreteActionSpace)
else -1,
max_number_actions=(
action_space.n
if isinstance(action_space, DiscreteActionSpace)
else -1
),
representation_dim=action_space.action_dim,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,17 @@ def __init__(

# actor network takes state as input and outputs an action vector
self._actor: nn.Module = actor_network_type(
input_dim=state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim,
input_dim=(
state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim
),
hidden_dims=actor_hidden_dims,
output_dim=1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim,
output_dim=(
1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim
),
action_space=action_space,
)
self._actor.apply(init_weights)
Expand All @@ -132,13 +136,17 @@ def __init__(
self._actor_soft_update_tau = actor_soft_update_tau
if self._use_actor_target:
self._actor_target: nn.Module = actor_network_type(
input_dim=state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim,
input_dim=(
state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim
),
hidden_dims=actor_hidden_dims,
output_dim=1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim,
output_dim=(
1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim
),
action_space=action_space,
)
update_target_network(self._actor_target, self._actor, tau=1)
Expand Down
8 changes: 5 additions & 3 deletions pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def __init__(
actor_soft_update_tau=actor_soft_update_tau,
critic_soft_update_tau=critic_soft_update_tau,
use_twin_critic=True, # we need to make this optional to users
exploration_module=exploration_module
if exploration_module is not None
else NormalDistributionExploration(mean=0.0, std_dev=0.1),
exploration_module=(
exploration_module
if exploration_module is not None
else NormalDistributionExploration(mean=0.0, std_dev=0.1)
),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ class and uses `act` and `learn_batch` methods of that class. We only implement
"""

super(DeepQLearning, self).__init__(
exploration_module=exploration_module
if exploration_module is not None
else EGreedyExploration(0.05),
exploration_module=(
exploration_module
if exploration_module is not None
else EGreedyExploration(0.05)
),
on_policy=False,
state_dim=state_dim,
action_space=action_space,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def __init__(
super(DeepSARSA, self).__init__(
state_dim=state_dim,
action_space=action_space,
exploration_module=exploration_module
if exploration_module is not None
else EGreedyExploration(0.05),
exploration_module=(
exploration_module
if exploration_module is not None
else EGreedyExploration(0.05)
),
on_policy=True,
action_representation_module=action_representation_module,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def __init__(
use_critic_target=True,
critic_soft_update_tau=critic_soft_update_tau,
use_twin_critic=True,
exploration_module=exploration_module
if exploration_module is not None
else NoExploration(),
exploration_module=(
exploration_module
if exploration_module is not None
else NoExploration()
),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down
8 changes: 5 additions & 3 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ def __init__(
actor_soft_update_tau=0.0, # not used
critic_soft_update_tau=0.0, # not used
use_twin_critic=False,
exploration_module=exploration_module
if exploration_module is not None
else PropensityExploration(),
exploration_module=(
exploration_module
if exploration_module is not None
else PropensityExploration()
),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def __init__(
state_dim=state_dim,
action_space=action_space,
on_policy=on_policy,
exploration_module=exploration_module
if exploration_module is not None
else EGreedyExploration(0.10),
exploration_module=(
exploration_module
if exploration_module is not None
else EGreedyExploration(0.10)
),
hidden_dims=hidden_dims,
num_quantiles=num_quantiles,
learning_rate=learning_rate,
Expand Down
8 changes: 5 additions & 3 deletions pearl/policy_learners/sequential_decision_making/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def __init__(
actor_soft_update_tau=0.0, # not used
critic_soft_update_tau=0.0, # not used
use_twin_critic=False,
exploration_module=exploration_module
if exploration_module is not None
else PropensityExploration(),
exploration_module=(
exploration_module
if exploration_module is not None
else PropensityExploration()
),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# Currently available actions is not used. Needs to be updated once we know the input
# structure of production stack on this param.


# TODO: to make things easier with a single optimizer, we need to polish this method.
class SoftActorCritic(ActorCriticBase):
"""
Expand Down Expand Up @@ -78,9 +79,11 @@ def __init__(
actor_soft_update_tau=0.0, # not used
critic_soft_update_tau=critic_soft_update_tau,
use_twin_critic=True,
exploration_module=exploration_module
if exploration_module is not None
else PropensityExploration(),
exploration_module=(
exploration_module
if exploration_module is not None
else PropensityExploration()
),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def __init__(
actor_soft_update_tau=0.0,
critic_soft_update_tau=critic_soft_update_tau,
use_twin_critic=True,
exploration_module=exploration_module
if exploration_module is not None
else NoExploration(),
exploration_module=(
exploration_module
if exploration_module is not None
else NoExploration()
),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down
6 changes: 3 additions & 3 deletions pearl/replay_buffers/tensor_based_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def _create_action_tensor_and_mask(
dtype=torch.float32,
) # (1 x action_space_size x action_dim)
available_actions_tensor = available_action_space.actions_batch
available_actions_tensor_with_padding[
0, : available_action_space.n, :
] = available_actions_tensor
available_actions_tensor_with_padding[0, : available_action_space.n, :] = (
available_actions_tensor
)

unavailable_actions_mask = torch.zeros(
(1, max_number_actions), device=self._device
Expand Down
1 change: 0 additions & 1 deletion pearl/safety_modules/reward_constrained_safety_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def constraint_lambda_update(
policy_learner: PolicyLearner,
cost_critic: nn.Module,
) -> None:

"""
Update the lambda constraint based on the cost critic via a projected gradient descent update rule.
"""
Expand Down
1 change: 0 additions & 1 deletion pearl/safety_modules/risk_sensitive_safety_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def get_q_values_under_risk_metric(
action_batch: Tensor,
q_value_distribution_network: DistributionalQValueNetwork,
) -> torch.Tensor:

"""Returns Q(s, a), given s and a
Args:
state_batch: a batch of state tensors (batch_size, state_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def create_offline_data(
evaluation_episodes: int = 100,
seed: Optional[int] = None,
) -> List[Value]:

"""
This function creates offline data by interacting with a given environment using a specified
Pearl Agent. This is mostly for illustration with standard benchmark environments only.
Expand Down Expand Up @@ -73,9 +72,11 @@ def create_offline_data(
"curr_available_actions": env.action_space,
"next_available_actions": env.action_space,
"done": action_result.done,
"max_number_actions": env.action_space.n
if isinstance(env.action_space, DiscreteActionSpace)
else None,
"max_number_actions": (
env.action_space.n
if isinstance(env.action_space, DiscreteActionSpace)
else None
),
}

observation = action_result.observation
Expand Down Expand Up @@ -146,7 +147,6 @@ def get_data_collection_agent_returns(
data_path: str,
returns_file_path: Optional[str] = None,
) -> List[Value]:

"""
This function returns episode returns of a Pearl Agent using for offline data collection.
The returns file can be directly provided or we can stitch together trajectories in the offline
Expand Down
1 change: 0 additions & 1 deletion pearl/utils/functional_utils/learning/loss_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
def compute_cql_loss(
q_network: QValueNetwork, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:

"""
Compute CQL loss for a batch of data.
Expand Down
2 changes: 1 addition & 1 deletion pearl/utils/instantiations/environments/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def tensor_to_numpy(x: Tensor) -> np.ndarray:

GYM_TO_PEARL_ACTION_SPACE = {
"Discrete": DiscreteActionSpace,
"Box": BoxActionSpace
"Box": BoxActionSpace,
# Add more here as needed
}
GYM_TO_PEARL_OBSERVATION_SPACE = {
Expand Down
6 changes: 3 additions & 3 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def evaluate_single(
method["history_summarization_module"].__name__
== "LSTMHistorySummarizationModule"
):
method["history_summarization_module_args"][
"observation_dim"
] = env.observation_space.shape[0]
method["history_summarization_module_args"]["observation_dim"] = (
env.observation_space.shape[0]
)
method["history_summarization_module_args"]["action_dim"] = (
policy_learner_args["action_representation_module"].representation_dim
if "action_representation_module" in policy_learner_args
Expand Down
2 changes: 0 additions & 2 deletions pearl/utils/scripts/benchmark_offline_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def get_random_agent_returns(
evaluation_episodes: int = 500,
seed: Optional[int] = None,
) -> List[float]:

"""
This function returns a list of episode returns of a Pearl Agent interacting with the input
environment using a randomly instantiated policy learner. This is needed to compute
Expand Down Expand Up @@ -128,7 +127,6 @@ def evaluate_offline_rl(
data_size: int = 1000000,
seed: Optional[int] = None,
) -> List[float]:

"""
This function trains and evaluates an offline RL agent on the given environment. Training data
can be provided through a url or by specifying a local file path. If neither are provided,
Expand Down
8 changes: 5 additions & 3 deletions pearl/utils/scripts/figure_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@

def moving_average(data: List[Value]) -> Value:
return [
sum(data[int(i - MA_WINDOW_SIZE + 1) : i + 1]) / MA_WINDOW_SIZE # pyre-ignore
if i >= MA_WINDOW_SIZE
else sum(data[: i + 1]) * 1.0 / (i + 1) # pyre-ignore
(
sum(data[int(i - MA_WINDOW_SIZE + 1) : i + 1]) / MA_WINDOW_SIZE # pyre-ignore
if i >= MA_WINDOW_SIZE
else sum(data[: i + 1]) * 1.0 / (i + 1)
) # pyre-ignore
for i in range(len(data))
]

Expand Down
Loading

0 comments on commit d2e79d9

Please sign in to comment.