Skip to content

Commit

Permalink
Validation Parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
JrtPec committed Dec 6, 2023
1 parent a056692 commit 7714adf
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
3 changes: 2 additions & 1 deletion openenergyid/mvlr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Multi-variable linear regression (MVLR) module."""

from .mvlr import MultiVariableLinearRegression, find_best_mvlr
from .mvlr import MultiVariableLinearRegression, find_best_mvlr, ValidationParameters
from .models import IndependentVariable, MultiVariableRegressionResult

__all__ = [
"MultiVariableLinearRegression",
"MultiVariableRegressionResult",
"IndependentVariable",
"find_best_mvlr",
"ValidationParameters",
]
20 changes: 14 additions & 6 deletions openenergyid/mvlr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from .mvlr import MultiVariableLinearRegression


class ConfidenceInterval(BaseModel):
"""Confidence interval for a coefficient."""

confidence: float
lower: float
upper: float


class IndependentVariable(BaseModel):
"""Independent variable for a multivariable linear regression model."""

Expand All @@ -18,7 +26,7 @@ class IndependentVariable(BaseModel):
t_stat: Optional[float] = None
p_value: Optional[float] = None
std_err: Optional[float] = None
confidence_interval: Optional[dict[str, float]] = None
confidence_interval: Optional[ConfidenceInterval] = None

@classmethod
def from_fit(cls, fit: fm.ols, name: str) -> "IndependentVariable":
Expand All @@ -29,11 +37,11 @@ def from_fit(cls, fit: fm.ols, name: str) -> "IndependentVariable":
t_stat=fit.tvalues[name],
p_value=fit.pvalues[name],
std_err=fit.bse[name],
confidence_interval={
"confidence": 0.95,
"lower": fit.conf_int().transpose()[name][0],
"upper": fit.conf_int().transpose()[name][1],
},
confidence_interval=ConfidenceInterval(
confidence=0.95,
lower=fit.conf_int().transpose()[name][0],
upper=fit.conf_int().transpose()[name][1],
),
)


Expand Down
36 changes: 22 additions & 14 deletions openenergyid/mvlr/mvlr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
import statsmodels.formula.api as fm
from patsy import LookupFactor, ModelDesc, Term # pylint: disable=no-name-in-module
from statsmodels.sandbox.regression.predstd import wls_prediction_std
Expand All @@ -12,6 +13,20 @@
from .helpers import resample_input_data


class ValidationParameters(BaseModel):
"""Parameters for validation of a multivariable linear regression model."""

rsquared: float = Field(
0.75, ge=0, le=1, description="Minimum acceptable value for the adjusted R-squared"
)
f_pvalue: float = Field(
0.05, ge=0, le=1, description="Maximum acceptable value for the F-statistic"
)
pvalues: float = Field(
0.05, ge=0, le=1, description="Maximum acceptable value for the p-values of the t-statistic"
)


class MultiVariableLinearRegression:
"""Multi-variable linear regression.
Expand Down Expand Up @@ -41,7 +56,7 @@ def __init__(
confint: float = 0.95,
cross_validation: bool = False,
allow_negative_predictions: bool = False,
validation_params: dict = None,
validation_params: ValidationParameters = None,
granularity: Granularity = None,
):
"""Parameters
Expand All @@ -65,15 +80,8 @@ def __init__(
If True, allow predictions to be negative.
For gas consumption or PV production, this is not physical
so allow_negative_predictions should be False
validation_params : dict, default=None
Dictionary with parameters to validate the model.
The following parameters are supported:
- "rsquared": float, default=0.75
Minimum acceptable value for the adjusted R-squared
- "f_pvalue": float, default=0.05
Maximum acceptable value for the F-statistic
- "pvalues": float, default=0.05
Maximum acceptable value for the p-values of the t-statistic
validation_params : ValidationParameters, default=None
Parameters to validate the model.
"""
self.data = data.copy()
if y not in self.data.columns:
Expand All @@ -87,7 +95,7 @@ def __init__(
self.confint = confint
self.cross_validation = cross_validation
self.allow_negative_predictions = allow_negative_predictions
self.validation_params = validation_params or {}
self.validation_params = validation_params or ValidationParameters()
self.granularity = granularity
self._fit = None
self._list_of_fits = []
Expand Down Expand Up @@ -400,16 +408,16 @@ def is_valid(self) -> bool:
-------
bool: True if the model is valid, False otherwise.
"""
if self.fit.rsquared_adj < self.validation_params.get("rsquared", 0.75):
if self.fit.rsquared_adj < self.validation_params.rsquared:
return False

if self.fit.f_pvalue > self.validation_params.get("f_pvalue", 0.05):
if self.fit.f_pvalue > self.validation_params.f_pvalue:
return False

param_keys = self.fit.pvalues.keys().tolist()
param_keys.remove("Intercept")
for k in param_keys:
if self.fit.pvalues[k] > self.validation_params.get("pvalues", 0.05):
if self.fit.pvalues[k] > self.validation_params.pvalues:
return False

return True
Expand Down

0 comments on commit 7714adf

Please sign in to comment.