Skip to content

Commit

Permalink
fix parameterlist parameterdict (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjr92 authored Jun 22, 2022
1 parent 62c49f6 commit fec76fc
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 72 deletions.
10 changes: 7 additions & 3 deletions examples/basic_tutorials/Parameter_Container.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
os.environ['TL_BACKEND'] = 'tensorflow'
# os.environ['TL_BACKEND'] = 'tensorflow'
# os.environ['TL_BACKEND'] = 'mindspore'
# os.environ['TL_BACKEND'] = 'paddle'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TL_BACKEND'] = 'torch'

import tensorlayerx as tlx
from tensorlayerx.nn import Module, Parameter, ParameterList, ParameterDict
Expand All @@ -28,6 +28,10 @@ def forward(self, x, choice):

input = tlx.nn.Input(shape=(5,5))
net = MyModule()

trainable_weights = net.trainable_weights
print("-----------------------------trainable_weights-------------------------------")
for weight in trainable_weights:
print(weight)
print("-----------------------------------output------------------------------------")
output = net(input, choice = 'right')
print(output)
133 changes: 65 additions & 68 deletions tensorlayerx/nn/core/core_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class Module(object):
def __init__(self, name=None, act=None, *args, **kwargs):
self._params = OrderedDict()
self._layers = OrderedDict()
self._params_list = OrderedDict()
self._params_dict = OrderedDict()
# self._params_list = OrderedDict()
# self._params_dict = OrderedDict()
self._params_status = OrderedDict()
self._parameter_layout_dict = {}
self._create_time = int(time.time() * 1e9)
Expand Down Expand Up @@ -148,11 +148,11 @@ def __setattr__(self, name, value):
raise TypeError("Expected type is Module, but got Parameter.")
self.insert_param_to_layer(name, value)

elif isinstance(value, ParameterList):
self.set_attr_for_parameter_tuple(name, value)

elif isinstance(value, ParameterDict):
self.set_attr_for_parameter_dict(name, value)
# elif isinstance(value, ParameterList):
# self.set_attr_for_parameter_tuple(name, value)
#
# elif isinstance(value, ParameterDict):
# self.set_attr_for_parameter_dict(name, value)

elif isinstance(value, Module):
if layers is None:
Expand Down Expand Up @@ -255,46 +255,46 @@ def _set_mode_for_layers(self, is_train):
if isinstance(layer, Module):
layer.is_train = is_train

def set_attr_for_parameter_dict(self, name, value):
"""Set attr for parameter in ParameterDict."""
params = self.__dict__.get('_params')
params_dict = self.__dict__.get('_params_dict')
if params is None:
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
exist_names = set("")
for item in value:
self.insert_param_to_layer(item, value[item], check_name=False)
if item in exist_names:
raise ValueError("The value {} , its name '{}' already exists.".
format(value[item], item))
exist_names.add(item)

if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_dict[name] = value

def set_attr_for_parameter_tuple(self, name, value):
"""Set attr for parameter in ParameterTuple."""
params = self.__dict__.get('_params')
params_list = self.__dict__.get('_params_list')
if params is None:
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
exist_names = set("")

for item in value:
self.insert_param_to_layer(item.name, item, check_name=False)
if item.name in exist_names:
raise ValueError("The value {} , its name '{}' already exists.".
format(value, item.name))
exist_names.add(item.name)

if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value
# def set_attr_for_parameter_dict(self, name, value):
# """Set attr for parameter in ParameterDict."""
# params = self.__dict__.get('_params')
# params_dict = self.__dict__.get('_params_dict')
# if params is None:
# raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
# exist_names = set("")
# for item in value:
# self.insert_param_to_layer(item, value[item], check_name=False)
# if item in exist_names:
# raise ValueError("The value {} , its name '{}' already exists.".
# format(value[item], item))
# exist_names.add(item)
#
# if name in self.__dict__:
# del self.__dict__[name]
# if name in params:
# del params[name]
# params_dict[name] = value
#
# def set_attr_for_parameter_tuple(self, name, value):
# """Set attr for parameter in ParameterTuple."""
# params = self.__dict__.get('_params')
# params_list = self.__dict__.get('_params_list')
# if params is None:
# raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
# exist_names = set("")
#
# for item in value:
# self.insert_param_to_layer(item.name, item, check_name=False)
# if item.name in exist_names:
# raise ValueError("The value {} , its name '{}' already exists.".
# format(value, item.name))
# exist_names.add(item.name)
#
# if name in self.__dict__:
# del self.__dict__[name]
# if name in params:
# del params[name]
# params_list[name] = value

def set_train(self):
"""Set this network in training mode. After calling this method,
Expand Down Expand Up @@ -354,7 +354,6 @@ def insert_param_to_layer(self, param_name, param, check_name=True):
Determines whether the name input is compatible. Default: True.
"""

if not param_name:
raise KeyError("The name of parameter should not be null.")
if check_name and '.' in param_name:
Expand Down Expand Up @@ -388,15 +387,15 @@ def __getattr__(self, name):
params_status = self.__dict__['_params_status']
if name in params_status:
return params_status[name]
if '_params_list' in self.__dict__:
params_list = self.__dict__['_params_list']
if name in params_list:
para_list = params_list[name]
return para_list
if '_params_dict' in self.__dict__:
params_dict = self.__dict__['_params_dict']
if name in params_dict:
return params_dict[name]
# if '_params_list' in self.__dict__:
# params_list = self.__dict__['_params_list']
# if name in params_list:
# para_list = params_list[name]
# return para_list
# if '_params_dict' in self.__dict__:
# params_dict = self.__dict__['_params_dict']
# if name in params_dict:
# return params_dict[name]
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))

def __delattr__(self, name):
Expand Down Expand Up @@ -1142,10 +1141,10 @@ def __setitem__(self, idx, parameter):
idx = self._get_abs_string_index(idx)
self._params[str(idx)] = parameter

# def __setattr__(self, key, value):
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
# warnings.warn("Setting attributes on ParameterList is not supported.")
# super(ParameterList, self).__setattr__(key, value)
def __setattr__(self, key, value):
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
# warnings.warn("Setting attributes on ParameterList is not supported.")
super(ParameterList, self).__setattr__(key, value)

def __len__(self):
return len(self._params)
Expand All @@ -1162,7 +1161,7 @@ def __dir__(self):
return keys

def append(self, parameter):
self._params[str(len(self))] = parameter
self.insert_param_to_layer(str(len(self)), parameter)
return self

def extend(self, parameters):
Expand All @@ -1173,7 +1172,7 @@ def extend(self, parameters):
)
offset = len(self)
for i, para in enumerate(parameters):
self._params[str(offset + i)] = para
self.insert_param_to_layer(str(offset + i), para)
return self

def __call__(self, input):
Expand Down Expand Up @@ -1248,15 +1247,13 @@ def __getitem__(self, key):
return self._params[key]

def __setitem__(self, key, parameter):
self._params[key] = parameter
self.insert_param_to_layer(key, parameter)

def __delitem__(self, key):
del self._params[key]

# def __setattr__(self, key, value):
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
# warnings.warn("Setting attributes on ParameterDict is not supported.")
# super(ParameterDict, self).__setattr__(key, value)
def __setattr__(self, key, value):
super(ParameterDict, self).__setattr__(key, value)

def __len__(self) -> int:
return len(self._params)
Expand Down
2 changes: 1 addition & 1 deletion tensorlayerx/nn/core/core_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def __dir__(self):
keys = [key for key in keys if not key.isdigit()]
return keys

def append(self, parameter: 'Parameter') -> 'ParameterList':
def append(self, parameter):

self.register_parameter(str(len(self)), parameter)
return self
Expand Down

0 comments on commit fec76fc

Please sign in to comment.