Skip to content

Commit

Permalink
Include an OriginalRepresentation that is just a dummy doing nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Oct 11, 2024
1 parent 86af4d4 commit 59f1b62
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/source/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ You can use the `CombinedLoss` to create arbitrary, weighted sums of losses and
.. autoclass:: qunfold.LaplacianKernelRepresentation
.. autoclass:: qunfold.GaussianRFFKernelRepresentation
.. autoclass:: qunfold.OriginalRepresentation
```


Expand Down
3 changes: 2 additions & 1 deletion qunfold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.5-rc4"
__version__ = "0.1.5-rc5"

from .methods.linear.losses import (
LeastSquaresLoss,
Expand All @@ -19,6 +19,7 @@
LaplacianKernelRepresentation,
GaussianKernelRepresentation,
GaussianRFFKernelRepresentation,
OriginalRepresentation,
)

from .methods.linear import (
Expand Down
14 changes: 14 additions & 0 deletions qunfold/methods/linear/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,17 @@ def _transform_after_preprocessor(self, X):
Xw = X @ self.w.T
C = np.concatenate((np.cos(Xw), np.sin(Xw)), axis=1)
return np.sqrt(2 / self.n_rff) * np.mean(C, axis=0)

class OriginalRepresentation(AbstractRepresentation):
"""A dummy representation that simply returns the data as it is."""
def fit_transform(self, X, y, average=True, n_classes=None):
check_y(y, n_classes)
self.p_trn = class_prevalences(y, n_classes)
if average:
return np.array([ X[y==c].mean(axis=0) for c in range(len(self.p_trn)) ]).T # = M
return X, y
def transform(self, X, average=True):
n_classes = len(self.p_trn)
if average:
return X.mean(axis=0) # = q
return X
8 changes: 4 additions & 4 deletions qunfold/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def test_methods(self):
p_run = qunfold.RUN(qunfold.ClassRepresentation(rf), tau=1e6).fit(X_trn, y_trn).predict(X_tst)
p_hdy = qunfold.HDy(rf, 3).fit(X_trn, y_trn).predict(X_tst)
p_edy = qunfold.EDy(rf).fit(X_trn, y_trn).predict(X_tst)
p_cstm = qunfold.LinearMethod( # a custom method
p_orig = qunfold.LinearMethod( # a custom method
qunfold.LeastSquaresLoss(),
qunfold.HistogramRepresentation(3)
qunfold.OriginalRepresentation()
).fit(X_trn, y_trn, n_classes).predict(X_tst)
p_kmme = qunfold.KMM('energy').fit(X_trn, y_trn).predict(X_tst)
p_rff = qunfold.KMM('rff').fit(X_trn, y_trn).predict(X_tst)
Expand All @@ -74,8 +74,8 @@ def test_methods(self):
f" {p_hdy.nit} it.; {p_hdy.message}",
f" p_edy = {p_edy} (RAE {qp.error.rae(p_edy, p_tst):.4f})",
f" {p_edy.nit} it.; {p_edy.message}",
f" p_cstm = {p_cstm} (RAE {qp.error.rae(p_cstm, p_tst):.4f})",
f" {p_cstm.nit} it.; {p_cstm.message}",
f" p_orig = {p_orig} (RAE {qp.error.rae(p_orig, p_tst):.4f})",
f" {p_orig.nit} it.; {p_orig.message}",
f" p_kmme = {p_kmme} (RAE {qp.error.rae(p_kmme, p_tst):.4f})",
f" {p_kmme.nit} it.; {p_kmme.message}",
f" p_rff = {p_rff} (RAE {qp.error.rae(p_rff, p_tst):.4f})",
Expand Down

0 comments on commit 59f1b62

Please sign in to comment.