Skip to content

Commit

Permalink
Add type error suppressions for upcoming upgrade
Browse files Browse the repository at this point in the history
Differential Revision: D66831077

fbshipit-source-id: 4eae8ca13aea331ad08cb2cef073ea22009313d1
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Dec 6, 2024
1 parent f768ca3 commit 7f022a9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions pearl/neural_networks/common/epistemic_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def generate_params_buffers(self) -> None:
"""
Generate parameters and buffers for the priornet.
"""
# pyre-fixme[6]: For 1st argument expected `List[Module]` but got `ModuleList`.
self.params, self.buffers = torch.func.stack_module_state(self.models)

def call_single_model(
Expand Down
2 changes: 2 additions & 0 deletions pearl/neural_networks/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def wrapper(
) -> torch.Tensor:
return torch.func.functional_call(models[0], (params, buffers), data)

# pyre-fixme[6]: For 1st argument expected `List[Module]` but got
# `Union[List[Module], ModuleList]`.
params, buffers = stack_module_state(models)
values = torch.vmap(wrapper)(params, buffers, features).view(
(-1, batch_size)
Expand Down

0 comments on commit 7f022a9

Please sign in to comment.