@@ -92,6 +92,43 @@ def init_weights(module, initialization):
9292ACTIVATIONS = ["ReLU" , "Softplus" , "Tanh" , "SELU" , "LeakyReLU" , "PReLU" , "Sigmoid" ]
9393
9494
95+ class MLP (nn .Module ):
96+ def __init__ (
97+ self ,
98+ in_features : int ,
99+ out_features : int ,
100+ hidden_size : list [int ],
101+ activation : str ,
102+ dropout : float ,
103+ ):
104+ super ().__init__ ()
105+
106+ activ = getattr (nn , activation )()
107+
108+ self .layers : nn .Sequential
109+
110+ layers = [
111+ nn .Linear (in_features , hidden_size [0 ]),
112+ ]
113+ layers .append (activ )
114+
115+ if dropout > 0 :
116+ layers .append (nn .Dropout (p = dropout ))
117+
118+ for i in range (len (hidden_size ) - 1 ):
119+ layers .append (nn .Linear (hidden_size [i ], hidden_size [i + 1 ]))
120+ layers .append (activ )
121+
122+ if dropout > 0 :
123+ layers .append (nn .Dropout (p = dropout ))
124+
125+ layers .append (nn .Linear (hidden_size [- 1 ], out_features ))
126+ self .layers = nn .Sequential (* layers )
127+
128+ def forward (self , X : torch .Tensor ) -> torch .Tensor :
129+ return self .layers (X )
130+
131+
95132class NHiTSBlock (nn .Module ):
96133 """
97134 N-HiTS block which takes a basis function as an argument.
@@ -137,15 +174,18 @@ def __init__(
137174 self .batch_normalization = batch_normalization
138175 self .dropout = dropout
139176
140- self . hidden_size = [
177+ mlp_in_features = (
141178 self .context_length_pooled * len (self .output_size )
142179 + self .context_length * self .encoder_covariate_size
143180 + self .prediction_length * self .decoder_covariate_size
144181 + self .static_hidden_size
145- ] + hidden_size
182+ )
183+
184+ mlp_out_features = context_length * len (output_size ) + n_theta * sum (
185+ output_size
186+ )
146187
147188 assert activation in ACTIVATIONS , f"{ activation } is not in { ACTIVATIONS } "
148- activ = getattr (nn , activation )()
149189
150190 if pooling_mode == "max" :
151191 self .pooling_layer = nn .MaxPool1d (
@@ -160,40 +200,21 @@ def __init__(
160200 ceil_mode = True ,
161201 )
162202
163- hidden_layers = []
164- for i in range (n_layers ):
165- hidden_layers .append (
166- nn .Linear (
167- in_features = self .hidden_size [i ],
168- out_features = self .hidden_size [i + 1 ],
169- )
170- )
171- hidden_layers .append (activ )
172-
173- if self .batch_normalization :
174- hidden_layers .append (
175- nn .BatchNorm1d (num_features = self .hidden_size [i + 1 ])
176- )
177-
178- if self .dropout > 0 :
179- hidden_layers .append (nn .Dropout (p = self .dropout ))
180-
181- output_layer = [
182- nn .Linear (
183- in_features = self .hidden_size [- 1 ],
184- out_features = context_length * len (output_size )
185- + n_theta * sum (output_size ),
186- )
187- ]
188- layers = hidden_layers + output_layer
189-
190203 # static_size is computed with data, static_hidden_size is provided by user,
191204 # if 0 no statics are used
192205 if (self .static_size > 0 ) and (self .static_hidden_size > 0 ):
193206 self .static_encoder = StaticFeaturesEncoder (
194207 in_features = static_size , out_features = static_hidden_size
195208 )
196- self .layers = nn .Sequential (* layers )
209+
210+ self .layers = MLP (
211+ in_features = mlp_in_features ,
212+ out_features = mlp_out_features ,
213+ hidden_size = hidden_size ,
214+ activation = activation ,
215+ dropout = self .dropout ,
216+ )
217+
197218 self .basis = basis
198219
199220 def forward (
0 commit comments