Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Type hints for NN and param store #2865

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions pyro/nn/auto_reg_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import warnings

from typing import Union, List, Optional

import torch
import torch.nn as nn
from torch.nn import functional as F


def sample_mask_indices(input_dim, hidden_dim, simple=True):
def sample_mask_indices(input_dim: int, hidden_dim: int, simple: bool=True) -> Union[int,torch.Tensor]:
"""
Samples the indices assigned to hidden units during the construction of MADE masks

Expand All @@ -32,9 +34,7 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True):
return ints


def create_mask(
input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier
):
def create_mask(input_dim: int , context_dim: int, hidden_dims: List[int], permutation: torch.LongTensor, output_dim_multiplier: int):
"""
Creates MADE masks for a conditional distribution

Expand Down Expand Up @@ -109,7 +109,7 @@ class MaskedLinear(nn.Linear):
:type bias: bool
"""

def __init__(self, in_features, out_features, mask, bias=True):
def __init__(self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool=True):
super().__init__(in_features, out_features, bias)
self.register_buffer("mask", mask.data)

Expand Down Expand Up @@ -165,15 +165,15 @@ class ConditionalAutoRegressiveNN(nn.Module):
"""

def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU(),
):
self,
input_dim: int,
context_dim: int,
hidden_dims: List[int],
param_dims: List[int]=[1, 1],
permutation: Optional[torch.LongTensor]=None,
skip_connections: bool=False,
nonlinearity=nn.ReLU()):
super().__init__()
if input_dim == 1:
warnings.warn(
Expand Down Expand Up @@ -327,14 +327,13 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN):
"""

def __init__(
self,
input_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU(),
):
self,
input_dim: int,
hidden_dims: List[int],
param_dims: List=[1, 1],
permutation: torch.LongTensor=None,
skip_connections: bool=False,
nonlinearity=nn.ReLU()):
super(AutoRegressiveNN, self).__init__(
input_dim,
0,
Expand Down
30 changes: 18 additions & 12 deletions pyro/nn/dense_nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union, Tuple

import torch


Expand Down Expand Up @@ -34,13 +36,13 @@ class ConditionalDenseNN(torch.nn.Module):
"""

def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
nonlinearity=torch.nn.ReLU(),
):
self,
input_dim:int,
context_dim:int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU()):

super().__init__()

self.input_dim = input_dim
Expand All @@ -65,14 +67,14 @@ def __init__(
# Save the nonlinearity
self.f = nonlinearity

def forward(self, x, context):
def forward(self, x:torch.Tensor, context:torch.Tensor) -> Union[torch.Tensor,Tuple[torch.Tensor]]:
# We must be able to broadcast the size of the context over the input
context = context.expand(x.size()[:-1] + (context.size(-1),))

x = torch.cat([context, x], dim=-1)
return self._forward(x)

def _forward(self, x):
def _forward(self, x:torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""
The forward method
"""
Expand Down Expand Up @@ -122,11 +124,15 @@ class DenseNN(ConditionalDenseNN):
"""

def __init__(
self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU()
):
self,
input_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.module = torch.nn.ReLU()) -> None:

super(DenseNN, self).__init__(
input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity
)

def forward(self, x):
def forward(self, x: torch.Tensor):
return self._forward(x)
41 changes: 23 additions & 18 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import functools
import inspect
from collections import OrderedDict, namedtuple
from typing import Callable, Dict, Union, Optional

import torch
from torch.distributions import constraints, transform_to

import pyro
from pyro.poutine.runtime import _PYRO_PARAM_STORE
import pyro # type: ignore
from pyro.poutine.runtime import _PYRO_PARAM_STORE # type: ignore


class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))):
Expand Down Expand Up @@ -156,7 +157,7 @@ def __get__(self, obj, obj_type):
return obj.__getattr__(self.name)


def _make_name(prefix, name):
def _make_name(prefix:str, name:str):
return "{}.{}".format(prefix, name) if prefix else name


Expand Down Expand Up @@ -210,7 +211,7 @@ def _get_pyro_params(module):


class _PyroModuleMeta(type):
_pyro_mixin_cache = {}
_pyro_mixin_cach: Dict = {}

# Unpickling helper to create an empty object of type PyroModule[Module].
class _New:
Expand Down Expand Up @@ -374,15 +375,15 @@ class PyroLinear(nn.Linear, PyroModule):
:param str name: Optional name for a root PyroModule. This is ignored in
sub-PyroModules of another PyroModule.
"""

def __init__(self, name=""):
def __init__(self, name:str=""):
self._pyro_name = name
self._pyro_context = _Context() # shared among sub-PyroModules
self._pyro_params = OrderedDict()
self._pyro_samples = OrderedDict()
self._pyro_params: OrderedDict = OrderedDict()
self._pyro_samples: OrderedDict = OrderedDict()
super().__init__()

def add_module(self, name, module):
def add_module(self, name:str, module):
"""
Adds a child module to the current module.
"""
Expand All @@ -392,7 +393,7 @@ def add_module(self, name, module):
)
super().add_module(name, module)

def named_pyro_params(self, prefix="", recurse=True):
def named_pyro_params(self, prefix:str='', recurse:bool=True):
"""
Returns an iterator over PyroModule parameters, yielding both the
name of the parameter as well as the parameter itself.
Expand All @@ -407,7 +408,7 @@ def named_pyro_params(self, prefix="", recurse=True):
for elem in gen:
yield elem

def _pyro_set_supermodule(self, name, context):
def _pyro_set_supermodule(self, name:str, context):
self._pyro_name = name
self._pyro_context = context
for key, value in self._modules.items():
Expand All @@ -417,6 +418,10 @@ def _pyro_set_supermodule(self, name, context):
), "submodule {} has executed outside of supermodule".format(name)
value._pyro_set_supermodule(_make_name(name, key), context)

def _pyro_get_fullname(self, name:str):
assert self.__dict__['_pyro_context'].used, "fullname is not yet defined"
return _make_name(self.__dict__['_pyro_name'], name)

def _pyro_get_fullname(self, name):
assert self.__dict__["_pyro_context"].used, "fullname is not yet defined"
return _make_name(self.__dict__["_pyro_name"], name)
Expand All @@ -425,7 +430,7 @@ def __call__(self, *args, **kwargs):
with self._pyro_context:
return super().__call__(*args, **kwargs)

def __getattr__(self, name):
def __getattr__(self, name:str):
# PyroParams trigger pyro.param statements.
if "_pyro_params" in self.__dict__:
_pyro_params = self.__dict__["_pyro_params"]
Expand Down Expand Up @@ -507,7 +512,7 @@ def __getattr__(self, name):

return result

def __setattr__(self, name, value):
def __setattr__(self, name:str, value:Union[PyroParam,"PyroModule",torch.nn.Parameter,torch.Tensor]):
if isinstance(value, PyroModule):
# Create a new sub PyroModule, overwriting any old value.
try:
Expand Down Expand Up @@ -585,7 +590,7 @@ def __setattr__(self, name, value):

super().__setattr__(name, value)

def __delattr__(self, name):
def __delattr__(self, name:str):
if name in self._parameters:
del self._parameters[name]
if self._pyro_context.used:
Expand Down Expand Up @@ -621,7 +626,7 @@ def __delattr__(self, name):
super().__delattr__(name)


def pyro_method(fn):
def pyro_method(fn: Callable):
"""
Decorator for top-level methods of a :class:`PyroModule` to enable pyro
effects and cache ``pyro.sample`` statements.
Expand All @@ -638,7 +643,7 @@ def cached_fn(self, *args, **kwargs):
return cached_fn


def clear(mod):
def clear(mod:PyroModule):
"""
Removes data from both a :class:`PyroModule` and the param store.

Expand All @@ -653,7 +658,7 @@ def clear(mod):
delattr(mod, name)


def to_pyro_module_(m, recurse=True):
def to_pyro_module_(m:torch.nn.Module, recurse:bool=True):
"""
Converts an ordinary :class:`torch.nn.Module` instance to a
:class:`PyroModule` **in-place**.
Expand Down Expand Up @@ -714,7 +719,7 @@ def to_pyro_module_(m, recurse=True):
# attribute. This is required if any attribute is set to a PyroParam or
# PyroSample. For motivation, see https://github.com/pyro-ppl/pyro/issues/2390
class _FlatWeightsDescriptor:
def __get__(self, obj, obj_type=None):
def __get__(self, obj, obj_type:Optional=None):
if obj is None:
return self
return [getattr(obj, name) for name in obj._flat_weights_names]
Expand Down
4 changes: 2 additions & 2 deletions pyro/optim/dct_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]:

return loss

def _step_param(self, group: Dict, p) -> None:
def _step_param(self, group: Dict, p: torch.Tensor) -> None:
grad = p.grad.data
grad.clamp_(-group["clip_norm"], group["clip_norm"])

Expand Down Expand Up @@ -160,7 +160,7 @@ def _step_param(self, group: Dict, p) -> None:
step = _transform_inverse(exp_avg / denom, time_dim, duration)
p.data.add_(step.mul_(-step_size))

def _step_param_subsample(self, group: Dict, p, subsample) -> None:
def _step_param_subsample(self, group: Dict, p: torch.Tensor, subsample) -> None:
mask = _get_mask(p, subsample)

grad = p.grad.data.masked_select(mask)
Expand Down
Loading