Skip to content

Commit

Permalink
Merge branch 'master' into update-pth-versions
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Aug 5, 2024
2 parents 347d11d + 8be85b4 commit b680b0c
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 10 deletions.
13 changes: 6 additions & 7 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
architecture: "x64"

- name: Get year & week number
Expand All @@ -50,7 +50,7 @@ jobs:
- name: Get pip cache dir
id: pip-cache
run: |
pip3 install -U pip
pip3 install -U "pip<24"
echo "pip_cache=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash -l {0}

Expand All @@ -70,10 +70,9 @@ jobs:
pip install mkl==2021.4.0
## Install torch & xla and torchvision
pip install --pre https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch-nightly-cp39-cp39-linux_x86_64.whl
pip install --pre https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-nightly-cp39-cp39-linux_x86_64.whl
pip install --pre https://storage.googleapis.com/tpu-pytorch/wheels/colab/torchvision-nightly-cp39-cp39-linux_x86_64.whl
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
# Check installation
python -c "import torch"
Expand Down
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ Complete list of metrics
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metric_group.MetricGroup
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
Expand Down
4 changes: 2 additions & 2 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
_check_signature(process_function, "process_function", self, None)

# generator provided by self._internal_run_as_gen
self._internal_run_generator: Optional[Generator] = None
self._internal_run_generator: Optional[Generator[Any, None, State]] = None

def register_events(
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
Expand Down Expand Up @@ -951,7 +951,7 @@ def _internal_run(self) -> State:
self._internal_run_generator = None
return out.value

def _internal_run_as_gen(self) -> Generator:
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
self._init_timers(self.state)
try:
Expand Down
1 change: 1 addition & 0 deletions ignite/handlers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, *args: Any, **kwargs: Any):
"You man install wandb with the command:\n pip install wandb\n"
)
if kwargs.get("init", True):
kwargs.pop("init", None)
wandb.init(*args, **kwargs)

def __getattr__(self, attr: Any) -> Any:
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metric_group import MetricGroup
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.mutual_information import MutualInformation
Expand All @@ -41,6 +42,7 @@
"Metric",
"Accuracy",
"Loss",
"MetricGroup",
"MetricsLambda",
"MeanAbsoluteError",
"MeanPairwiseDistance",
Expand Down
54 changes: 54 additions & 0 deletions ignite/metrics/metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Callable, Dict, Sequence

import torch

from ignite.metrics import Metric


class MetricGroup(Metric):
"""
A class for grouping metrics so that user could manage them easier.
Args:
metrics: a dictionary of names to metric instances.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. `output_transform` of each metric in the group is also
called upon its update.
Examples:
We construct a group of metrics, attach them to the engine at once and retrieve their result.
.. code-block:: python
import torch
metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())})
metric_group.attach(default_evaluator, "eval_metrics")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
# Metrics individually available in `state.metrics`
state.metrics["acc"], state.metrics["precision"], state.metrics["loss"]
# And also altogether
state.metrics["eval_metrics"]
"""

_state_dict_all_req_keys = ("metrics",)

def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
self.metrics = metrics
super(MetricGroup, self).__init__(output_transform=output_transform)

def reset(self) -> None:
for m in self.metrics.values():
m.reset()

def update(self, output: Sequence[torch.Tensor]) -> None:
for m in self.metrics.values():
m.update(m._output_transform(output))

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.metrics.items()}
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ scikit-image
py-rouge
# temporary fix for python=3.12 and v3.8.1
# nltk
git+https://github.com/nltk/nltk
git+https://github.com/nltk/nltk@aba99c8
# Examples dependencies
pandas
gymnasium
Expand Down
118 changes: 118 additions & 0 deletions tests/ignite/metrics/test_metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import torch

from ignite import distributed as idist
from ignite.engine import Engine
from ignite.metrics import Accuracy, MetricGroup, Precision

torch.manual_seed(41)


def test_update():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_output_transform():
def drop_first(output):
y_pred, y = output
return (y_pred[1:], y[1:])

precision = Precision(output_transform=drop_first)
accuracy = Accuracy(output_transform=drop_first)

group = MetricGroup(
{"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)}
)

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update(drop_first(drop_first((y_pred, y))))
accuracy.update(drop_first(drop_first((y_pred, y))))
group.update(drop_first((y_pred, y)))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_compute():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

for _ in range(3):
y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()}

precision.reset()
accuracy.reset()
group.reset()

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
rank = idist.get_rank()
torch.manual_seed(12 + rank)

n_epochs = 3
n_iters = 5
batch_size = 10
device = idist.device()

y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device)
y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device)

def update(_, i):
return (
y_pred[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

precision = Precision()
precision.attach(engine, "precision")

accuracy = Accuracy()
accuracy.attach(engine, "accuracy")

group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()})
group.attach(engine, "eval_metrics")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "eval_metrics" in engine.state.metrics
assert "eval_metrics.accuracy" in engine.state.metrics
assert "eval_metrics.precision" in engine.state.metrics

assert engine.state.metrics["eval_metrics"] == {
"eval_metrics.accuracy": engine.state.metrics["accuracy"],
"eval_metrics.precision": engine.state.metrics["precision"],
}
assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"]
assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]

0 comments on commit b680b0c

Please sign in to comment.