Skip to content

Commit

Permalink
check typos & fixing some... (#2102)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Sep 26, 2023
1 parent 9aecaf4 commit 1ddacab
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 22 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches: [master, "release/*"]
pull_request:
branches: [master, "release/*"]
types: [opened, reopened, ready_for_review, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
tags: ["*"]
pull_request:
branches: ["master", "release/*"]
types: [opened, reopened, ready_for_review, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ repos:
- id: codespell
additional_dependencies: [tomli]
#args: ["--write-changes"]
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.16.12
hooks:
- id: typos
# empty to do not write fixes
args: []
exclude: pyproject.toml

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ ignore-words-list = """
archiv
"""

[tool.typos.default]
extend-ignore-identifiers-re = [
# *sigh* this just isn't worth the cost of fixing
"AttributeID.*Supress.*",
]

[tool.typos.default.extend-identifiers]
# *sigh* this just isn't worth the cost of fixing
MAPE = "MAPE"
WIL = "WIL"
Raison = "Raison"

[tool.typos.default.extend-words]
# Don't correct the surname "Teh"
fpr = "fpr"
mape = "mape"
wil = "wil"



[tool.ruff]
line-length = 120
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BinaryAccuracy(BinaryStatScores):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``ba`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, metric returns a scalar value.
- ``acc`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, metric returns a scalar value.
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/regression/tweedie_deviance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _tweedie_deviance_score_compute(sum_deviance_score: Tensor, num_observations
"""Compute Deviance Score.
Args:
sum_deviance_score: Sum of deviance scores accumalated until now.
sum_deviance_score: Sum of deviance scores accumulated until now.
num_observations: Number of observations encountered until now.
Example:
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/functional/text/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _get_characters(sentence: str, whitespace: bool) -> List[str]:
return list(sentence.strip().replace(" ", ""))


def _separate_word_and_punctiation(word: str) -> List[str]:
def _separate_word_and_punctuation(word: str) -> List[str]:
"""Separates out punctuations from beginning and end of words for chrF.
Adapted from https://github.com/m-popovic/chrF and
Expand All @@ -117,7 +117,7 @@ def _separate_word_and_punctiation(word: str) -> List[str]:
return [word]


def _get_words_and_punctiation(sentence: str) -> List[str]:
def _get_words_and_punctuation(sentence: str) -> List[str]:
"""Separates out punctuations from beginning and end of words for chrF for all words in the sentence.
Args:
Expand All @@ -127,7 +127,7 @@ def _get_words_and_punctiation(sentence: str) -> List[str]:
An aggregated list of separated words and punctuations.
"""
return sum((_separate_word_and_punctiation(word) for word in sentence.strip().split()), [])
return sum((_separate_word_and_punctuation(word) for word in sentence.strip().split()), [])


def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, Dict[Tuple[str, ...], Tensor]]:
Expand Down Expand Up @@ -180,7 +180,7 @@ def _char_and_word_ngrams_counts(
if lowercase:
sentence = sentence.lower()
char_n_grams_counts = _ngram_counts(_get_characters(sentence, whitespace), n_char_order)
word_n_grams_counts = _ngram_counts(_get_words_and_punctiation(sentence), n_word_order)
word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order)
return char_n_grams_counts, word_n_grams_counts

def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) -> Dict[int, Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/text/infolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def __init__(
self.alpha = alpha or 0
self.beta = beta or 0

def __call__(self, preds_distribution: Tensor, target_distribtuion: Tensor) -> Tensor:
def __call__(self, preds_distribution: Tensor, target_distribution: Tensor) -> Tensor:
information_measure_function = getattr(self, f"_calculate_{self.information_measure.value}")
return torch.nan_to_num(information_measure_function(preds_distribution, target_distribtuion))
return torch.nan_to_num(information_measure_function(preds_distribution, target_distribution))

@staticmethod
def _calculate_kl_divergence(preds_distribution: Tensor, target_distribution: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/text/ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _normalize_general_and_western(sentence: str) -> str:
(r">", ">"),
# tokenize punctuation
(r"([{-~[-` -&(-+:-@/])", r" \1 "),
# handle possesives
# handle possessive
(r"'s ", r" 's "),
(r"'s$", r" 's"),
# tokenize period and comma unless preceded by a digit
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/functional/text/wil.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from torchmetrics.functional.text.helper import _edit_distance


def _wil_update(
def _word_info_lost_update(
preds: Union[str, List[str]],
target: Union[str, List[str]],
) -> Tuple[Tensor, Tensor, Tensor]:
"""Update the wil score with the current set of references and predictions.
"""Update the WIL score with the current set of references and predictions.
Args:
preds: Transcription(s) to score as a string or list of strings
Expand Down Expand Up @@ -54,7 +54,7 @@ def _wil_update(
return errors - total, target_total, preds_total


def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> Tensor:
def _word_info_lost_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> Tensor:
"""Compute the Word Information Lost.
Args:
Expand Down Expand Up @@ -90,5 +90,5 @@ def word_information_lost(preds: Union[str, List[str]], target: Union[str, List[
tensor(0.6528)
"""
errors, target_total, preds_total = _wil_update(preds, target)
return _wil_compute(errors, target_total, preds_total)
errors, target_total, preds_total = _word_info_lost_update(preds, target)
return _word_info_lost_compute(errors, target_total, preds_total)
6 changes: 3 additions & 3 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,13 @@ def __init__(
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

mx_num_feets = (num_features, num_features)
mx_num_feats = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feets).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
self.add_state("real_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")

self.add_state("fake_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("fake_features_cov_sum", torch.zeros(mx_num_feets).double(), dist_reduce_fx="sum")
self.add_state("fake_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")

def update(self, imgs: Tensor, real: bool) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/text/wil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from torch import Tensor, tensor

from torchmetrics.functional.text.wil import _wil_compute, _wil_update
from torchmetrics.functional.text.wil import _word_info_lost_compute, _word_info_lost_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -82,14 +82,14 @@ def __init__(

def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None:
"""Update state with predictions and targets."""
errors, target_total, preds_total = _wil_update(preds, target)
errors, target_total, preds_total = _word_info_lost_update(preds, target)
self.errors += errors
self.target_total += target_total
self.preds_total += preds_total

def compute(self) -> Tensor:
"""Calculate the Word Information Lost."""
return _wil_compute(self.errors, self.target_total, self.preds_total)
return _word_info_lost_compute(self.errors, self.target_total, self.preds_total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_compute_on_different_dtype():


def test_error_on_wrong_specified_compute_groups():
"""Test that error is raised if user mis-specify the compute groups."""
"""Test that error is raised if user miss-specify the compute groups."""
with pytest.raises(ValueError, match="Input MulticlassAccuracy in `compute_groups`.*"):
MetricCollection(
MulticlassConfusionMatrix(3),
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/pairwise/test_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_error_on_wrong_shapes(metric):
(partial(pairwise_minkowski_distance, exponent=3), partial(pairwise_distances, metric="minkowski", p=3)),
],
)
def test_precison_case(metric_functional, sk_fn):
def test_precision_case(metric_functional, sk_fn):
"""Test that metrics are robust towars cases where high precision is needed."""
x = torch.tensor([[772.0, 112.0], [772.20001, 112.0]])
res1 = metric_functional(x, zero_diagonal=False)
Expand Down

0 comments on commit 1ddacab

Please sign in to comment.