Skip to content

Commit f499773

Browse files
authored
[ENH] rename "model metadata" classes to "model package" classes (#1892)
Proposed solution for, and fixes #1886 This PR contains changes naming convention of "model metadata" class to "model package container" nomenclature Not breaking as the v2 package/metadata layer has not been released yet. Also contains: * minor docstring fixes * tests for the naming pattern * a small bugfix to the test framework, for the above test
1 parent 57a6a1c commit f499773

File tree

18 files changed

+136
-98
lines changed

18 files changed

+136
-98
lines changed

pytorch_forecasting/models/base/_base_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class _BaseObject(_SkbaseBaseObject):
1212

1313

1414
class _BasePtForecaster(_BaseObject):
15-
"""Base class for all PyTorch Forecasting forecaster metadata.
15+
"""Base class for all PyTorch Forecasting forecaster packages.
1616
1717
This class points to model objects and contains metadata as tags.
1818
"""
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""DeepAR: Probabilistic forecasting with autoregressive recurrent networks."""
22

33
from pytorch_forecasting.models.deepar._deepar import DeepAR
4-
from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata
4+
from pytorch_forecasting.models.deepar._deepar_pkg import DeepAR_pkg
55

6-
__all__ = ["DeepAR", "DeepARMetadata"]
6+
__all__ = ["DeepAR", "DeepAR_pkg"]

pytorch_forecasting/models/deepar/_deepar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636

3737
class DeepAR(AutoRegressiveBaseModelWithCovariates):
38+
"""DeepAR: Probabilistic forecasting with autoregressive recurrent networks."""
39+
3840
def __init__(
3941
self,
4042
cell_type: str = "LSTM",

pytorch_forecasting/models/deepar/_deepar_metadata.py renamed to pytorch_forecasting/models/deepar/_deepar_pkg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""DeepAR metadata container."""
1+
"""DeepAR package container."""
22

33
from pytorch_forecasting.models.base._base_object import _BasePtForecaster
44

55

6-
class DeepARMetadata(_BasePtForecaster):
7-
"""DeepAR metadata container."""
6+
class DeepAR_pkg(_BasePtForecaster):
7+
"""DeepAR package container."""
88

99
_tags = {
1010
"info:name": "DeepAR",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Simple models based on fully connected networks."""
22

33
from pytorch_forecasting.models.mlp._decodermlp import DecoderMLP
4-
from pytorch_forecasting.models.mlp._decodermlp_metadata import DecoderMLPMetadata
4+
from pytorch_forecasting.models.mlp._decodermlp_pkg import DecoderMLP_pkg
55
from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule
66

7-
__all__ = ["DecoderMLP", "DecoderMLPMetadata", "FullyConnectedModule"]
7+
__all__ = ["DecoderMLP", "DecoderMLP_pkg", "FullyConnectedModule"]

pytorch_forecasting/models/mlp/_decodermlp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525

2626
class DecoderMLP(BaseModelWithCovariates):
27-
"""
28-
MLP on the decoder.
27+
"""MLP on the decoder.
2928
3029
MLP that predicts output only based on information available in the decoder.
3130
"""

pytorch_forecasting/models/mlp/_decodermlp_metadata.py renamed to pytorch_forecasting/models/mlp/_decodermlp_pkg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""DecoderMLP metadata container."""
1+
"""DecoderMLP package container."""
22

33
from pytorch_forecasting.models.base._base_object import _BasePtForecaster
44

55

6-
class DecoderMLPMetadata(_BasePtForecaster):
7-
"""DecoderMLP metadata container."""
6+
class DecoderMLP_pkg(_BasePtForecaster):
7+
"""DecoderMLP package container."""
88

99
_tags = {
1010
"info:name": "DecoderMLP",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""N-Beats model for timeseries forecasting without covariates."""
22

33
from pytorch_forecasting.models.nbeats._nbeats import NBeats
4-
from pytorch_forecasting.models.nbeats._nbeats_metadata import NBeatsMetadata
4+
from pytorch_forecasting.models.nbeats._nbeats_pkg import NBeats_pkg
55
from pytorch_forecasting.models.nbeats.sub_modules import (
66
NBEATSGenericBlock,
77
NBEATSSeasonalBlock,
@@ -11,7 +11,7 @@
1111
__all__ = [
1212
"NBeats",
1313
"NBEATSGenericBlock",
14-
"NBeatsMetadata",
14+
"NBeats_pkg",
1515
"NBEATSSeasonalBlock",
1616
"NBEATSTrendBlock",
1717
]

pytorch_forecasting/models/nbeats/_nbeats.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121

2222
class NBeats(BaseModel):
23+
"""N-Beats model for timeseries forecasting without covariates."""
24+
2325
def __init__(
2426
self,
2527
stack_types: Optional[list[str]] = None,

pytorch_forecasting/models/nbeats/_nbeats_metadata.py renamed to pytorch_forecasting/models/nbeats/_nbeats_pkg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""NBeats metadata container."""
1+
"""NBeats package container."""
22

33
from pytorch_forecasting.models.base._base_object import _BasePtForecaster
44

55

6-
class NBeatsMetadata(_BasePtForecaster):
7-
"""NBeats metadata container."""
6+
class NBeats_pkg(_BasePtForecaster):
7+
"""NBeats package container."""
88

99
_tags = {
1010
"info:name": "NBeats",

0 commit comments

Comments
 (0)