Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement clustering accuracy #2767

Open
moetayuko opened this issue Oct 3, 2024 · 4 comments · May be fixed by #2777
Open

Implement clustering accuracy #2767

moetayuko opened this issue Oct 3, 2024 · 4 comments · May be fixed by #2777
Labels
enhancement New feature or request New metric
Milestone

Comments

@moetayuko
Copy link

🚀 Feature

Motivation

Clustering accuracy is a popular metric. In addition to classification accuracy, it employs the Hungarian algorithm to align the predicted pseudo labels and the ground truth labels.

Current implementations of clustering accuracy use either scipy.optimize.linear_sum_assignment or the munkres package for Hungarian. I'm not sure if this is allowed for torchmetrics, and a custom implementation needs to be added if not.

Pitch

Implement clustering accuracy in torchmetrics.clustering

@moetayuko moetayuko added the enhancement New feature or request label Oct 3, 2024
Copy link

github-actions bot commented Oct 3, 2024

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

@moetayuko thanks for opening this issue. Do you have a reference to a source (possible research paper) where they describe the metric in details?

@SkafteNicki SkafteNicki added this to the future milestone Oct 8, 2024
@SkafteNicki
Copy link
Member

@moetayuko thanks for the references, it really helped understanding how the metric is intended to work.
Hopefully, I have time to fully implement the metric in the next couple of days. I have already the logic figured out using https://github.com/ivan-chai/torch-linear-assignment for solving the linear sum assignment problem:

from torchmetrics.functional.classification import multiclass_confusion_matrix
import torch
# pip install git+https://github.com/ivan-chai/torch-linear-assignment.git@main
from torch_linear_assignment import batch_linear_assignment

preds = torch.tensor([0, 0, 1, 1])
target = torch.tensor([1, 1, 0, 0])

confmat = multiclass_confusion_matrix(preds, target, num_classes=5)
print(confmat)

confmat = confmat[None]

assignment = batch_linear_assignment(confmat.max() - confmat)
print(assignment)

confmat = confmat[0]

tps = confmat[torch.arange(confmat.size(0)), assignment.flatten()]

acc = tps.sum() / len(preds)
print(acc)

@SkafteNicki SkafteNicki linked a pull request Oct 12, 2024 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request New metric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants