Skip to content

Commit

Permalink
Merge branch 'master' into ci/PT-2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Dec 14, 2024
2 parents c6e2224 + cd24d2b commit 44ae858
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 26 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ concurrency:

jobs:
check-code:
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
with:
actions-ref: v0.11.8
actions-ref: v0.11.9
extra-typing: "typing"

check-schema:
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
uses: Lightning-AI/utilities/.github/workflows/[email protected].9

check-package:
if: github.event.pull_request.draft == false
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
with:
actions-ref: v0.11.8
actions-ref: v0.11.9
artifact-name: dist-packages-${{ github.sha }}
import-name: "torchmetrics"
testing-matrix: |
Expand All @@ -35,7 +35,7 @@ jobs:
}
check-md-links:
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
with:
base-branch: master
config-file: ".github/markdown-links-config.json"
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ jobs:
- name: Upload coverage to Codecov
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
if: ${{ env.TEST_DIRS != '' }}
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: tests/coverage.xml
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/clear-cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ on:
jobs:
cron-clear:
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
with:
scripts-ref: v0.11.7
scripts-ref: v0.11.9
dry-run: ${{ github.event_name == 'pull_request' }}
pattern: "pip-latest"
age-days: 7

direct-clear:
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
with:
scripts-ref: v0.11.8
scripts-ref: v0.11.9
dry-run: ${{ github.event_name == 'pull_request' }}
pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging
age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging
4 changes: 2 additions & 2 deletions .github/workflows/publish-pkg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
- run: ls -lh dist/
# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
uses: pypa/gh-action-pypi-publish@v1.11.0
uses: pypa/gh-action-pypi-publish@v1.12.2
with:
user: __token__
password: ${{ secrets.test_pypi_password }}
Expand All @@ -94,7 +94,7 @@ jobs:
path: dist
- run: ls -lh dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@v1.11.0
uses: pypa/gh-action-pypi-publish@v1.12.2
with:
user: __token__
password: ${{ secrets.pypi_password }}
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))


---
Expand Down
2 changes: 1 addition & 1 deletion requirements/_docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ sphinx-autobuild ==2024.10.3
sphinx-gallery ==0.18.0

lightning >=1.8.0, <2.5.0
lightning-utilities ==0.11.8
lightning-utilities ==0.11.9
pydantic > 1.0.0, < 3.0.0

# integrations
Expand Down
4 changes: 2 additions & 2 deletions requirements/_doctest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

pytest >=8.0, <9.0
pytest-doctestplus >=1.0, <1.3
pytest-rerunfailures >=13.0, <15.0
pytest-doctestplus >=1.0, <1.4
pytest-rerunfailures >=13.0, <16.0
6 changes: 3 additions & 3 deletions requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ coverage ==7.6.*
codecov ==2.1.13
pytest ==8.3.*
pytest-cov ==6.0.0
pytest-doctestplus ==1.2.1
pytest-rerunfailures ==14.0
pytest-doctestplus ==1.3.0
pytest-rerunfailures ==15.0
pytest-timeout ==2.3.1
pytest-xdist ==3.6.1
phmdoctest ==1.4.0
Expand All @@ -18,5 +18,5 @@ fire ==0.7.*

cloudpickle >1.3, <=3.1.0
scikit-learn ==1.2.*; python_version < "3.9"
scikit-learn ==1.5.*; python_version > "3.8" # we do not use `> =` because of oldest replcement
scikit-learn ==1.6.*; python_version > "3.8" # we do not use `> =` because of oldest replcement
cachier ==3.1.2
2 changes: 1 addition & 1 deletion requirements/multimodal.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

transformers >=4.42.3, <4.47.0
transformers >=4.42.3, <4.48.0
piq <=0.8.0
2 changes: 1 addition & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
nltk >3.8.1, <=3.9.1
tqdm <4.68.0
regex >=2021.9.24, <=2024.11.6
transformers >4.4.0, <4.47.0
transformers >4.4.0, <4.48.0
mecab-python3 >=1.0.6, <1.1.0
ipadic >=1.0.0, <1.1.0
sentencepiece >=0.2.0, <0.3.0
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _multiclass_stat_scores_tensor_validation(
)
if multidim_average != "global" and preds.ndim < 3:
raise ValueError(
"If `preds` have one dimension more than `target`, the shape of `preds` should "
"If `preds` have one dimension more than `target`, the shape of `preds` should be"
" at least 3D when multidim_average is set to `samplewise`"
)

Expand All @@ -303,7 +303,7 @@ def _multiclass_stat_scores_tensor_validation(
)
if multidim_average != "global" and preds.ndim < 2:
raise ValueError(
"When `preds` and `target` have the same shape, the shape of `preds` should "
"When `preds` and `target` have the same shape, the shape of `preds` should be"
" at least 2D when multidim_average is set to `samplewise`"
)
else:
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
)
self.numerator.append(numerator)
self.denominator.append(denominator)
if self.average == "weighted":
self.support.append(support)
self.support.append(support)

def compute(self) -> Tensor:
"""Computes the Dice Score."""
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from sklearn.metrics import f1_score
from torchmetrics import MetricCollection
from torchmetrics.functional.segmentation.dice import dice_score
from torchmetrics.segmentation.dice import DiceScore

Expand Down Expand Up @@ -106,3 +107,32 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr
"input_format": input_format,
},
)


@pytest.mark.parametrize("compute_groups", [True, False])
def test_dice_score_metric_collection(compute_groups: bool, num_batches: int = 4):
"""Test that the metric works within a metric collection with and without compute groups."""
metric_collection = MetricCollection(
metrics={
"DiceScore (micro)": DiceScore(
num_classes=NUM_CLASSES,
average="micro",
),
"DiceScore (macro)": DiceScore(
num_classes=NUM_CLASSES,
average="macro",
),
"DiceScore (weighted)": DiceScore(
num_classes=NUM_CLASSES,
average="weighted",
),
},
compute_groups=compute_groups,
)

for _ in range(num_batches):
metric_collection.update(_inputs1.preds, _inputs1.target)
result = metric_collection.compute()

assert isinstance(result, dict)
assert len(set(metric_collection.keys()) - set(result.keys())) == 0

0 comments on commit 44ae858

Please sign in to comment.