diff --git a/pyproject.toml b/pyproject.toml index df8b8891776..dc393ed6708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,6 @@ ignore = [ "D104", # todo: Missing docstring in public package "D107", # Missing docstring in `__init__` "ANN101", # Missing type annotation for `self` in method - "ANN102", # Missing type annotation for `cls` in classmethod "S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo "B905", # todo: `zip()` without an explicit `strict=` parameter diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 117a89cb667..6c84ec47b97 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -484,7 +484,7 @@ class Accuracy(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["Accuracy"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 428c961ec1d..4e7c3edf982 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -502,7 +502,7 @@ class AUROC(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["AUROC"], task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 569c00b73d8..be025da5d02 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -512,7 +512,7 @@ class AveragePrecision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["AveragePrecision"], task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index d939bf99daa..aaf4794ec3f 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -368,7 +368,7 @@ class CalibrationError(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["CalibrationError"], task: Literal["binary", "multiclass"], n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 213fa261ea3..431af8325ee 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -312,7 +312,7 @@ class labels. """ def __new__( # type: ignore[misc] - cls, + cls: Type["CohenKappa"], task: Literal["binary", "multiclass"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index b01007cfd94..e1d74c79bc0 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Type import torch from torch import Tensor @@ -505,7 +505,7 @@ class ConfusionMatrix(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["ConfusionMatrix"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 20ab5344373..7679bb3eab2 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union import torch from torch import Tensor @@ -393,7 +393,7 @@ class ExactMatch(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["ExactMatch"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 0386a8b2eb9..48098a7d0dd 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -1053,7 +1053,7 @@ class FBetaScore(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["FBetaScore"], task: Literal["binary", "multiclass", "multilabel"], beta: float = 1.0, threshold: float = 0.5, @@ -1115,7 +1115,7 @@ class F1Score(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["F1Score"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 340a647aa8d..6b3bacd45c7 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -493,7 +493,7 @@ class HammingDistance(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["HammingDistance"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 0530fd60c81..fbdf3ada50a 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union import torch from torch import Tensor @@ -351,7 +351,7 @@ class HingeLoss(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["HingeLoss"], task: Literal["binary", "multiclass"], num_classes: Optional[int] = None, squared: bool = False, diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index a4f07febb34..c2f1138cc9f 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -442,7 +442,7 @@ class JaccardIndex(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["JaccardIndex"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index cc86ccfc9bb..8fab20badf5 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -390,7 +390,7 @@ class MatthewsCorrCoef(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["MatthewsCorrCoef"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index ebaed8896fd..18536c6a0c2 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from torch import Tensor from typing_extensions import Literal @@ -482,7 +482,7 @@ class PrecisionAtFixedRecall(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["PrecisionAtFixedRecall"], task: Literal["binary", "multiclass", "multilabel"], min_recall: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index d221584c336..9b658e8dac7 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -925,7 +925,7 @@ class Precision(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["Precision"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, @@ -988,7 +988,7 @@ class Recall(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["Recall"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 9996cfd683e..c9e42d41b71 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Type, Union import torch from torch import Tensor @@ -650,7 +650,7 @@ class PrecisionRecallCurve(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["PrecisionRecallCurve"], task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index bfbf5d68cd9..c9936f09411 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from torch import Tensor from typing_extensions import Literal @@ -481,7 +481,7 @@ class RecallAtFixedPrecision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["RecallAtFixedPrecision"], task: Literal["binary", "multiclass", "multilabel"], min_precision: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index a391cd2046f..69400953d6d 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Type, Union from torch import Tensor from typing_extensions import Literal @@ -559,7 +559,7 @@ class ROC(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["ROC"], task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index d9124968cfc..ec7a56b0fc7 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Type, Union from torch import Tensor from typing_extensions import Literal @@ -474,7 +474,7 @@ class Specificity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["Specificity"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity_sensitivity.py b/src/torchmetrics/classification/specificity_sensitivity.py index c199f4dc7ad..bf4ba1a06b5 100644 --- a/src/torchmetrics/classification/specificity_sensitivity.py +++ b/src/torchmetrics/classification/specificity_sensitivity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Type, Union from torch import Tensor from typing_extensions import Literal @@ -343,7 +343,7 @@ class SpecificityAtSensitivity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls, + cls: Type["SpecificityAtSensitivity"], task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, thresholds: Optional[Union[int, List[float], Tensor]] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index ce671e202bf..021c8582d30 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch from torch import Tensor @@ -513,7 +513,7 @@ class StatScores(_ClassificationTaskWrapper): """ def __new__( - cls, + cls: Type["StatScores"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index 69e2ae882f3..33c3afb0beb 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -41,7 +41,7 @@ import re import tempfile from functools import partial -from typing import Any, ClassVar, Dict, Optional, Sequence +from typing import Any, ClassVar, Dict, Optional, Sequence, Type import torch from torch import Tensor, tensor @@ -58,7 +58,7 @@ ) AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200") -_Tokenizers_list = Literal["none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200"] +_TokenizersLiteral = Literal["none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200"] _UCODE_RANGES = ( ("\u3400", "\u4db5"), # CJK Unified Ideographs Extension A, release 3.0 @@ -143,7 +143,7 @@ class _SacreBLEUTokenizer: # Keep it as class variable to avoid initializing over and over again sentencepiece_processors: ClassVar[Dict[str, Optional[Any]]] = {"flores101": None, "flores200": None} - def __init__(self, tokenize: _Tokenizers_list, lowercase: bool = False) -> None: + def __init__(self, tokenize: _TokenizersLiteral, lowercase: bool = False) -> None: self._check_tokenizers_validity(tokenize) self.tokenize_fn = getattr(self, self._TOKENIZE_FN[tokenize]) @@ -154,7 +154,12 @@ def __call__(self, line: str) -> Sequence[str]: return self._lower(tokenized_line, self.lowercase).split() @classmethod - def tokenize(cls, line: str, tokenize: _Tokenizers_list, lowercase: bool = False) -> Sequence[str]: + def tokenize( + cls: Type["_SacreBLEUTokenizer"], + line: str, + tokenize: _TokenizersLiteral, + lowercase: bool = False, + ) -> Sequence[str]: cls._check_tokenizers_validity(tokenize) tokenize_fn = getattr(cls, cls._TOKENIZE_FN[tokenize]) @@ -162,7 +167,7 @@ def tokenize(cls, line: str, tokenize: _Tokenizers_list, lowercase: bool = False return cls._lower(tokenized_line, lowercase).split() @classmethod - def _tokenize_regex(cls, line: str) -> str: + def _tokenize_regex(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Post-processing tokenizer for `13a` and `zh` tokenizers. Args: @@ -191,7 +196,7 @@ def _is_chinese_char(uchar: str) -> bool: return any(start <= uchar <= end for start, end in _UCODE_RANGES) @classmethod - def _tokenize_base(cls, line: str) -> str: + def _tokenize_base(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes an input line with the tokenizer. Args: @@ -204,7 +209,7 @@ def _tokenize_base(cls, line: str) -> str: return line @classmethod - def _tokenize_13a(cls, line: str) -> str: + def _tokenize_13a(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a line using a relatively minimal tokenization that is equivalent to mteval-v13a, used by WMT. Args: @@ -228,7 +233,7 @@ def _tokenize_13a(cls, line: str) -> str: return cls._tokenize_regex(f" {line} ") @classmethod - def _tokenize_zh(cls, line: str) -> str: + def _tokenize_zh(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenization of Chinese text. This is done in two steps: separate each Chinese characters (by utf-8 encoding) and afterwards tokenize the @@ -256,7 +261,7 @@ def _tokenize_zh(cls, line: str) -> str: return cls._tokenize_regex(line_in_chars) @classmethod - def _tokenize_international(cls, line: str) -> str: + def _tokenize_international(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: r"""Tokenizes a string following the official BLEU implementation. See github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 @@ -289,7 +294,7 @@ def _tokenize_international(cls, line: str) -> str: return " ".join(line.split()) @classmethod - def _tokenize_char(cls, line: str) -> str: + def _tokenize_char(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes all the characters in the input line. Args: @@ -302,7 +307,7 @@ def _tokenize_char(cls, line: str) -> str: return " ".join(char for char in line) @classmethod - def _tokenize_ja_mecab(cls, line: str) -> str: + def _tokenize_ja_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a Japanese string line using MeCab morphological analyzer. Args: @@ -321,7 +326,7 @@ def _tokenize_ja_mecab(cls, line: str) -> str: return tagger.parse(line).strip() @classmethod - def _tokenize_ko_mecab(cls, line: str) -> str: + def _tokenize_ko_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a Korean string line using MeCab-korean morphological analyzer. Args: @@ -340,7 +345,9 @@ def _tokenize_ko_mecab(cls, line: str) -> str: return tagger.parse(line).strip() @classmethod - def _tokenize_flores(cls, line: str, tokenize: Literal["flores101", "flores200"]) -> str: + def _tokenize_flores( + cls: Type["_SacreBLEUTokenizer"], line: str, tokenize: Literal["flores101", "flores200"] + ) -> str: """Tokenizes a string line using sentencepiece tokenizer. Args: @@ -365,7 +372,7 @@ def _tokenize_flores(cls, line: str, tokenize: Literal["flores101", "flores200"] return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr] @classmethod - def _tokenize_flores_101(cls, line: str) -> str: + def _tokenize_flores_101(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset. Args: @@ -378,7 +385,7 @@ def _tokenize_flores_101(cls, line: str) -> str: return cls._tokenize_flores(line, "flores101") @classmethod - def _tokenize_flores_200(cls, line: str) -> str: + def _tokenize_flores_200(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset. Args: @@ -397,7 +404,7 @@ def _lower(line: str, lowercase: bool) -> str: return line @classmethod - def _check_tokenizers_validity(cls, tokenize: _Tokenizers_list) -> None: + def _check_tokenizers_validity(cls: Type["_SacreBLEUTokenizer"], tokenize: _TokenizersLiteral) -> None: """Check if a supported tokenizer is chosen. Also check all dependencies of a given tokenizers are installed. @@ -453,7 +460,7 @@ def sacre_bleu_score( target: Sequence[Sequence[str]], n_gram: int = 4, smooth: bool = False, - tokenize: _Tokenizers_list = "13a", + tokenize: _TokenizersLiteral = "13a", lowercase: bool = False, weights: Optional[Sequence[float]] = None, ) -> Tensor: diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 4020f4e1d63..51d19120039 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -35,7 +35,7 @@ import re from functools import lru_cache -from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union from torch import Tensor, tensor @@ -150,7 +150,7 @@ def _normalize_general_and_western(sentence: str) -> str: return sentence @classmethod - def _normalize_asian(cls, sentence: str) -> str: + def _normalize_asian(cls: Type["_TercomTokenizer"], sentence: str) -> str: """Split Chinese chars and Japanese kanji down to character level.""" # 4E00—9FFF CJK Unified Ideographs # 3400—4DBF CJK Unified Ideographs Extension A @@ -182,7 +182,7 @@ def _remove_punct(sentence: str) -> str: return re.sub(r"[\.,\?:;!\"\(\)]", "", sentence) @classmethod - def _remove_asian_punct(cls, sentence: str) -> str: + def _remove_asian_punct(cls: Type["_TercomTokenizer"], sentence: str) -> str: """Remove asian punctuation from an input sentence string.""" sentence = re.sub(cls._ASIAN_PUNCTUATION, r"", sentence) return re.sub(cls._FULL_WIDTH_PUNCTUATION, r"", sentence) diff --git a/src/torchmetrics/text/sacre_bleu.py b/src/torchmetrics/text/sacre_bleu.py index 0708f51544f..8e2e63ad82b 100644 --- a/src/torchmetrics/text/sacre_bleu.py +++ b/src/torchmetrics/text/sacre_bleu.py @@ -22,7 +22,7 @@ from torch import Tensor from torchmetrics.functional.text.bleu import _bleu_score_update -from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer, _Tokenizers_list +from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer, _TokenizersLiteral from torchmetrics.text.bleu import BLEUScore from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -91,7 +91,7 @@ def __init__( self, n_gram: int = 4, smooth: bool = False, - tokenize: _Tokenizers_list = "13a", + tokenize: _TokenizersLiteral = "13a", lowercase: bool = False, weights: Optional[Sequence[float]] = None, **kwargs: Any, diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index c5fce04c6c2..bfc2fd20190 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type + from lightning_utilities.core.enums import StrEnum from typing_extensions import Literal @@ -23,7 +25,7 @@ def _name() -> str: return "Task" @classmethod - def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "EnumStr": + def from_str(cls: Type["EnumStr"], value: str, source: Literal["key", "value", "any"] = "key") -> "EnumStr": """Load from string. Raises: diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index 0fc6de59e11..c35603fa8b1 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -17,7 +17,7 @@ import pytest from torch import Tensor, tensor -from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _Tokenizers_list, sacre_bleu_score +from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE @@ -128,4 +128,4 @@ def test_tokenize_ko_mecab(): def test_equivalence_of_available_tokenizers_and_annotation(): """Test equivalence of SacreBLEU available tokenizers and corresponding type annotation.""" - assert set(AVAILABLE_TOKENIZERS) == set(_Tokenizers_list.__args__) + assert set(AVAILABLE_TOKENIZERS) == set(_TokenizersLiteral.__args__)