Skip to content

Commit

Permalink
Merge branch 'master' into ci/labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Aug 25, 2023
2 parents 37dfa44 + 1b184f4 commit 2e2fd9c
Show file tree
Hide file tree
Showing 18 changed files with 588 additions and 20 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)


### Changed
Expand All @@ -26,7 +26,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)


- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)


## [1.1.0] - 2023-08-22
Expand Down
Binary file modified docs/source/_static/images/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 13 additions & 12 deletions docs/source/_static/images/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions docs/source/clustering/mutual_info_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Mutual Information Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg
:tags: Clustering

.. include:: ../links.rst

########################
Mutual Information Score
########################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.MutualInfoScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.mutual_info_score
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ Or directly from conda

classification/*

.. toctree::
:maxdepth: 2
:name: clustering
:caption: Clustering
:glob:

clustering/*

.. toctree::
:maxdepth: 2
:name: detection
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@
.. _CIOU: https://arxiv.org/abs/2005.03572
.. _DIOU: https://arxiv.org/abs/1911.08287v1
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@
numpy >1.20.0
torch >=1.8.1, <=2.0.1
typing-extensions; python_version < '3.9'
packaging # hotfix for utils, can be dropped with lit-utils >=0.5
lightning-utilities >=0.8.0, <0.10.0
18 changes: 18 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.mutual_info_score import MutualInfoScore

__all__ = [
"MutualInfoScore",
]
125 changes: 125 additions & 0 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MutualInfoScore.plot"]


class MutualInfoScore(Metric):
r"""Compute `Mutual Information Score`_.
.. math::
MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N}
\log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}}
Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions,
:math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and
:math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`.
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields
the same mutual information score.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)``
- ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)``
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score
Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> mi_score = MutualInfoScore()
>>> mi_score(preds, target)
tensor(0.5004)
"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Compute mutual information over state."""
return mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> metric = MutualInfoScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> metric = MutualInfoScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
"""
return self._plot(val, ax)
16 changes: 16 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score

__all__ = ["mutual_info_score"]
79 changes: 79 additions & 0 deletions src/torchmetrics/functional/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels


def _mutual_info_score_update(preds: Tensor, target: Tensor) -> Tensor:
"""Update and return variables required to compute the mutual information score.
Args:
preds: predicted class labels
target: ground truth class labels
Returns:
contingency: contingency matrix
"""
check_cluster_labels(preds, target)
return calculate_contingency_matrix(preds, target)


def _mutual_info_score_compute(contingency: Tensor) -> Tensor:
"""Compute the mutual information score based on the contingency matrix.
Args:
contingency: contingency matrix
Returns:
mutual_info: mutual information score
"""
n = contingency.sum()
u = contingency.sum(dim=1)
v = contingency.sum(dim=0)

# Check if preds or target labels only have one cluster
if u.size() == 1 or v.size() == 1:
return tensor(0.0)

# Find indices of nonzero values in U and V
nzu, nzv = torch.nonzero(contingency, as_tuple=True)
contingency = contingency[nzu, nzv]

# Calculate MI using entries corresponding to nonzero contingency matrix entries
log_outer = torch.log(u[nzu]) + torch.log(v[nzv])
mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer)
return mutual_info.sum()


def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute mutual information between two clusterings.
Args:
preds: predicted classes
target: ground truth classes
Example:
>>> from torchmetrics.functional.clustering import mutual_info_score
>>> target = torch.tensor([0, 3, 2, 2, 1])
>>> preds = torch.tensor([1, 3, 2, 0, 1])
>>> mutual_info_score(preds, target)
tensor(1.0549)
"""
contingency = _mutual_info_score_update(preds, target)
return _mutual_info_score_compute(contingency)
Loading

0 comments on commit 2e2fd9c

Please sign in to comment.