Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check typos & fixing some... #2102

Merged
merged 10 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches: [master, "release/*"]
pull_request:
branches: [master, "release/*"]
types: [opened, reopened, ready_for_review, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
tags: ["*"]
pull_request:
branches: ["master", "release/*"]
types: [opened, reopened, ready_for_review, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ repos:
- id: codespell
additional_dependencies: [tomli]
#args: ["--write-changes"]
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.16.12
hooks:
- id: typos
# empty to do not write fixes
args: []
exclude: pyproject.toml

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ ignore-words-list = """
archiv
"""

[tool.typos.default]
extend-ignore-identifiers-re = [
# *sigh* this just isn't worth the cost of fixing
"AttributeID.*Supress.*",
]

[tool.typos.default.extend-identifiers]
# *sigh* this just isn't worth the cost of fixing
MAPE = "MAPE"
WIL = "WIL"
Raison = "Raison"

[tool.typos.default.extend-words]
# Don't correct the surname "Teh"
fpr = "fpr"
mape = "mape"
wil = "wil"



[tool.ruff]
line-length = 120
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BinaryAccuracy(BinaryStatScores):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``ba`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, metric returns a scalar value.
- ``acc`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, metric returns a scalar value.
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _tweedie_deviance_score_compute(sum_deviance_score: Tensor, num_observations
"""Compute Deviance Score.

Args:
sum_deviance_score: Sum of deviance scores accumalated until now.
sum_deviance_score: Sum of deviance scores accumulated until now.
num_observations: Number of observations encountered until now.

Example:
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/functional/text/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _get_characters(sentence: str, whitespace: bool) -> List[str]:
return list(sentence.strip().replace(" ", ""))


def _separate_word_and_punctiation(word: str) -> List[str]:
def _separate_word_and_punctuation(word: str) -> List[str]:
"""Separates out punctuations from beginning and end of words for chrF.

Adapted from https://github.com/m-popovic/chrF and
Expand All @@ -117,7 +117,7 @@ def _separate_word_and_punctiation(word: str) -> List[str]:
return [word]


def _get_words_and_punctiation(sentence: str) -> List[str]:
def _get_words_and_punctuation(sentence: str) -> List[str]:
"""Separates out punctuations from beginning and end of words for chrF for all words in the sentence.

Args:
Expand All @@ -127,7 +127,7 @@ def _get_words_and_punctiation(sentence: str) -> List[str]:
An aggregated list of separated words and punctuations.

"""
return sum((_separate_word_and_punctiation(word) for word in sentence.strip().split()), [])
return sum((_separate_word_and_punctuation(word) for word in sentence.strip().split()), [])


def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, Dict[Tuple[str, ...], Tensor]]:
Expand Down Expand Up @@ -180,7 +180,7 @@ def _char_and_word_ngrams_counts(
if lowercase:
sentence = sentence.lower()
char_n_grams_counts = _ngram_counts(_get_characters(sentence, whitespace), n_char_order)
word_n_grams_counts = _ngram_counts(_get_words_and_punctiation(sentence), n_word_order)
word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order)
return char_n_grams_counts, word_n_grams_counts

def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) -> Dict[int, Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/text/infolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def __init__(
self.alpha = alpha or 0
self.beta = beta or 0

def __call__(self, preds_distribution: Tensor, target_distribtuion: Tensor) -> Tensor:
def __call__(self, preds_distribution: Tensor, target_distribution: Tensor) -> Tensor:
information_measure_function = getattr(self, f"_calculate_{self.information_measure.value}")
return torch.nan_to_num(information_measure_function(preds_distribution, target_distribtuion))
return torch.nan_to_num(information_measure_function(preds_distribution, target_distribution))

@staticmethod
def _calculate_kl_divergence(preds_distribution: Tensor, target_distribution: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/text/ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _normalize_general_and_western(sentence: str) -> str:
(r">", ">"),
# tokenize punctuation
(r"([{-~[-` -&(-+:-@/])", r" \1 "),
# handle possesives
# handle possessive
(r"'s ", r" 's "),
(r"'s$", r" 's"),
# tokenize period and comma unless preceded by a digit
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/functional/text/wil.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from torchmetrics.functional.text.helper import _edit_distance


def _wil_update(
def _word_info_lost_update(
preds: Union[str, List[str]],
target: Union[str, List[str]],
) -> Tuple[Tensor, Tensor, Tensor]:
"""Update the wil score with the current set of references and predictions.
"""Update the WIL score with the current set of references and predictions.

Args:
preds: Transcription(s) to score as a string or list of strings
Expand Down Expand Up @@ -54,7 +54,7 @@ def _wil_update(
return errors - total, target_total, preds_total


def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> Tensor:
def _word_info_lost_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> Tensor:
"""Compute the Word Information Lost.

Args:
Expand Down Expand Up @@ -90,5 +90,5 @@ def word_information_lost(preds: Union[str, List[str]], target: Union[str, List[
tensor(0.6528)

"""
errors, target_total, preds_total = _wil_update(preds, target)
return _wil_compute(errors, target_total, preds_total)
errors, target_total, preds_total = _word_info_lost_update(preds, target)
return _word_info_lost_compute(errors, target_total, preds_total)
6 changes: 3 additions & 3 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,13 @@ def __init__(
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

mx_num_feets = (num_features, num_features)
mx_num_feats = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feets).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
self.add_state("real_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")

self.add_state("fake_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("fake_features_cov_sum", torch.zeros(mx_num_feets).double(), dist_reduce_fx="sum")
self.add_state("fake_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")

def update(self, imgs: Tensor, real: bool) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/text/wil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from torch import Tensor, tensor

from torchmetrics.functional.text.wil import _wil_compute, _wil_update
from torchmetrics.functional.text.wil import _word_info_lost_compute, _word_info_lost_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -82,14 +82,14 @@ def __init__(

def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None:
"""Update state with predictions and targets."""
errors, target_total, preds_total = _wil_update(preds, target)
errors, target_total, preds_total = _word_info_lost_update(preds, target)
self.errors += errors
self.target_total += target_total
self.preds_total += preds_total

def compute(self) -> Tensor:
"""Calculate the Word Information Lost."""
return _wil_compute(self.errors, self.target_total, self.preds_total)
return _word_info_lost_compute(self.errors, self.target_total, self.preds_total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_compute_on_different_dtype():


def test_error_on_wrong_specified_compute_groups():
"""Test that error is raised if user mis-specify the compute groups."""
"""Test that error is raised if user miss-specify the compute groups."""
with pytest.raises(ValueError, match="Input MulticlassAccuracy in `compute_groups`.*"):
MetricCollection(
MulticlassConfusionMatrix(3),
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/pairwise/test_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_error_on_wrong_shapes(metric):
(partial(pairwise_minkowski_distance, exponent=3), partial(pairwise_distances, metric="minkowski", p=3)),
],
)
def test_precison_case(metric_functional, sk_fn):
def test_precision_case(metric_functional, sk_fn):
"""Test that metrics are robust towars cases where high precision is needed."""
x = torch.tensor([[772.0, 112.0], [772.20001, 112.0]])
res1 = metric_functional(x, zero_diagonal=False)
Expand Down
Loading