diff --git a/qunfold/__init__.py b/qunfold/__init__.py index 6f0d960..706ba28 100644 --- a/qunfold/__init__.py +++ b/qunfold/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.5-rc" +__version__ = "0.1.5-rc2" from .methods.linear.losses import ( LeastSquaresLoss, diff --git a/qunfold/methods/likelihood.py b/qunfold/methods/likelihood.py index 05b5577..a0d8f0f 100644 --- a/qunfold/methods/likelihood.py +++ b/qunfold/methods/likelihood.py @@ -82,13 +82,29 @@ def fit(self, X, y, n_classes=None): self.classifier.fit(X, y) return self def predict(self, X): - pYX_pY = jnp.array(self.classifier.predict_proba(X) / self.p_trn) # P(Y|X) / P_trn(Y) - p_prev = jnp.array(self.p_trn) # the current estimate - for n_iter in range(self.max_iter): - pYX = pYX_pY * p_prev - pYX = pYX / pYX.sum(axis=1, keepdims=True) - p_next = pYX.mean(axis=0) - if jnp.linalg.norm(p_next - p_prev) < self.tol: - return Result(p_next, n_iter+1, "Optimization terminated successfully.") - p_prev = p_next - return Result(p_prev, self.max_iter, "Maximum number of iterations reached.") + return maximize_expectation( + jnp.array(self.classifier.predict_proba(X)), # P(Y|X) + jnp.array(self.p_trn), + self.max_iter, + self.tol, + ) + +def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8): + """The expectation maximization routine that is part of the `ExpectationMaximizer` by Saerens et al. (2002). + + Args: + pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. + p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`. + max_iter (optional): The maximum number of iterations. Defaults to `100`. + tol (optional): The convergence tolerance for the L2 norm between iterations. Defaults to `1e-8`. + """ + pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y) + p_prev = jnp.array(p_trn) # copy p_trn to get the first estimate + for n_iter in range(max_iter): + pYX = pYX_pY * p_prev + pYX = pYX / pYX.sum(axis=1, keepdims=True) + p_next = pYX.mean(axis=0) + if jnp.linalg.norm(p_next - p_prev) < tol: + return Result(p_next, n_iter+1, "Optimization terminated successfully.") + p_prev = p_next + return Result(p_prev, max_iter, "Maximum number of iterations reached.")