Skip to content

Commit

Permalink
manual fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Sep 22, 2023
1 parent d416473 commit 4145c59
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ repos:
name: Upgrade code

- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
rev: v2.2.5
hooks:
- id: codespell
additional_dependencies: [tomli]
args: ["--write-changes"]

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
Expand Down
8 changes: 4 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@

def _set_root_image_path(page_path: str):
"""Set relative path to be from the root, drop all `../` in images used gallery."""
with open(page_path, encoding="UTF-8") as fo:
body = fo.read()
with open(page_path, encoding="UTF-8") as fopen:
body = fopen.read()
found = re.findall(r" :image: (.*)\.svg", body)
for occur in found:
occur_ = occur.replace("../", "")
body = body.replace(occur, occur_)
with open(page_path, "w", encoding="UTF-8") as fo:
fo.write(body)
with open(page_path, "w", encoding="UTF-8") as fopen:
fopen.write(body)


if SPHINX_FETCH_ASSETS:
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ addopts = [
"--color=yes",
"--disable-pytest-warnings",
]
# ToDo
#filterwarnings = ["error::FutureWarning"]
#filterwarnings = ["error::FutureWarning"] # ToDo
xfail_strict = true
junit_duration_report = "call"

Expand Down Expand Up @@ -48,8 +47,7 @@ blank = true
#skip = '*.py'
# comma separated list of words; waiting for:
# https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603
ignore-words-list = "ROUGE, MAPE, fpr"
count = ''
ignore-words-list = "rouge, mape, wil, fpr"
quiet-level = 3


Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,13 +827,13 @@ def __calculate_recall_precision_scores(
tp_sum = _cumsum(tps, dim=1, dtype=torch.float)
fp_sum = _cumsum(fps, dim=1, dtype=torch.float)
for idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
nd = len(tp)
tp_len = len(tp)
rc = tp / npig
pr = tp / (fp + tp + torch.finfo(torch.float64).eps)
prec = torch.zeros((num_rec_thrs,))
score = torch.zeros((num_rec_thrs,))

recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0
recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if tp_len else 0

# Remove zigzags for AUC
diff_zero = torch.zeros((1,), device=pr.device)
Expand All @@ -843,7 +843,7 @@ def __calculate_recall_precision_scores(
pr += diff

inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
num_inds = inds.argmax() if inds.max() >= nd else num_rec_thrs
num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
score[:num_inds] = det_scores_sorted[inds]
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/retrieval/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class _RetrievalFallOut(RetrievalFallOut):
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> fo = _RetrievalFallOut(top_k=2)
>>> fo(preds, target, indexes=indexes)
>>> rfo = _RetrievalFallOut(top_k=2)
>>> rfo(preds, target, indexes=indexes)
tensor(0.5000)
"""
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/retrieval/fall_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class RetrievalFallOut(RetrievalMetric):
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> fo = RetrievalFallOut(top_k=2)
>>> fo(preds, target, indexes=indexes)
>>> rfo = RetrievalFallOut(top_k=2)
>>> rfo(preds, target, indexes=indexes)
tensor(0.5000)
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_error_on_wrong_input():
with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_fn` to be an callable function.*"):
DummyMetric(dist_sync_fn=[2, 3])

with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` bu.*"):
with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` but.*"):
DummyMetric(compute_on_cpu=None)

with pytest.raises(ValueError, match="Expected keyword argument `sync_on_compute` to be a `bool` but.*"):
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def _assert_dtype_support(
class MetricTester:
"""Test class for all metrics.
Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup once and
Class used for efficiently run alot of parametrized tests in DDP mode. Makes sure that DDP is only setup once and
that pool of processes are used for all tests. All tests should subclass from this and implement a new method called
`test_metric_name` where the method `self.run_metric_test` is called inside.
``test_metric_name`` where the method ``self.run_metric_test`` is called inside.
"""

Expand Down

0 comments on commit 4145c59

Please sign in to comment.