Skip to content

Commit

Permalink
Allow to omit_result_conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Sep 21, 2024
1 parent ad63445 commit aa58316
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion qunfold/methods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ def predict(self, X):
self.tol,
)

def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8):
def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8, omit_result_conversion=False):
"""The expectation maximization routine that is part of the `ExpectationMaximizer` by Saerens et al. (2002).
Args:
pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. Multiple bags, with shape `(n_bags, n_items_per_bag, n_classes)` are also supported.
p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`.
max_iter (optional): The maximum number of iterations. Defaults to `100`.
tol (optional): The convergence tolerance for the L2 norm between iterations or None to disable convergence checks. Defaults to `1e-8`.
omit_result_conversion (optional): Whether to omit the conversion into a `Result` type.
"""
pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y)
p_prev = p_trn
Expand All @@ -106,12 +107,16 @@ def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8):
p_next = pYX.mean(axis=-2, keepdims=True) # shape (n_bags, [1], n_classes)
if tol is not None:
if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol):
if omit_result_conversion:
return jnp.squeeze(p_next, axis=-2)
return Result(
jnp.squeeze(p_next, axis=-2),
n_iter+1,
"Optimization terminated successfully."
)
p_prev = p_next
if omit_result_conversion:
return jnp.squeeze(p_prev, axis=-2)
return Result(
jnp.squeeze(p_prev, axis=-2),
max_iter,
Expand Down

0 comments on commit aa58316

Please sign in to comment.