From 4145c59a51b3231355d43c59b3684cf1e70c2492 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 22 Sep 2023 22:38:22 +0200 Subject: [PATCH] manual fixes --- .pre-commit-config.yaml | 3 ++- docs/source/conf.py | 8 ++++---- pyproject.toml | 6 ++---- src/torchmetrics/detection/_mean_ap.py | 6 +++--- src/torchmetrics/retrieval/_deprecated.py | 4 ++-- src/torchmetrics/retrieval/fall_out.py | 4 ++-- tests/unittests/bases/test_metric.py | 2 +- tests/unittests/helpers/testers.py | 4 ++-- 8 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae2a85c8969..63fcf1bdf80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,10 +45,11 @@ repos: name: Upgrade code - repo: https://github.com/codespell-project/codespell - rev: v2.2.4 + rev: v2.2.5 hooks: - id: codespell additional_dependencies: [tomli] + args: ["--write-changes"] - repo: https://github.com/PyCQA/docformatter rev: v1.7.5 diff --git a/docs/source/conf.py b/docs/source/conf.py index d8dde4988f6..3f116bd77d3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -67,14 +67,14 @@ def _set_root_image_path(page_path: str): """Set relative path to be from the root, drop all `../` in images used gallery.""" - with open(page_path, encoding="UTF-8") as fo: - body = fo.read() + with open(page_path, encoding="UTF-8") as fopen: + body = fopen.read() found = re.findall(r" :image: (.*)\.svg", body) for occur in found: occur_ = occur.replace("../", "") body = body.replace(occur, occur_) - with open(page_path, "w", encoding="UTF-8") as fo: - fo.write(body) + with open(page_path, "w", encoding="UTF-8") as fopen: + fopen.write(body) if SPHINX_FETCH_ASSETS: diff --git a/pyproject.toml b/pyproject.toml index 2c3c71ffed0..cfff089737b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,7 @@ addopts = [ "--color=yes", "--disable-pytest-warnings", ] -# ToDo -#filterwarnings = ["error::FutureWarning"] +#filterwarnings = ["error::FutureWarning"] # ToDo xfail_strict = true junit_duration_report = "call" @@ -48,8 +47,7 @@ blank = true #skip = '*.py' # comma separated list of words; waiting for: # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603 -ignore-words-list = "ROUGE, MAPE, fpr" -count = '' +ignore-words-list = "rouge, mape, wil, fpr" quiet-level = 3 diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index bd9717ff73a..7ff44165b8c 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -827,13 +827,13 @@ def __calculate_recall_precision_scores( tp_sum = _cumsum(tps, dim=1, dtype=torch.float) fp_sum = _cumsum(fps, dim=1, dtype=torch.float) for idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): - nd = len(tp) + tp_len = len(tp) rc = tp / npig pr = tp / (fp + tp + torch.finfo(torch.float64).eps) prec = torch.zeros((num_rec_thrs,)) score = torch.zeros((num_rec_thrs,)) - recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0 + recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if tp_len else 0 # Remove zigzags for AUC diff_zero = torch.zeros((1,), device=pr.device) @@ -843,7 +843,7 @@ def __calculate_recall_precision_scores( pr += diff inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False) - num_inds = inds.argmax() if inds.max() >= nd else num_rec_thrs + num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs inds = inds[:num_inds] prec[:num_inds] = pr[inds] score[:num_inds] = det_scores_sorted[inds] diff --git a/src/torchmetrics/retrieval/_deprecated.py b/src/torchmetrics/retrieval/_deprecated.py index 5e4e9ca4c11..45bff9431f2 100644 --- a/src/torchmetrics/retrieval/_deprecated.py +++ b/src/torchmetrics/retrieval/_deprecated.py @@ -19,8 +19,8 @@ class _RetrievalFallOut(RetrievalFallOut): >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) - >>> fo = _RetrievalFallOut(top_k=2) - >>> fo(preds, target, indexes=indexes) + >>> rfo = _RetrievalFallOut(top_k=2) + >>> rfo(preds, target, indexes=indexes) tensor(0.5000) """ diff --git a/src/torchmetrics/retrieval/fall_out.py b/src/torchmetrics/retrieval/fall_out.py index 7c4f031e1e6..5df0a877bd3 100644 --- a/src/torchmetrics/retrieval/fall_out.py +++ b/src/torchmetrics/retrieval/fall_out.py @@ -72,8 +72,8 @@ class RetrievalFallOut(RetrievalMetric): >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) - >>> fo = RetrievalFallOut(top_k=2) - >>> fo(preds, target, indexes=indexes) + >>> rfo = RetrievalFallOut(top_k=2) + >>> rfo(preds, target, indexes=indexes) tensor(0.5000) """ diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 5e6a1c2e979..c9d1b8bd5ea 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -42,7 +42,7 @@ def test_error_on_wrong_input(): with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_fn` to be an callable function.*"): DummyMetric(dist_sync_fn=[2, 3]) - with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` bu.*"): + with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` but.*"): DummyMetric(compute_on_cpu=None) with pytest.raises(ValueError, match="Expected keyword argument `sync_on_compute` to be a `bool` but.*"): diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 3740e7bf335..5954e4d0801 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -320,9 +320,9 @@ def _assert_dtype_support( class MetricTester: """Test class for all metrics. - Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup once and + Class used for efficiently run alot of parametrized tests in DDP mode. Makes sure that DDP is only setup once and that pool of processes are used for all tests. All tests should subclass from this and implement a new method called - `test_metric_name` where the method `self.run_metric_test` is called inside. + ``test_metric_name`` where the method ``self.run_metric_test`` is called inside. """