From 6129509a45651445dc9d5c9f867d49eb9d0164ea Mon Sep 17 00:00:00 2001 From: "kamathhrishi@gmail.com" Date: Tue, 8 Jun 2021 22:57:24 +0530 Subject: [PATCH 1/2] Type hints for Params and NN --- pyro/optim/dct_adam.py | 4 ++-- pyro/params/param_store.py | 48 ++++++++++++++++++++------------------ pyro/poutine/util.py | 8 +++---- pyro/util.py | 7 +++--- setup.cfg | 3 --- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/pyro/optim/dct_adam.py b/pyro/optim/dct_adam.py index ac1ceb5590..87cfe2c5dc 100644 --- a/pyro/optim/dct_adam.py +++ b/pyro/optim/dct_adam.py @@ -104,7 +104,7 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]: return loss - def _step_param(self, group: Dict, p) -> None: + def _step_param(self, group: Dict, p: torch.Tensor) -> None: grad = p.grad.data grad.clamp_(-group['clip_norm'], group['clip_norm']) @@ -145,7 +145,7 @@ def _step_param(self, group: Dict, p) -> None: step = _transform_inverse(exp_avg / denom, time_dim, duration) p.data.add_(step.mul_(-step_size)) - def _step_param_subsample(self, group: Dict, p, subsample) -> None: + def _step_param_subsample(self, group: Dict, p: torch.Tensor, subsample) -> None: mask = _get_mask(p, subsample) grad = p.grad.data.masked_select(mask) diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 3b31046a3c..b370b472f9 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -4,6 +4,8 @@ import re import warnings import weakref +from typing import Iterable, Dict, Union, Callable, Optional +from collections import KeysView import torch from torch.distributions import constraints, transform_to @@ -46,7 +48,7 @@ def __init__(self): self._param_to_name = {} # dictionary from unconstrained param to param name self._constraints = {} # dictionary from param name to constraint object - def clear(self): + def clear(self) -> None: """ Clear the ParamStore """ @@ -62,7 +64,7 @@ def items(self): for name in self._params: yield name, self[name] - def keys(self): + def keys(self) -> KeysView: """ Iterate over param names. """ @@ -75,22 +77,22 @@ def values(self): for name, constrained_param in self.items(): yield constrained_param - def __bool__(self): + def __bool__(self) -> bool: return bool(self._params) - def __len__(self): + def __len__(self) -> int: return len(self._params) - def __contains__(self, name): + def __contains__(self, name:str) -> bool: return name in self._params - def __iter__(self): + def __iter__(self) -> Iterable: """ Iterate over param names. """ return iter(self.keys()) - def __delitem__(self, name): + def __delitem__(self, name: str) -> None: """ Remove a parameter from the param store. """ @@ -98,7 +100,7 @@ def __delitem__(self, name): self._param_to_name.pop(unconstrained_value) self._constraints.pop(name) - def __getitem__(self, name): + def __getitem__(self, name: str): """ Get the *constrained* value of a named parameter. """ @@ -111,7 +113,7 @@ def __getitem__(self, name): return constrained_value - def __setitem__(self, name, new_constrained_value): + def __setitem__(self, name: str, new_constrained_value): """ Set the constrained value of an existing parameter, or the value of a new *unconstrained* parameter. To declare a new parameter with @@ -131,7 +133,7 @@ def __setitem__(self, name, new_constrained_value): self._params[name] = unconstrained_value self._param_to_name[unconstrained_value] = name - def setdefault(self, name, init_constrained_value, constraint=constraints.real): + def setdefault(self, name:str, init_constrained_value: Union[torch.Tensor,Callable[[],torch.Tensor]], constraint:constraints=constraints.real) -> torch.Tensor: """ Retrieve a *constrained* parameter value from the if it exists, otherwise set the initial value. Note that this is a little fancier than @@ -169,7 +171,7 @@ def setdefault(self, name, init_constrained_value, constraint=constraints.real): # ------------------------------------------------------------------------------- # Old non-dict interface - def named_parameters(self): + def named_parameters(self) -> Iterable: """ Returns an iterator over ``(name, unconstrained_value)`` tuples for each parameter in the ParamStore. Note that, in the event the parameter is constrained, @@ -177,18 +179,18 @@ def named_parameters(self): """ return self._params.items() - def get_all_param_names(self): + def get_all_param_names(self) -> KeysView: warnings.warn("ParamStore.get_all_param_names() is deprecated; use .keys() instead.", DeprecationWarning) return self.keys() - def replace_param(self, param_name, new_param, old_param): + def replace_param(self, param_name:str, new_param:str, old_param:str): warnings.warn("ParamStore.replace_param() is deprecated; use .__setitem__() instead.", DeprecationWarning) assert self._params[param_name] is old_param.unconstrained() self[param_name] = new_param - def get_param(self, name, init_tensor=None, constraint=constraints.real, event_dim=None): + def get_param(self, name: str, init_tensor: Optional[torch.Tensor] = None, constraint:constraints.Constraint=constraints.real, event_dim:Optional[int] = None): """ Get parameter from its name. If it does not yet exist in the ParamStore, it will be created and stored. @@ -209,7 +211,7 @@ def get_param(self, name, init_tensor=None, constraint=constraints.real, event_d else: return self.setdefault(name, init_tensor, constraint) - def match(self, name): + def match(self, name:str) -> Dict: """ Get all parameters that match regex. The parameter must exist. @@ -220,7 +222,7 @@ def match(self, name): pattern = re.compile(name) return {name: self[name] for name in self if pattern.match(name)} - def param_name(self, p): + def param_name(self, p) -> str: """ Get parameter name from parameter @@ -229,7 +231,7 @@ def param_name(self, p): """ return self._param_to_name.get(p) - def get_state(self): + def get_state(self) -> Dict: """ Get the ParamStore state. """ @@ -257,7 +259,7 @@ def set_state(self, state): constraint = constraints.real self._constraints[param_name] = constraint - def save(self, filename): + def save(self, filename:str) -> None: """ Save parameters to disk @@ -267,7 +269,7 @@ def save(self, filename): with open(filename, "wb") as output_file: torch.save(self.get_state(), output_file) - def load(self, filename, map_location=None): + def load(self, filename: str, map_location: Optional[Union[Callable, torch.device, str, Dict]] = None) -> None: """ Loads parameters from disk @@ -293,19 +295,19 @@ def load(self, filename, map_location=None): _MODULE_NAMESPACE_DIVIDER = "$$$" -def param_with_module_name(pyro_name, param_name): +def param_with_module_name(pyro_name: str, param_name: str) -> str: return _MODULE_NAMESPACE_DIVIDER.join([pyro_name, param_name]) -def module_from_param_with_module_name(param_name): +def module_from_param_with_module_name(param_name: str) -> str: return param_name.split(_MODULE_NAMESPACE_DIVIDER)[0] -def user_param_name(param_name): +def user_param_name(param_name: str) -> str: if _MODULE_NAMESPACE_DIVIDER in param_name: return param_name.split(_MODULE_NAMESPACE_DIVIDER)[1] return param_name -def normalize_param_name(name): +def normalize_param_name(name: str) -> str: return name.replace(_MODULE_NAMESPACE_DIVIDER, ".") diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index 4ef4c9c412..1c20a79118 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -38,7 +38,7 @@ def prune_subsample_sites(trace): return trace -def enum_extend(trace, msg, num_samples=None): +def enum_extend(trace, msg: str, num_samples: Optional[int]=None) -> List: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -65,7 +65,7 @@ def enum_extend(trace, msg, num_samples=None): return extended_traces -def mc_extend(trace, msg, num_samples=None): +def mc_extend(trace, msg: str, num_samples:Optional[int] = None) -> List: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -90,7 +90,7 @@ def mc_extend(trace, msg, num_samples=None): return extended_traces -def discrete_escape(trace, msg): +def discrete_escape(trace, msg: str) -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -107,7 +107,7 @@ def discrete_escape(trace, msg): (getattr(msg["fn"], "has_enumerate_support", False)) -def all_escape(trace, msg): +def all_escape(trace, msg: str) -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site diff --git a/pyro/util.py b/pyro/util.py index fd6299b8e6..95ecabdc2a 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -11,6 +11,7 @@ from collections import defaultdict from contextlib import contextmanager from itertools import zip_longest +from typing import Dict import numpy as np import torch @@ -18,7 +19,7 @@ from pyro.poutine.util import site_is_subsample -def set_rng_seed(rng_seed): +def set_rng_seed(rng_seed : int) -> None: """ Sets seeds of `torch` and `torch.cuda` (if available). @@ -29,11 +30,11 @@ def set_rng_seed(rng_seed): np.random.seed(rng_seed) -def get_rng_state(): +def get_rng_state() -> Dict: return {'torch': torch.get_rng_state(), 'random': random.getstate(), 'numpy': np.random.get_state()} -def set_rng_state(state): +def set_rng_state(state : Dict) -> Dict: torch.set_rng_state(state['torch']) random.setstate(state['random']) if 'numpy' in state: diff --git a/setup.cfg b/setup.cfg index 3dc45864e6..35e3537483 100644 --- a/setup.cfg +++ b/setup.cfg @@ -74,9 +74,6 @@ warn_unused_ignores = True warn_incomplete_stub = True warn_unused_ignores = True -[mypy-pyro.params.*] -ignore_errors = True -warn_unused_ignores = True [mypy-pyro.poutine.*] ignore_errors = True From b4009803aea6180890a85f865c61b6b9449960ed Mon Sep 17 00:00:00 2001 From: "kamathhrishi@gmail.com" Date: Mon, 5 Jul 2021 10:21:59 +0530 Subject: [PATCH 2/2] Some more types --- pyro/nn/auto_reg_nn.py | 30 ++++++++++++++++-------------- pyro/nn/dense_nn.py | 26 ++++++++++++++------------ pyro/nn/module.py | 37 +++++++++++++++++++------------------ pyro/params/param_store.py | 7 ++++--- pyro/util.py | 12 ++++++------ 5 files changed, 59 insertions(+), 53 deletions(-) diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index c33d9f8ec8..962e2e615f 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -3,12 +3,14 @@ import warnings +from typing import Union, List, Optional + import torch import torch.nn as nn from torch.nn import functional as F -def sample_mask_indices(input_dim, hidden_dim, simple=True): +def sample_mask_indices(input_dim: int, hidden_dim: int, simple: bool=True) -> Union[int,torch.Tensor]: """ Samples the indices assigned to hidden units during the construction of MADE masks @@ -30,7 +32,7 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True): return ints -def create_mask(input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier): +def create_mask(input_dim: int , context_dim: int, hidden_dims: List[int], permutation: torch.LongTensor, output_dim_multiplier: int): """ Creates MADE masks for a conditional distribution @@ -91,7 +93,7 @@ class MaskedLinear(nn.Linear): :type bias: bool """ - def __init__(self, in_features, out_features, mask, bias=True): + def __init__(self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool=True): super().__init__(in_features, out_features, bias) self.register_buffer('mask', mask.data) @@ -148,12 +150,12 @@ class ConditionalAutoRegressiveNN(nn.Module): def __init__( self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, + input_dim: int, + context_dim: int, + hidden_dims: List[int], + param_dims: List[int]=[1, 1], + permutation: Optional[torch.LongTensor]=None, + skip_connections: bool=False, nonlinearity=nn.ReLU()): super().__init__() if input_dim == 1: @@ -294,11 +296,11 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN): def __init__( self, - input_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, + input_dim: int, + hidden_dims: List[int], + param_dims: List=[1, 1], + permutation: torch.LongTensor=None, + skip_connections: bool=False, nonlinearity=nn.ReLU()): super( AutoRegressiveNN, diff --git a/pyro/nn/dense_nn.py b/pyro/nn/dense_nn.py index b8af6ef0f1..1f1bbf6635 100644 --- a/pyro/nn/dense_nn.py +++ b/pyro/nn/dense_nn.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import List, Union, Tuple + import torch @@ -35,11 +37,11 @@ class ConditionalDenseNN(torch.nn.Module): def __init__( self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU()): + input_dim:int, + context_dim:int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.Module = torch.nn.ReLU()): super().__init__() self.input_dim = input_dim @@ -64,14 +66,14 @@ def __init__( # Save the nonlinearity self.f = nonlinearity - def forward(self, x, context): + def forward(self, x:torch.Tensor, context:torch.Tensor) -> Union[torch.Tensor,Tuple[torch.Tensor]]: # We must be able to broadcast the size of the context over the input context = context.expand(x.size()[:-1]+(context.size(-1),)) x = torch.cat([context, x], dim=-1) return self._forward(x) - def _forward(self, x): + def _forward(self, x:torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ The forward method """ @@ -122,10 +124,10 @@ class DenseNN(ConditionalDenseNN): def __init__( self, - input_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU()): + input_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.module = torch.nn.ReLU()) -> None: super(DenseNN, self).__init__( input_dim, 0, @@ -134,5 +136,5 @@ def __init__( nonlinearity=nonlinearity ) - def forward(self, x): + def forward(self, x: torch.Tensor): return self._forward(x) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 10d4c779b0..198e21785c 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -15,12 +15,13 @@ import functools import inspect from collections import OrderedDict, namedtuple +from typing import Callable, Dict, Union, Optional import torch from torch.distributions import constraints, transform_to -import pyro -from pyro.poutine.runtime import _PYRO_PARAM_STORE +import pyro # type: ignore +from pyro.poutine.runtime import _PYRO_PARAM_STORE # type: ignore class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))): @@ -152,7 +153,7 @@ def __get__(self, obj, obj_type): return obj.__getattr__(self.name) -def _make_name(prefix, name): +def _make_name(prefix:str, name:str): return "{}.{}".format(prefix, name) if prefix else name @@ -202,7 +203,7 @@ def _get_pyro_params(module): class _PyroModuleMeta(type): - _pyro_mixin_cache = {} + _pyro_mixin_cach: Dict = {} # Unpickling helper to create an empty object of type PyroModule[Module]. class _New: @@ -365,14 +366,14 @@ class PyroLinear(nn.Linear, PyroModule): :param str name: Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule. """ - def __init__(self, name=""): + def __init__(self, name:str=""): self._pyro_name = name self._pyro_context = _Context() # shared among sub-PyroModules - self._pyro_params = OrderedDict() - self._pyro_samples = OrderedDict() + self._pyro_params: OrderedDict = OrderedDict() + self._pyro_samples: OrderedDict = OrderedDict() super().__init__() - def add_module(self, name, module): + def add_module(self, name:str, module): """ Adds a child module to the current module. """ @@ -380,7 +381,7 @@ def add_module(self, name, module): module._pyro_set_supermodule(_make_name(self._pyro_name, name), self._pyro_context) super().add_module(name, module) - def named_pyro_params(self, prefix='', recurse=True): + def named_pyro_params(self, prefix:str='', recurse:bool=True): """ Returns an iterator over PyroModule parameters, yielding both the name of the parameter as well as the parameter itself. @@ -395,7 +396,7 @@ def named_pyro_params(self, prefix='', recurse=True): for elem in gen: yield elem - def _pyro_set_supermodule(self, name, context): + def _pyro_set_supermodule(self, name:str, context): self._pyro_name = name self._pyro_context = context for key, value in self._modules.items(): @@ -404,7 +405,7 @@ def _pyro_set_supermodule(self, name, context): "submodule {} has executed outside of supermodule".format(name) value._pyro_set_supermodule(_make_name(name, key), context) - def _pyro_get_fullname(self, name): + def _pyro_get_fullname(self, name:str): assert self.__dict__['_pyro_context'].used, "fullname is not yet defined" return _make_name(self.__dict__['_pyro_name'], name) @@ -412,7 +413,7 @@ def __call__(self, *args, **kwargs): with self._pyro_context: return super().__call__(*args, **kwargs) - def __getattr__(self, name): + def __getattr__(self, name:str): # PyroParams trigger pyro.param statements. if '_pyro_params' in self.__dict__: _pyro_params = self.__dict__['_pyro_params'] @@ -479,7 +480,7 @@ def __getattr__(self, name): return result - def __setattr__(self, name, value): + def __setattr__(self, name:str, value:Union[PyroParam,"PyroModule",torch.nn.Parameter,torch.Tensor]): if isinstance(value, PyroModule): # Create a new sub PyroModule, overwriting any old value. try: @@ -550,7 +551,7 @@ def __setattr__(self, name, value): super().__setattr__(name, value) - def __delattr__(self, name): + def __delattr__(self, name:str): if name in self._parameters: del self._parameters[name] if self._pyro_context.used: @@ -586,7 +587,7 @@ def __delattr__(self, name): super().__delattr__(name) -def pyro_method(fn): +def pyro_method(fn: Callable): """ Decorator for top-level methods of a :class:`PyroModule` to enable pyro effects and cache ``pyro.sample`` statements. @@ -603,7 +604,7 @@ def cached_fn(self, *args, **kwargs): return cached_fn -def clear(mod): +def clear(mod:PyroModule): """ Removes data from both a :class:`PyroModule` and the param store. @@ -618,7 +619,7 @@ def clear(mod): delattr(mod, name) -def to_pyro_module_(m, recurse=True): +def to_pyro_module_(m:torch.nn.Module, recurse:bool=True): """ Converts an ordinary :class:`torch.nn.Module` instance to a :class:`PyroModule` **in-place**. @@ -679,7 +680,7 @@ def to_pyro_module_(m, recurse=True): # attribute. This is required if any attribute is set to a PyroParam or # PyroSample. For motivation, see https://github.com/pyro-ppl/pyro/issues/2390 class _FlatWeightsDescriptor: - def __get__(self, obj, obj_type=None): + def __get__(self, obj, obj_type:Optional=None): if obj is None: return self return [getattr(obj, name) for name in obj._flat_weights_names] diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index b370b472f9..ea832e292d 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -10,6 +10,7 @@ import torch from torch.distributions import constraints, transform_to +import pyro class ParamStoreDict: """ @@ -133,7 +134,7 @@ def __setitem__(self, name: str, new_constrained_value): self._params[name] = unconstrained_value self._param_to_name[unconstrained_value] = name - def setdefault(self, name:str, init_constrained_value: Union[torch.Tensor,Callable[[],torch.Tensor]], constraint:constraints=constraints.real) -> torch.Tensor: + def setdefault(self, name:str, init_constrained_value: Union[torch.Tensor,Callable[[],torch.Tensor]], constraint:constraints.Constraint=constraints.real) -> torch.Tensor: """ Retrieve a *constrained* parameter value from the if it exists, otherwise set the initial value. Note that this is a little fancier than @@ -184,7 +185,7 @@ def get_all_param_names(self) -> KeysView: DeprecationWarning) return self.keys() - def replace_param(self, param_name:str, new_param:str, old_param:str): + def replace_param(self, param_name:str, new_param: pyro.param, old_param: pyro.param): warnings.warn("ParamStore.replace_param() is deprecated; use .__setitem__() instead.", DeprecationWarning) assert self._params[param_name] is old_param.unconstrained() @@ -241,7 +242,7 @@ def get_state(self) -> Dict: } return state - def set_state(self, state): + def set_state(self, state:Dict): """ Set the ParamStore state using state from a previous get_state() call """ diff --git a/pyro/util.py b/pyro/util.py index 95ecabdc2a..a36294f9a6 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -11,7 +11,7 @@ from collections import defaultdict from contextlib import contextmanager from itertools import zip_longest -from typing import Dict +from typing import Dict, Optional import numpy as np import torch @@ -42,7 +42,7 @@ def set_rng_state(state : Dict) -> Dict: np.random.set_state(state['numpy']) -def torch_isnan(x): +def torch_isnan(x : torch.Tensor) -> torch.Tensor: """ A convenient function to check if a Tensor contains any nan; also works with numbers """ @@ -51,7 +51,7 @@ def torch_isnan(x): return torch.isnan(x).any() -def torch_isinf(x): +def torch_isinf(x : torch.Tensor) -> torch.Tensor: """ A convenient function to check if a Tensor contains any +inf; also works with numbers """ @@ -60,7 +60,7 @@ def torch_isinf(x): return (x == math.inf).any() or (x == -math.inf).any() -def warn_if_nan(value, msg="", *, filename=None, lineno=None): +def warn_if_nan(value, msg : str="", *, filename :Optional[str] = None, lineno: Optional[bool]=None) -> torch.Tensor: """ A convenient function to warn if a Tensor or its grad contains any nan, also works with numbers. @@ -85,8 +85,8 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None): return value -def warn_if_inf(value, msg="", allow_posinf=False, allow_neginf=False, *, - filename=None, lineno=None): +def warn_if_inf(value:torch.Tensor, msg : str="", allow_posinf:bool=False, allow_neginf:bool=False, *, + filename : str =None, lineno=None) -> torch.Tensor: """ A convenient function to warn if a Tensor or its grad contains any inf, also works with numbers.