From b92d5dd0c9f9e7c94e8817621e0c7fe396c608de Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 22:23:35 +0100 Subject: [PATCH] tentative cleanup --- src/brevitas/core/stats/stats_wrapper.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index 49bf62a82..8e148b8c1 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -97,19 +97,17 @@ def __init__( super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - if len(tracked_parameter_list) >= 1: - self.first_tracked_param = _ViewParameterWrapper( - tracked_parameter_list[0], stats_input_view_shape_impl) - else: - self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) if len(tracked_parameter_list) > 1: + self.first_tracked_param = _ViewParameterWrapper( + tracked_parameter_list[0], stats_input_view_shape_impl) extra_list = [ _ViewCatParameterWrapper( param, stats_input_view_shape_impl, stats_input_concat_dim) for param in tracked_parameter_list[1:]] self.extra_tracked_params_list = torch.nn.ModuleList(extra_list) else: + self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) self.extra_tracked_params_list = None self.stats = _Stats(stats_impl, stats_output_shape)