Skip to content

Commit

Permalink
commented out cell 9
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiguender committed Oct 9, 2024
1 parent 22cb361 commit 106486b
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 239 deletions.
298 changes: 152 additions & 146 deletions _proc/03_wSAA.ipynb

Large diffs are not rendered by default.

Binary file modified dddex/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/_modidx.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/baseClasses.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/crossValidation.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/levelSetKDEx_multivariate.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/levelSetKDEx_univariate.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/loadData.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/utils.cpython-39.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/wSAA.cpython-39.pyc
Binary file not shown.
5 changes: 5 additions & 0 deletions dddex/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@
'dddex.wSAA.RandomForestWSAA.getWeights': ('wsaa.html#randomforestwsaa.getweights', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA.pointPredict': ('wsaa.html#randomforestwsaa.pointpredict', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA.predict': ('wsaa.html#randomforestwsaa.predict', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA2': ('wsaa.html#randomforestwsaa2', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA2.fit': ('wsaa.html#randomforestwsaa2.fit', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA2.getWeights': ('wsaa.html#randomforestwsaa2.getweights', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA2.pointPredict': ('wsaa.html#randomforestwsaa2.pointpredict', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA2.predict': ('wsaa.html#randomforestwsaa2.predict', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM': ('wsaa.html#randomforestwsaa_lgbm', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM.fit': ('wsaa.html#randomforestwsaa_lgbm.fit', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM.getWeights': ('wsaa.html#randomforestwsaa_lgbm.getweights', 'dddex/wSAA.py'),
Expand Down
135 changes: 134 additions & 1 deletion dddex/wSAA.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .utils import restructureWeightsDataList, restructureWeightsDataList_multivariate

# %% auto 0
__all__ = ['RandomForestWSAA', 'RandomForestWSAA_LGBM', 'SampleAverageApproximation']
__all__ = ['RandomForestWSAA', 'RandomForestWSAA2', 'RandomForestWSAA_LGBM', 'SampleAverageApproximation']

# %% ../nbs/03_wSAA.ipynb 7
class RandomForestWSAA(RandomForestRegressor, BaseWeightsBasedEstimator):
Expand Down Expand Up @@ -100,6 +100,139 @@ def getWeights(self,



return weightsDataList

#---

def predict(self: BaseWeightsBasedEstimator,
X: np.ndarray, # Feature matrix for which conditional quantiles are computed.
probs: list, # Probabilities for which quantiles are computed.
outputAsDf: bool=True, # Determines output. Either a dataframe with probs as columns or a dict with probs as keys.
# Optional. List with length X.shape[0]. Values are multiplied to the predictions
# of each sample to rescale values.
scalingList: list=None,
):

__doc__ = BaseWeightsBasedEstimator.predict.__doc__

return super(MetaEstimatorMixin, self).predict(X = X,
probs = probs,
scalingList = scalingList)

#---

def pointPredict(self,
X: np.ndarray, # Feature Matrix
**kwargs):
"""Original `predict` method to generate point forecasts"""

return super().predict(X = X,
**kwargs)


# %% ../nbs/03_wSAA.ipynb 9
# We attempt here to speed up the computation of the weights by interpreting every single
# tree as a lookup table. This way we don't have to compare the leaf-Indices arrays of each
# training sample and each test sample.
# Unfortunately, despite the fact that this strategy works very well for a single tree,
# it doesn't work for the whole forest because the structure of the output of the lookup
# tables per tree makes it difficult to aggregate the received weights per tree
# over all trees.

class RandomForestWSAA2(RandomForestRegressor, BaseWeightsBasedEstimator):

def fit(self,
X: np.ndarray, # Feature matrix
y: np.ndarray, # Target values
**kwargs):

super().fit(X = X,
y = y,
**kwargs)

self.yTrain = y

leafIndices = self.apply(X)

indicesPerBinPerTree = list()

for indexTree in range(self.n_estimators):
leafIndicesPerTree = leafIndices[:, indexTree]

indicesPerBin = defaultdict(list)

for index, leafIndex in enumerate(leafIndicesPerTree):
indicesPerBin[leafIndex].append(index)

indicesPerBinPerTree.append(indicesPerBin)

self.indicesPerBinPerTree = indicesPerBinPerTree



#---

def getWeights(self,
X: np.ndarray, # Feature matrix for which conditional density estimates are computed.
# Specifies structure of the returned density estimates. One of:
# 'all', 'onlyPositiveWeights', 'summarized', 'cumDistribution', 'cumDistributionSummarized'
outputType: str='onlyPositiveWeights',
# Optional. List with length X.shape[0]. Values are multiplied to the estimated
# density of each sample for scaling purposes.
scalingList: list=None,
) -> list: # List whose elements are the conditional density estimates for the samples specified by `X`.

__doc__ = BaseWeightsBasedEstimator.getWeights.__doc__

#---

leafIndicesPerTree = self.apply(X)

weightsDataList = list()

for leafIndices in leafIndicesPerTree:

weights = np.zeros(self.yTrain.shape[0])

for indexTree in range(len(leafIndices)):
indicesPosWeight = self.indicesPerBinPerTree[indexTree][leafIndices[indexTree]]

weightsNew = np.zeros(self.yTrain.shape[0])
np.put(weightsNew, indicesPosWeight, 1 / len(indicesPosWeight))

weights = weights + weightsNew

weights = weights / len(leafIndices)

weightsPosIndex = np.where(weights > 0)[0]

weightsDataList.append((weights[weightsPosIndex], weightsPosIndex))

#---

# Check if self.yTrain is a 2D array with more than one column.
if len(self.yTrain.shape) > 1:
if self.yTrain.shape[1] > 1:

if not outputType in ['all', 'onlyPositiveWeights', 'summarized']:
raise ValueError("outputType must be one of 'all', 'onlyPositiveWeights', 'summarized' for multivariate y.")

weightsDataList = restructureWeightsDataList_multivariate(weightsDataList = weightsDataList,
outputType = outputType,
y = self.yTrain,
scalingList = scalingList,
equalWeights = False)

else:
weightsDataList = restructureWeightsDataList(weightsDataList = weightsDataList,
outputType = outputType,
y = self.yTrain,
scalingList = scalingList,
equalWeights = False)




return weightsDataList

#---
Expand Down
Loading

0 comments on commit 106486b

Please sign in to comment.