From 2f1670935fd7ad417ec361c3a6ff46c72dc10581 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 20 Nov 2024 18:25:36 -0800 Subject: [PATCH] Prepare for "Fix type-safety of `torch.nn.Module` instances": fbcode/p* Summary: X-link: https://github.com/pytorch/captum/pull/1448 See D52890934 Reviewed By: r-barnes Differential Revision: D66235323 fbshipit-source-id: a8781d76a63bf8003761055d7808190f73dea5e9 --- pearl/neural_networks/common/value_networks.py | 1 + .../q_value_networks.py | 1 + pearl/pearl_agent.py | 3 +++ .../contextual_bandits/disjoint_bandit.py | 1 + .../thompson_sampling_exploration.py | 1 + .../contextual_bandits/ucb_exploration.py | 1 + .../actor_critic_base.py | 8 +++++++- .../sequential_decision_making/ddpg.py | 5 +++++ .../deep_td_learning.py | 6 ++++++ .../implicit_q_learning.py | 7 +++++++ .../sequential_decision_making/ppo.py | 4 ++++ .../quantile_regression_deep_q_learning.py | 2 ++ .../quantile_regression_deep_td_learning.py | 2 ++ .../sequential_decision_making/reinforce.py | 1 + .../soft_actor_critic.py | 11 +++++++++++ .../soft_actor_critic_continuous.py | 18 +++++++++++++++++- .../tabular_q_learning.py | 5 +++++ .../sequential_decision_making/td3.py | 4 ++++ .../reward_constrained_safety_module.py | 12 +++++++++++- .../functional_utils/learning/action_utils.py | 4 ++++ .../functional_utils/learning/critic_utils.py | 6 ++++++ .../unit/with_pytorch/test_disjoint_bandits.py | 1 + test/unit/with_pytorch/test_epinet.py | 2 ++ 23 files changed, 103 insertions(+), 3 deletions(-) diff --git a/pearl/neural_networks/common/value_networks.py b/pearl/neural_networks/common/value_networks.py index a75b341d..597f7829 100644 --- a/pearl/neural_networks/common/value_networks.py +++ b/pearl/neural_networks/common/value_networks.py @@ -52,6 +52,7 @@ def forward(self, x: Tensor) -> Tensor: # default initialization in linear and conv layers of a F.sequential model is Kaiming def xavier_init(self) -> None: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. for layer in self._model: if isinstance(layer, nn.Linear): nn.init.xavier_normal_(layer.weight) diff --git a/pearl/neural_networks/sequential_decision_making/q_value_networks.py b/pearl/neural_networks/sequential_decision_making/q_value_networks.py index 088bdec3..55c3c27a 100644 --- a/pearl/neural_networks/sequential_decision_making/q_value_networks.py +++ b/pearl/neural_networks/sequential_decision_making/q_value_networks.py @@ -594,6 +594,7 @@ def __init__( self._action_dim = action_dim def forward(self, x: Tensor) -> Tensor: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. return self._model(x) def get_q_values( diff --git a/pearl/pearl_agent.py b/pearl/pearl_agent.py index 350e4606..59b4ee69 100644 --- a/pearl/pearl_agent.py +++ b/pearl/pearl_agent.py @@ -92,6 +92,7 @@ def __init__( # adds the safety module to the policy learner as well # @jalaj, we need to follow the practice below for safety module + # pyre-fixme[16]: `PolicyLearner` has no attribute `safety_module`. self.policy_learner.safety_module = self.safety_module self.replay_buffer: ReplayBuffer = ( @@ -190,6 +191,8 @@ def observe( else action_result.available_action_space ), # next_available_actions terminated=action_result.terminated, + # pyre-fixme[6]: For 8th argument expected `Optional[int]` but got + # `Union[None, Tensor, Module]`. max_number_actions=( self.policy_learner.action_representation_module.max_number_actions if not self.policy_learner._is_action_continuous diff --git a/pearl/policy_learners/contextual_bandits/disjoint_bandit.py b/pearl/policy_learners/contextual_bandits/disjoint_bandit.py index 78c83824..b0a745b0 100644 --- a/pearl/policy_learners/contextual_bandits/disjoint_bandit.py +++ b/pearl/policy_learners/contextual_bandits/disjoint_bandit.py @@ -228,6 +228,7 @@ def get_scores( @property def optimizer(self) -> torch.optim.Optimizer: + # pyre-fixme[7]: Expected `Optimizer` but got `Union[Tensor, Module]`. return self._optimizer def set_history_summarization_module( diff --git a/pearl/policy_learners/exploration_modules/contextual_bandits/thompson_sampling_exploration.py b/pearl/policy_learners/exploration_modules/contextual_bandits/thompson_sampling_exploration.py index d78e16ea..113c4318 100644 --- a/pearl/policy_learners/exploration_modules/contextual_bandits/thompson_sampling_exploration.py +++ b/pearl/policy_learners/exploration_modules/contextual_bandits/thompson_sampling_exploration.py @@ -52,6 +52,7 @@ def get_scores( expected_reward = representation(subjective_state) # batch_size, action_count, 1 assert expected_reward.shape == subjective_state.shape[:-1] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. sigma = representation.calculate_sigma(subjective_state) # batch_size, action_count, 1 assert sigma.shape == subjective_state.shape[:-1] diff --git a/pearl/policy_learners/exploration_modules/contextual_bandits/ucb_exploration.py b/pearl/policy_learners/exploration_modules/contextual_bandits/ucb_exploration.py index 50052f61..c9bbcb39 100644 --- a/pearl/policy_learners/exploration_modules/contextual_bandits/ucb_exploration.py +++ b/pearl/policy_learners/exploration_modules/contextual_bandits/ucb_exploration.py @@ -44,6 +44,7 @@ def sigma( Returns: sigma with shape (batch_size, action_count) or (batch_size, 1) """ + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. sigma = representation.calculate_sigma(subjective_state) nan_check = torch.isnan(sigma) sigma = torch.where(nan_check, torch.zeros_like(sigma), sigma) diff --git a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py index adb93d94..d11810b1 100644 --- a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py +++ b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py @@ -238,6 +238,7 @@ def act( # (action computed by actor network; and without any exploration) with torch.no_grad(): if self._is_action_continuous: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. exploit_action = self._actor.sample_action(subjective_state) action_probabilities = None else: @@ -245,6 +246,7 @@ def act( actions = self.action_representation_module( available_action_space.actions_batch ) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. action_probabilities = self._actor.get_policy_distribution( state_batch=subjective_state, available_actions=actions, @@ -267,6 +269,7 @@ def act( ) def reset(self, action_space: ActionSpace) -> None: + # pyre-fixme[16]: `ActorCriticBase` has no attribute `_action_space`. self._action_space = action_space def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: @@ -334,7 +337,10 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch: # change reward to be the lambda_constraint weighted sum of reward and cost if hasattr(self.safety_module, "lambda_constraint"): batch.reward = ( - batch.reward - self.safety_module.lambda_constraint * batch.cost + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `lambda_constraint`. + batch.reward + - self.safety_module.lambda_constraint * batch.cost ) batch = super().preprocess_batch(batch) diff --git a/pearl/policy_learners/sequential_decision_making/ddpg.py b/pearl/policy_learners/sequential_decision_making/ddpg.py index 73be5184..2f867a48 100644 --- a/pearl/policy_learners/sequential_decision_making/ddpg.py +++ b/pearl/policy_learners/sequential_decision_making/ddpg.py @@ -100,9 +100,12 @@ def __init__( def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: # sample a batch of actions from the actor network; shape (batch_size, action_dim) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. action_batch = self._actor.sample_action(batch.state) # obtain q values for (batch.state, action_batch) from critic 1 + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `get_q_values`. q1 = self._critic._critic_1.get_q_values( state_batch=batch.state, action_batch=action_batch ) @@ -116,10 +119,12 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor: with torch.no_grad(): # sample a batch of next actions from target actor network; + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_action = self._actor_target.sample_action(batch.next_state) # (batch_size, action_dim) # get q values of (batch.next_state, next_action) from targets of twin critic + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_q1, next_q2 = self._critic_target.get_q_values( state_batch=batch.next_state, action_batch=next_action, diff --git a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py index c95e60d7..5018061d 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py @@ -134,6 +134,8 @@ def make_specified_network() -> QValueNetwork: if network_type is TwoTowerQValueNetwork: return network_type( state_dim=state_dim, + # pyre-fixme[6]: For 2nd argument expected `int` but got + # `Union[Tensor, Module]`. action_dim=self._action_representation_module.representation_dim, hidden_dims=hidden_dims, state_output_dim=state_output_dim, @@ -149,6 +151,8 @@ def make_specified_network() -> QValueNetwork: ) return network_type( state_dim=state_dim, + # pyre-fixme[6]: For 2nd argument expected `int` but got + # `Union[Tensor, Module]`. action_dim=self._action_representation_module.representation_dim, hidden_dims=hidden_dims, output_dim=1, @@ -289,6 +293,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: # Conservative TD updates for offline learning. if self._is_conservative: cql_loss = compute_cql_loss(self._Q, batch, batch_size) + # pyre-fixme[58]: `*` is not supported for operand types + # `Optional[float]` and `Tensor`. loss = self._conservative_alpha * cql_loss + bellman_loss else: loss = bellman_loss diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index cd70c6e0..7cb373ae 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -167,7 +167,11 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: self._history_summarization_optimizer.step() # update critic and target Twin networks; update_target_networks( + # pyre-fixme[6]: For 1st argument expected `Union[List[Module], + # ModuleList]` but got `Union[Tensor, Module]`. self._critic_target._critic_networks_combined, + # pyre-fixme[6]: For 2nd argument expected `Union[List[Module], + # ModuleList]` but got `Union[Tensor, Module]`. self._critic._critic_networks_combined, self._critic_soft_update_tau, ) @@ -181,6 +185,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: def _value_loss(self, batch: TransitionBatch) -> torch.Tensor: with torch.no_grad(): + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. q1, q2 = self._critic_target.get_q_values(batch.state, batch.action) # random ensemble distillation. random_index = torch.randint(0, 2, (1,)).item() @@ -197,6 +202,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: Performs policy extraction using advantage weighted regression """ with torch.no_grad(): + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. q1, q2 = self._critic_target.get_q_values(batch.state, batch.action) # random ensemble distillation. random_index = torch.randint(0, 2, (1,)).item() @@ -226,6 +232,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: else: if self._is_action_continuous: + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. log_action_probabilities = self._actor.get_log_probability( batch.state, batch.action ).view(-1) diff --git a/pearl/policy_learners/sequential_decision_making/ppo.py b/pearl/policy_learners/sequential_decision_making/ppo.py index e63d182f..2f3097b4 100644 --- a/pearl/policy_learners/sequential_decision_making/ppo.py +++ b/pearl/policy_learners/sequential_decision_making/ppo.py @@ -152,6 +152,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: # TODO need to support continuous action # TODO: change the output shape of value networks assert isinstance(batch, PPOTransitionBatch) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. action_probs = self._actor.get_action_prob( state_batch=batch.state, action_batch=batch.action, @@ -167,6 +168,8 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: clip = torch.clamp( r_thelta, min=1.0 - self._epsilon, max=1.0 + self._epsilon ) # shape (batch_size) + # pyre-fixme[58]: `*` is not supported for operand types `Tensor` and + # `Optional[Tensor]`. loss = torch.sum(-torch.min(r_thelta * batch.gae, clip * batch.gae)) # entropy entropy: torch.Tensor = torch.distributions.Categorical( @@ -236,6 +239,7 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None: state_values = self._critic(history_summary_batch).detach() action_probs = ( + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. self._actor.get_action_prob( state_batch=history_summary_batch, action_batch=action_representation_batch, diff --git a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py index e8a2b587..fc771e6a 100644 --- a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py @@ -118,6 +118,8 @@ def _get_next_state_quantiles( # get q values from a q value distribution under a risk metric # instead of using the 'get_q_values' method of the QuantileQValueNetwork, # we invoke a method from the risk sensitive safety module + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `get_q_values_under_risk_metric`. next_state_action_values = self.safety_module.get_q_values_under_risk_metric( next_state_batch_repeated, next_available_actions_batch, self._Q_target ).view( diff --git a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py index 96ccb263..372cdd48 100644 --- a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py @@ -146,6 +146,8 @@ def act( # instead of using the 'get_q_values' method of the QuantileQValueNetwork, # we invoke a method from the risk sensitive safety module + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `get_q_values_under_risk_metric`. q_values = self.safety_module.get_q_values_under_risk_metric( states_repeated, actions, self._Q ) diff --git a/pearl/policy_learners/sequential_decision_making/reinforce.py b/pearl/policy_learners/sequential_decision_making/reinforce.py index e3741fdd..ab964a82 100644 --- a/pearl/policy_learners/sequential_decision_making/reinforce.py +++ b/pearl/policy_learners/sequential_decision_making/reinforce.py @@ -145,6 +145,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: batch.state ) # (batch_size x state_dim) note that here batch_size = episode length return_batch = batch.cum_reward # (batch_size) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. policy_propensities = self._actor.get_action_prob( batch.state, batch.action, diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py index 5a2f2bc2..49b200e4 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py @@ -113,6 +113,7 @@ def __init__( # sac uses a learning rate scheduler specifically def reset(self, action_space: ActionSpace) -> None: + # pyre-fixme[16]: `SoftActorCritic` has no attribute `_action_space`. self._action_space = action_space self.scheduler.step() @@ -152,11 +153,14 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso assert next_available_actions_batch is not None next_state_batch_repeated = torch.repeat_interleave( next_state_batch.unsqueeze(1), + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.action_representation_module.max_number_actions, dim=1, ) # (batch_size x action_space_size x state_dim) # get q values of (states, all actions) from twin critics + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_q1, next_q2 = self._critic_target.get_q_values( state_batch=next_state_batch_repeated, action_batch=next_available_actions_batch, @@ -179,6 +183,7 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso if next_unavailable_actions_mask_batch is not None: next_state_action_values[next_unavailable_actions_mask_batch] = 0.0 + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_state_policy_dist = self._actor.get_policy_distribution( state_batch=next_state_batch, available_actions=next_available_actions_batch, @@ -197,6 +202,8 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: state_batch = batch.state # (batch_size x state_dim) state_batch_repeated = torch.repeat_interleave( state_batch.unsqueeze(1), + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.action_representation_module.max_number_actions, dim=1, ) # (batch_size x action_space_size x state_dim) @@ -206,6 +213,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: ) # (batch_size x action_space_size x action_dim) # get q values of (states, all actions) from twin critics + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. q1, q2 = self._critic.get_q_values( state_batch=state_batch_repeated, action_batch=available_actions ) @@ -216,6 +224,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: batch.curr_unavailable_actions_mask ) # (batch_size x action_space_size) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. new_policy_dist = self._actor.get_policy_distribution( state_batch=state_batch, available_actions=available_actions, @@ -223,6 +232,8 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: ) # (batch_size x action_space_size) state_action_values = q.view( + # pyre-fixme[6]: For 1st argument expected `dtype` but got `Tuple[int, + # Union[Module, Tensor]]`. (state_batch.shape[0], self.action_representation_module.max_number_actions) ) # (batch_size x action_space_size) diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py index cb593966..d5440f2f 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py @@ -104,8 +104,15 @@ def __init__( torch.nn.Parameter(torch.zeros(1, requires_grad=True)), ) self._entropy_optimizer: torch.optim.Optimizer = optim.AdamW( - [self._log_entropy], lr=critic_learning_rate, amsgrad=True + # pyre-fixme[6]: For 1st argument expected `Union[Iterable[Dict[str, + # Any]], Iterable[Tuple[str, Tensor]], Iterable[Tensor]]` but got + # `List[Union[Module, Tensor]]`. + [self._log_entropy], + lr=critic_learning_rate, + amsgrad=True, ) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.register_buffer("_entropy_coef", torch.exp(self._log_entropy).detach()) assert isinstance(action_space, BoxSpace) self.register_buffer( @@ -120,11 +127,14 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: if self._entropy_autotune: with torch.no_grad(): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. _, action_batch_log_prob = self._actor.sample_action( state_batch, get_log_prob=True ) entropy_optimizer_loss = ( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. -torch.exp(self._log_entropy) * (action_batch_log_prob + self._target_entropy) ).mean() @@ -133,6 +143,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: entropy_optimizer_loss.backward() self._entropy_optimizer.step() + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. self._entropy_coef = torch.exp(self._log_entropy).detach() {**actor_critic_loss, **{"entropy_coef": entropy_optimizer_loss}} @@ -171,8 +183,10 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso ( next_action_batch, next_action_batch_log_prob, + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. ) = self._actor.sample_action(next_state_batch, get_log_prob=True) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_q1, next_q2 = self._critic_target.get_q_values( state_batch=next_state_batch, action_batch=next_action_batch, @@ -198,8 +212,10 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: ( action_batch, action_batch_log_prob, + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. ) = self._actor.sample_action(state_batch, get_log_prob=True) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. q1, q2 = self._critic.get_q_values( state_batch=state_batch, action_batch=action_batch ) # shape: (batch_size, 1) diff --git a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py index 3e532edb..30f2677a 100644 --- a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py @@ -67,7 +67,10 @@ def __init__( self.debug: bool = debug def reset(self, action_space: ActionSpace) -> None: + # pyre-fixme[16]: `TabularQLearning` has no attribute `_action_space`. self._action_space = action_space + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. for i, action in enumerate(self._action_space): if int(action.item()) != i: raise ValueError( @@ -140,6 +143,8 @@ def learn( old_q_value = self.q_values.get((state, action.item()), 0) next_q_values = [ self.q_values.get((next_state, next_action.item()), 0) + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not + # a function. for next_action in self._action_space ] diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index 5a0aaf8c..3286dd0b 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -137,6 +137,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor: with torch.no_grad(): # sample next_action from actor's target network; shape (batch_size, action_dim) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_action = self._actor_target.sample_action(batch.next_state) # sample clipped gaussian noise @@ -164,6 +165,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor: ) # shape (batch_size, action_dim) # sample q values of (next_state, next_action) from targets of critics + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_q1, next_q2 = self._critic_target.get_q_values( state_batch=batch.next_state, action_batch=next_action, @@ -249,9 +251,11 @@ def __init__( def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: # sample a batch of actions from the actor network; shape (batch_size, action_dim) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. action_batch = self._actor.sample_action(batch.state) # samples q values for (batch.state, action_batch) from twin critics + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. q, _ = self._critic.get_q_values( state_batch=batch.state, action_batch=action_batch ) diff --git a/pearl/safety_modules/reward_constrained_safety_module.py b/pearl/safety_modules/reward_constrained_safety_module.py index 2b9bd9cc..851c9315 100644 --- a/pearl/safety_modules/reward_constrained_safety_module.py +++ b/pearl/safety_modules/reward_constrained_safety_module.py @@ -137,8 +137,11 @@ def constraint_lambda_update( """ with torch.no_grad(): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. cost_q1, cost_q2 = cost_critic.get_q_values( state_batch=batch.state, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sample_action`. action_batch=policy_learner._actor.sample_action(batch.state), ) cost_q = torch.maximum(cost_q1, cost_q2) @@ -158,9 +161,12 @@ def cost_critic_learn_batch( with torch.no_grad(): # sample next_action from actor's target network; shape (batch_size, action_dim) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sample_action`. next_action = policy_learner._actor.sample_action(batch.next_state) # sample q values of (next_state, next_action) from targets of critics + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. next_q1, next_q2 = self.target_of_cost_critic.get_q_values( state_batch=batch.next_state, action_batch=next_action, @@ -173,7 +179,11 @@ def cost_critic_learn_batch( # cost + gamma * (min{Qtarget_1(s', a from target actor network), # Qtarget_2(s', a from target actor network)}) expected_state_action_values = ( - next_q * self.cost_discount_factor * (1 - batch.terminated.float()) + next_q + * self.cost_discount_factor + * (1 - batch.terminated.float()) + # pyre-fixme[58]: `+` is not supported for operand types `Tensor` and + # `Optional[Tensor]`. ) + batch.cost # (batch_size) # update twin critics towards bellman target diff --git a/pearl/utils/functional_utils/learning/action_utils.py b/pearl/utils/functional_utils/learning/action_utils.py index 05f325b9..3c4cf4f7 100644 --- a/pearl/utils/functional_utils/learning/action_utils.py +++ b/pearl/utils/functional_utils/learning/action_utils.py @@ -136,8 +136,12 @@ def concatenate_actions_to_state( # (batch_size, action_count, state_dim + action_dim) new_feature = torch.cat([expanded_state, expanded_action], dim=2) torch._assert( + # pyre-fixme[58]: `+` is not supported for operand types `int` and + # `Union[int, Tensor, Module]`. new_feature.shape == (batch_size, action_count, state_dim + action_dim), "The shape of the concatenated feature is wrong. Expected " + # pyre-fixme[58]: `+` is not supported for operand types `int` and + # `Union[int, Tensor, Module]`. f"{(batch_size, action_count, state_dim + action_dim)}, got {new_feature.shape}", ) return new_feature.to(subjective_state.device) diff --git a/pearl/utils/functional_utils/learning/critic_utils.py b/pearl/utils/functional_utils/learning/critic_utils.py index 0f511c29..43718140 100644 --- a/pearl/utils/functional_utils/learning/critic_utils.py +++ b/pearl/utils/functional_utils/learning/critic_utils.py @@ -121,16 +121,22 @@ def update_critic_target_network( if isinstance(target_network, TwinCritic): update_target_networks( target_network._critic_networks_combined, + # pyre-fixme[6]: For 2nd argument expected `Union[List[Module], + # ModuleList]` but got `Union[Module, Tensor]`. network._critic_networks_combined, tau=tau, ) else: update_target_network( ( + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Union[Module, Tensor]`. target_network._model if hasattr(target_network, "_model") else target_network ), + # pyre-fixme[6]: For 2nd argument expected `Module` but got + # `Union[Module, Tensor]`. network._model if hasattr(network, "_model") else network, tau=tau, ) diff --git a/test/unit/with_pytorch/test_disjoint_bandits.py b/test/unit/with_pytorch/test_disjoint_bandits.py index 25c381c3..ca95d360 100644 --- a/test/unit/with_pytorch/test_disjoint_bandits.py +++ b/test/unit/with_pytorch/test_disjoint_bandits.py @@ -398,6 +398,7 @@ def test_get_scores(self) -> None: for i in range(self.action_space.n): model = policy_learner.models[i] # model for arm i mus = model(features) + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. sigmas = model.calculate_sigma(features) expected_scores.append(mus + alpha * sigmas) expected_scores = torch.cat(expected_scores, dim=1) diff --git a/test/unit/with_pytorch/test_epinet.py b/test/unit/with_pytorch/test_epinet.py index 00e654bf..4544062e 100644 --- a/test/unit/with_pytorch/test_epinet.py +++ b/test/unit/with_pytorch/test_epinet.py @@ -133,6 +133,7 @@ def test_priornet_constant(self) -> None: init_priornet_weight = copy.deepcopy( self.network.priornet.models[0][0][0].weight ) + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. init_epinet_weight = copy.deepcopy(self.network.epinet[0][0].weight) for _ in range(self.num_epochs): @@ -159,6 +160,7 @@ def test_priornet_constant(self) -> None: final_priornet_weight = copy.deepcopy( self.network.priornet.models[0][0][0].weight ) + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. final_epinet_weight = copy.deepcopy(self.network.epinet[0][0].weight) tt.assert_close( final_priornet_weight,