Skip to content

Commit

Permalink
Prepare for "Fix type-safety of torch.nn.Module instances": fbcode/p*
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/captum#1448

See D52890934

Reviewed By: r-barnes

Differential Revision: D66235323

fbshipit-source-id: a8781d76a63bf8003761055d7808190f73dea5e9
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 21, 2024
1 parent fe19943 commit 2f16709
Show file tree
Hide file tree
Showing 23 changed files with 103 additions and 3 deletions.
1 change: 1 addition & 0 deletions pearl/neural_networks/common/value_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def forward(self, x: Tensor) -> Tensor:

# default initialization in linear and conv layers of a F.sequential model is Kaiming
def xavier_init(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for layer in self._model:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ def __init__(
self._action_dim = action_dim

def forward(self, x: Tensor) -> Tensor:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
return self._model(x)

def get_q_values(
Expand Down
3 changes: 3 additions & 0 deletions pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(

# adds the safety module to the policy learner as well
# @jalaj, we need to follow the practice below for safety module
# pyre-fixme[16]: `PolicyLearner` has no attribute `safety_module`.
self.policy_learner.safety_module = self.safety_module

self.replay_buffer: ReplayBuffer = (
Expand Down Expand Up @@ -190,6 +191,8 @@ def observe(
else action_result.available_action_space
), # next_available_actions
terminated=action_result.terminated,
# pyre-fixme[6]: For 8th argument expected `Optional[int]` but got
# `Union[None, Tensor, Module]`.
max_number_actions=(
self.policy_learner.action_representation_module.max_number_actions
if not self.policy_learner._is_action_continuous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def get_scores(

@property
def optimizer(self) -> torch.optim.Optimizer:
# pyre-fixme[7]: Expected `Optimizer` but got `Union[Tensor, Module]`.
return self._optimizer

def set_history_summarization_module(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_scores(
expected_reward = representation(subjective_state)
# batch_size, action_count, 1
assert expected_reward.shape == subjective_state.shape[:-1]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
sigma = representation.calculate_sigma(subjective_state)
# batch_size, action_count, 1
assert sigma.shape == subjective_state.shape[:-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def sigma(
Returns:
sigma with shape (batch_size, action_count) or (batch_size, 1)
"""
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
sigma = representation.calculate_sigma(subjective_state)
nan_check = torch.isnan(sigma)
sigma = torch.where(nan_check, torch.zeros_like(sigma), sigma)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,15 @@ def act(
# (action computed by actor network; and without any exploration)
with torch.no_grad():
if self._is_action_continuous:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
exploit_action = self._actor.sample_action(subjective_state)
action_probabilities = None
else:
assert isinstance(available_action_space, DiscreteActionSpace)
actions = self.action_representation_module(
available_action_space.actions_batch
)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
action_probabilities = self._actor.get_policy_distribution(
state_batch=subjective_state,
available_actions=actions,
Expand All @@ -267,6 +269,7 @@ def act(
)

def reset(self, action_space: ActionSpace) -> None:
# pyre-fixme[16]: `ActorCriticBase` has no attribute `_action_space`.
self._action_space = action_space

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
Expand Down Expand Up @@ -334,7 +337,10 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
# change reward to be the lambda_constraint weighted sum of reward and cost
if hasattr(self.safety_module, "lambda_constraint"):
batch.reward = (
batch.reward - self.safety_module.lambda_constraint * batch.cost
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
# attribute `lambda_constraint`.
batch.reward
- self.safety_module.lambda_constraint * batch.cost
)
batch = super().preprocess_batch(batch)

Expand Down
5 changes: 5 additions & 0 deletions pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ def __init__(
def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

# sample a batch of actions from the actor network; shape (batch_size, action_dim)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
action_batch = self._actor.sample_action(batch.state)

# obtain q values for (batch.state, action_batch) from critic 1
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_q_values`.
q1 = self._critic._critic_1.get_q_values(
state_batch=batch.state, action_batch=action_batch
)
Expand All @@ -116,10 +119,12 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

with torch.no_grad():
# sample a batch of next actions from target actor network;
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
next_action = self._actor_target.sample_action(batch.next_state)
# (batch_size, action_dim)

# get q values of (batch.next_state, next_action) from targets of twin critic
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
next_q1, next_q2 = self._critic_target.get_q_values(
state_batch=batch.next_state,
action_batch=next_action,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def make_specified_network() -> QValueNetwork:
if network_type is TwoTowerQValueNetwork:
return network_type(
state_dim=state_dim,
# pyre-fixme[6]: For 2nd argument expected `int` but got
# `Union[Tensor, Module]`.
action_dim=self._action_representation_module.representation_dim,
hidden_dims=hidden_dims,
state_output_dim=state_output_dim,
Expand All @@ -149,6 +151,8 @@ def make_specified_network() -> QValueNetwork:
)
return network_type(
state_dim=state_dim,
# pyre-fixme[6]: For 2nd argument expected `int` but got
# `Union[Tensor, Module]`.
action_dim=self._action_representation_module.representation_dim,
hidden_dims=hidden_dims,
output_dim=1,
Expand Down Expand Up @@ -289,6 +293,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
# Conservative TD updates for offline learning.
if self._is_conservative:
cql_loss = compute_cql_loss(self._Q, batch, batch_size)
# pyre-fixme[58]: `*` is not supported for operand types
# `Optional[float]` and `Tensor`.
loss = self._conservative_alpha * cql_loss + bellman_loss
else:
loss = bellman_loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
self._history_summarization_optimizer.step()
# update critic and target Twin networks;
update_target_networks(
# pyre-fixme[6]: For 1st argument expected `Union[List[Module],
# ModuleList]` but got `Union[Tensor, Module]`.
self._critic_target._critic_networks_combined,
# pyre-fixme[6]: For 2nd argument expected `Union[List[Module],
# ModuleList]` but got `Union[Tensor, Module]`.
self._critic._critic_networks_combined,
self._critic_soft_update_tau,
)
Expand All @@ -181,6 +185,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
def _value_loss(self, batch: TransitionBatch) -> torch.Tensor:

with torch.no_grad():
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
q1, q2 = self._critic_target.get_q_values(batch.state, batch.action)
# random ensemble distillation.
random_index = torch.randint(0, 2, (1,)).item()
Expand All @@ -197,6 +202,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
Performs policy extraction using advantage weighted regression
"""
with torch.no_grad():
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
q1, q2 = self._critic_target.get_q_values(batch.state, batch.action)
# random ensemble distillation.
random_index = torch.randint(0, 2, (1,)).item()
Expand Down Expand Up @@ -226,6 +232,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

else:
if self._is_action_continuous:
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
log_action_probabilities = self._actor.get_log_probability(
batch.state, batch.action
).view(-1)
Expand Down
4 changes: 4 additions & 0 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
# TODO need to support continuous action
# TODO: change the output shape of value networks
assert isinstance(batch, PPOTransitionBatch)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
action_probs = self._actor.get_action_prob(
state_batch=batch.state,
action_batch=batch.action,
Expand All @@ -167,6 +168,8 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
clip = torch.clamp(
r_thelta, min=1.0 - self._epsilon, max=1.0 + self._epsilon
) # shape (batch_size)
# pyre-fixme[58]: `*` is not supported for operand types `Tensor` and
# `Optional[Tensor]`.
loss = torch.sum(-torch.min(r_thelta * batch.gae, clip * batch.gae))
# entropy
entropy: torch.Tensor = torch.distributions.Categorical(
Expand Down Expand Up @@ -236,6 +239,7 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:

state_values = self._critic(history_summary_batch).detach()
action_probs = (
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
self._actor.get_action_prob(
state_batch=history_summary_batch,
action_batch=action_representation_batch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def _get_next_state_quantiles(
# get q values from a q value distribution under a risk metric
# instead of using the 'get_q_values' method of the QuantileQValueNetwork,
# we invoke a method from the risk sensitive safety module
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_q_values_under_risk_metric`.
next_state_action_values = self.safety_module.get_q_values_under_risk_metric(
next_state_batch_repeated, next_available_actions_batch, self._Q_target
).view(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def act(

# instead of using the 'get_q_values' method of the QuantileQValueNetwork,
# we invoke a method from the risk sensitive safety module
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_q_values_under_risk_metric`.
q_values = self.safety_module.get_q_values_under_risk_metric(
states_repeated, actions, self._Q
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
batch.state
) # (batch_size x state_dim) note that here batch_size = episode length
return_batch = batch.cum_reward # (batch_size)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
policy_propensities = self._actor.get_action_prob(
batch.state,
batch.action,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(

# sac uses a learning rate scheduler specifically
def reset(self, action_space: ActionSpace) -> None:
# pyre-fixme[16]: `SoftActorCritic` has no attribute `_action_space`.
self._action_space = action_space
self.scheduler.step()

Expand Down Expand Up @@ -152,11 +153,14 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso
assert next_available_actions_batch is not None
next_state_batch_repeated = torch.repeat_interleave(
next_state_batch.unsqueeze(1),
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Module, Tensor]`.
self.action_representation_module.max_number_actions,
dim=1,
) # (batch_size x action_space_size x state_dim)

# get q values of (states, all actions) from twin critics
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
next_q1, next_q2 = self._critic_target.get_q_values(
state_batch=next_state_batch_repeated,
action_batch=next_available_actions_batch,
Expand All @@ -179,6 +183,7 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso
if next_unavailable_actions_mask_batch is not None:
next_state_action_values[next_unavailable_actions_mask_batch] = 0.0

# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
next_state_policy_dist = self._actor.get_policy_distribution(
state_batch=next_state_batch,
available_actions=next_available_actions_batch,
Expand All @@ -197,6 +202,8 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
state_batch = batch.state # (batch_size x state_dim)
state_batch_repeated = torch.repeat_interleave(
state_batch.unsqueeze(1),
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Module, Tensor]`.
self.action_representation_module.max_number_actions,
dim=1,
) # (batch_size x action_space_size x state_dim)
Expand All @@ -206,6 +213,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
) # (batch_size x action_space_size x action_dim)

# get q values of (states, all actions) from twin critics
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
q1, q2 = self._critic.get_q_values(
state_batch=state_batch_repeated, action_batch=available_actions
)
Expand All @@ -216,13 +224,16 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
batch.curr_unavailable_actions_mask
) # (batch_size x action_space_size)

# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
new_policy_dist = self._actor.get_policy_distribution(
state_batch=state_batch,
available_actions=available_actions,
unavailable_actions_mask=unavailable_actions_mask,
) # (batch_size x action_space_size)

state_action_values = q.view(
# pyre-fixme[6]: For 1st argument expected `dtype` but got `Tuple[int,
# Union[Module, Tensor]]`.
(state_batch.shape[0], self.action_representation_module.max_number_actions)
) # (batch_size x action_space_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,15 @@ def __init__(
torch.nn.Parameter(torch.zeros(1, requires_grad=True)),
)
self._entropy_optimizer: torch.optim.Optimizer = optim.AdamW(
[self._log_entropy], lr=critic_learning_rate, amsgrad=True
# pyre-fixme[6]: For 1st argument expected `Union[Iterable[Dict[str,
# Any]], Iterable[Tuple[str, Tensor]], Iterable[Tensor]]` but got
# `List[Union[Module, Tensor]]`.
[self._log_entropy],
lr=critic_learning_rate,
amsgrad=True,
)
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
self.register_buffer("_entropy_coef", torch.exp(self._log_entropy).detach())
assert isinstance(action_space, BoxSpace)
self.register_buffer(
Expand All @@ -120,11 +127,14 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

if self._entropy_autotune:
with torch.no_grad():
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
_, action_batch_log_prob = self._actor.sample_action(
state_batch, get_log_prob=True
)

entropy_optimizer_loss = (
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
-torch.exp(self._log_entropy)
* (action_batch_log_prob + self._target_entropy)
).mean()
Expand All @@ -133,6 +143,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
entropy_optimizer_loss.backward()
self._entropy_optimizer.step()

# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
self._entropy_coef = torch.exp(self._log_entropy).detach()
{**actor_critic_loss, **{"entropy_coef": entropy_optimizer_loss}}

Expand Down Expand Up @@ -171,8 +183,10 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso
(
next_action_batch,
next_action_batch_log_prob,
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
) = self._actor.sample_action(next_state_batch, get_log_prob=True)

# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
next_q1, next_q2 = self._critic_target.get_q_values(
state_batch=next_state_batch,
action_batch=next_action_batch,
Expand All @@ -198,8 +212,10 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
(
action_batch,
action_batch_log_prob,
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
) = self._actor.sample_action(state_batch, get_log_prob=True)

# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
q1, q2 = self._critic.get_q_values(
state_batch=state_batch, action_batch=action_batch
) # shape: (batch_size, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def __init__(
self.debug: bool = debug

def reset(self, action_space: ActionSpace) -> None:
# pyre-fixme[16]: `TabularQLearning` has no attribute `_action_space`.
self._action_space = action_space
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
# `Union[Tensor, Module]`.
for i, action in enumerate(self._action_space):
if int(action.item()) != i:
raise ValueError(
Expand Down Expand Up @@ -140,6 +143,8 @@ def learn(
old_q_value = self.q_values.get((state, action.item()), 0)
next_q_values = [
self.q_values.get((next_state, next_action.item()), 0)
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not
# a function.
for next_action in self._action_space
]

Expand Down
Loading

0 comments on commit 2f16709

Please sign in to comment.