diff --git a/docs/advanced/PreTrainedModelsHF.md b/docs/advanced/PreTrainedModelsHF.md index d3a9a18..0b4d099 100644 --- a/docs/advanced/PreTrainedModelsHF.md +++ b/docs/advanced/PreTrainedModelsHF.md @@ -1 +1,260 @@ -# Use pre-trained models from HuggingFace +# HuggingFace's pre-trained models in the Melusine framework + + + +[HuggingFace](https://huggingface.co/) + + +The Hugging Face library has revolutionized the landscape of natural language processing (NLP) and beyond, redefining the boundaries of what's possible in NLP and other domains and establishing itself as an indispensable tool for researchers, data scientists, and developers. By bridging the gap between cutting-edge research and practical implementation, Hugging Face not only simplifies the complexities of model deployment but also fosters innovation across industries, enabling applications that were once considered out of reach. + +Renowned for its user-friendly interface and extensive collection of pre-trained models, Hugging Face empowers users to tackle a diverse range of tasks from text classification and sentiment analysis to machine translation and question answering. The library's versatility and adaptability make it a cornerstone in modern AI development, providing accurate and efficient models. + + + +**Melusine** provides an exceptional framework for streamlining and optimizing email workflows with remarkable efficiency. Its flexible architecture allows seamless integration of machine learning models into its detectors, as demonstrated in the Hugging Face folder, enabling users to harness advanced AI capabilities for enhanced performance. + +### Tutorial : Dissatisfaction detection using Hugging-face models + +Whether it's utilizing pre-trained models from Hugging Face, such as BERT or DistilBERT, for email classification, integrating Named Entity Recognition (NER) models to extract key information, leveraging topic modeling transformers to organize emails by themes, or using language translation models to convert emails into multiple languages, all of these capabilities are seamlessly achievable through the Melusine framework. + +By integrating these models into the Melusine framework, businesses can unlock advanced email processing capabilities, streamline workflows, and enhance productivity across their operations. Transformers-based models from Hugging Face can significantly enhance detection capabilities and act as a complementary approach to strengthen prediction. +The integration of these advanced transformations is primarily facilitated through **Melusine detectors**. + +**model selection** + +The selection of a model depends on the specific detection task. For example, **Sentiment detection in French text** suitable models includes camembert and distil-camembert. + + +**Implementing solution** + +As usual , the detector inherites from a **MelusineTransformerDetector** base class, adheres to the standard structure of a Melusine detector, with the addition of a method enabling machine learning-based detection. +The MelusineTransformerDetector class has one additional defined method **by_ml_detect** as demonstrated below + + +``` python +class MelusineTransformerDetector(BaseMelusineDetector, ABC): + """ + Defines an interface for detectors. + All detectors used in a MelusinePipeline should inherit from the MelusineDetector class and + implement the abstract methods. + This ensures homogeneous coding style throughout the application. + Alternatively, melusine user's can define their own Interface (inheriting from the BaseMelusineDetector) + to suit their needs. + """ + + @property + def transform_methods(self) -> list[Callable]: + """ + Specify the sequence of methods to be called by the transform method. + + Returns + ------- + _: list[Callable] + List of methods to be called by the transform method. + """ + return [ + self.pre_detect, + self.by_regex_detect, + self.by_ml_detect, + self.post_detect, + ] + + @abstractmethod + def pre_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """What needs to be done before detection.""" + + @abstractmethod + def by_regex_detect( + self, row: MelusineItem, debug_mode: bool = False + ) -> MelusineItem: + """Run detection.""" + + @abstractmethod + def by_ml_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """Run detection.""" + + @abstractmethod + def post_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """What needs to be done after detection (e.g., mapping columns).""" +``` + +**The detection method can be one of the following three** + +* Purely deterministic : using the Melusine_regex fonctionality +* Machine learning-based detection : using Hugging-Face models +* Combining deterministic and machine-learning based methods + + + +```mermaid + + graph LR + + A[PRE-DETECT] -- deterministic --> B(by_regex_detect) + + A -- machine-learning based --> C( by_ml_detect) + + A -- combined methods --> D( by_regex_detect & by_ml_detect) + + B --> E[POST-DETECT] + C --> E + D --> E + +``` + + +* In order to detect dissatisfaction emotions by regex, a DissatisfactionRegex class inheriting from melusineregex is required. + +The implemntation can be found in here ! (melusine/regex/dissatisfaction_regex.py) +After constructing the DissatisfactionRegex class , the by_regex_detect method could be implemented as demonstrated in the DissatisfactionDetector! (huggingface/detectors.py/) + + +## The Machine Learning Approach to Detect Dissatisfaction: Two Methods + + +* Using a Pre-trained Model Directly +In this case a hf-token is required as menshioned in the model class. +The model can be loaded directly from the Hugging Face platform, along with its tokenizer, for immediate use in detecting dissatisfaction. + +* Fine-tuning the Model : A pre-trained model can be fine-tuned using various methods, including: + + * The Hugging Face Trainer API + * PyTorch Lightning (https://lightning.ai/docs/pytorch/stable/) + + Fine-tuning approaches: + + 1- **Full Fine-Tuning**: Involves updating all layers of the model during training, typically used for adapting the model to a specific task. + 2- **LoRA (Low-Rank Adaptation)** in Parameter-Efficient Fine-Tuning (PEFT): A method designed to reduce computational and memory costs by only fine-tuning a small subset of parameters, while maintaining high performance. + +Fine-tuning allows customization of the model for specific tasks, improving its performance on datasets relevant to dissatisfaction detection. +A fine-tuned model could be then locally stored and loaded from path. + + +```python +def load_hfmodel(self, model_name="distilcamembert-base") -> None: + """ + GET Distil-camembert-base from HF + Parameters + + ---------- + + row: MelusineItem Content of an email. + debug_mode: bool Debug mode activation flag. + Returns + + ------- + + row: MelusineItem + Updated row. + + """ + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=2 + ) + + +def predict(self, text: str) -> Tuple[List, List]: + """ + Apply model and get prediction + Parameters + ---------- + text: str + Email text + Returns + ------- + row: MelusineItem + Updated row. + """ + + inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") + # Forward pass through the model + outputs = self.model(**inputs) + # Extract logits + self.logits = outputs.logits + # Convert logits to probabilities using softmax + probs = torch.nn.functional.softmax(self.logits, dim=-1) + probs = probs.detach().cpu().numpy() + # Convert predictions and scores to lists + predictions = probs.argmax(axis=1).tolist() + scores = probs.max(axis=1).tolist() + return predictions, scores +``` + + + +The by_ml_detect function applies the model on a the input text. It returns both the predictions outputs and the scores outputs. A certain threshold could be then defined in the detector configuration so that the resulting prediction would be based on the score's validity and its threshold-crossing. + + +```python +def by_ml_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """ + Use machine learning model to detect dissatisfaction. + + Parameters + ---------- + row: MelusineItem + Content of an email. + debug_mode: bool + Debug mode activation flag. + + Returns + ------- + row: MelusineItem + Updated row. + """ + + predictions, scores = self.melusine_model.predict(row[self.CONST_TEXT_COL_NAME]) + debug_info: Dict[str, Any] = {} + + row[self.DISSATISFACTION_ML_MATCH_COL], row[self.DISSATISFACTION_ML_SCORE_COL] = ( + bool(predictions[0]), + scores[0], + ) + # Save debug data + if debug_mode: + debug_info[self.DISSATISFACTION_ML_MATCH_COL] = row[ + self.DISSATISFACTION_ML_MATCH_COL + ] + debug_info[self.DISSATISFACTION_ML_SCORE_COL] = row[ + self.DISSATISFACTION_ML_SCORE_COL + ] + row[self.debug_dict_col].update(debug_info) + return row +``` + +The final detection result could be defined in the **post_detect** method using a predefined condition. +[! Example ] : condition : by_regex_detect OR (by_ml_detect and by_ml_detect.score > .9) + + +```python +def post_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """ + Apply final eligibility rules. + + Parameters + ---------- + row: MelusineItem + Content of an email. + debug_mode: bool + Debug mode activation flag. + + Returns + ------- + row: MelusineItem + Updated row. + """ + + # Match on thanks regex & Does not contain a body + ml_result = (row[self.DISSATISFACTION_ML_SCORE_COL] > 0.9) and row[ + self.DISSATISFACTION_ML_MATCH_COL + ] + deterministic_result = row[self.DISSATISFACTION_BY_REGEX_MATCH_COL] + row[self.result_column] = deterministic_result or ml_result + return row +``` + + +**Melusine already automates email workflows using deterministic regex-based methods. However, the rapid growth and evolution of artificial intelligence applications in the NLP landscape remain largely untapped. This tutorial offers a glimpse into integrating state-of-the-art models into your workflows. Feel free to experiment with different model types, preprocessing methods, and use cases while maintaining the general structure of the detector. The core purpose of Melusine lies in its modularity and versatility, enabling it to handle a wide range of applications and modeling tools effectively.** + + diff --git a/hugging_face/__init__.py b/hugging_face/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hugging_face/detectors.py b/hugging_face/detectors.py new file mode 100644 index 0000000..66d4104 --- /dev/null +++ b/hugging_face/detectors.py @@ -0,0 +1,194 @@ +""" +Classes of detectors. + +Implemented classes: [ThanksDetector, VacationReplyDetector, ExpeditorDetector, +ReplyDetector, TransferDetector, RecipientsDetector] + +""" + +from typing import Any, Dict, List, Optional + +from hugging_face.models.model import TextClassifier +from melusine.base import MelusineItem, MelusineRegex, MelusineTransformerDetector +from melusine.regex import DissatisfactionRegex + + +class DissatisfactionDetector(MelusineTransformerDetector): + """ + Class to detect emails containing dissatisfaction emotion. + + Ex: + je vous deteste, + Cordialement + """ + + # Intermediate columns + CONST_TEXT_COL_NAME: str = "effective_text" + DISSATISFACTION_TEXT_COL: str = "dissatisfaction_text" + CONST_DEBUG_TEXT_KEY: str = "text" + CONST_DEBUG_PARTS_KEY: str = "parts" + + # Results columns + DISSATISFACTION_ML_SCORE_COL: str = "dissatisfaction_ml_score" + DISSATISFACTION_ML_MATCH_COL: str = "dissatisfaction_ml_result" + DISSATISFACTION_BY_REGEX_MATCH_COL: str = "dissatisfaction_regex_result" + + def __init__( + self, + text_column: str, + name: str, + tokenizer_name_or_path: str, + model_name_or_path: str, + token: Optional[str] = None, + ) -> None: + """ + Attributes initialization. + + Parameters + ---------- + text_column: str + Name of the column containing the email text. + name: str + Name of the detector. + tokenizer_name_or_path: str + Name of model or path of the tokenizer. + model_name_or_path: str + Name of path of the model. + text_column: str + Name of the column containing the email text. + token: Optional[str] + hugging-face token . + """ + + # Input columns + self.text_column = text_column + input_columns: List[str] = [text_column] + + # Output columns + self.result_column = f"{name}_result" + output_columns: List[str] = [self.result_column] + + # Detection regex + self.dissatisfaction_regex: MelusineRegex = DissatisfactionRegex() + self.token = token + + super().__init__( + name=name, + input_columns=input_columns, + output_columns=output_columns, + ) + self.melusine_model = TextClassifier( + tokenizer_name_or_path=tokenizer_name_or_path, model_name_or_path=model_name_or_path, token=self.token + ) + + def pre_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """ + Extract text to analyse. + + Parameters + ---------- + row: MelusineItem + Content of an email. + debug_mode: bool + Debug mode activation flag. + + Returns + ------- + row: MelusineItem + Updated row. + """ + + # Last message body + message_text: str = row[self.text_column] + + row[self.CONST_TEXT_COL_NAME] = "\n".join([message_text]) + + # Prepare and save debug data + if debug_mode: + debug_dict: Dict[str, Any] = { + self.CONST_DEBUG_TEXT_KEY: row[self.CONST_TEXT_COL_NAME], + } + row[self.debug_dict_col] = debug_dict + + return row + + def by_regex_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """ + Use regex to detect dissatisfaction. + + Parameters + ---------- + row: MelusineItem + Content of an email. + debug_mode: bool + Debug mode activation flag. + + Returns + ------- + row: MelusineItem + Updated row. + """ + debug_info: Dict[str, Any] = {} + text: str = row[self.CONST_TEXT_COL_NAME] + detection_data = self.dissatisfaction_regex(text) + detection_result = detection_data[self.dissatisfaction_regex.MATCH_RESULT] + + # Save debug data + if debug_mode: + debug_info[self.dissatisfaction_regex.regex_name] = detection_data + row[self.debug_dict_col].update(debug_info) + + # Create new columns + row[self.DISSATISFACTION_BY_REGEX_MATCH_COL] = detection_result + return row + + def by_ml_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """ + Use machine learning model to detect dissatisfaction. + + Parameters + ---------- + row: MelusineItem + Content of an email. + debug_mode: bool + Debug mode activation flag. + + Returns + ------- + row: MelusineItem + Updated row. + """ + + predictions, scores = self.melusine_model.predict(row[self.CONST_TEXT_COL_NAME]) + debug_info: Dict[str, Any] = {} + + row[self.DISSATISFACTION_ML_MATCH_COL], row[self.DISSATISFACTION_ML_SCORE_COL] = bool(predictions[0]), scores[0] + # Save debug data + if debug_mode: + debug_info[self.DISSATISFACTION_ML_MATCH_COL] = row[self.DISSATISFACTION_ML_MATCH_COL] + debug_info[self.DISSATISFACTION_ML_SCORE_COL] = row[self.DISSATISFACTION_ML_SCORE_COL] + row[self.debug_dict_col].update(debug_info) + return row + + def post_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """ + Apply final eligibility rules. + + Parameters + ---------- + row: MelusineItem + Content of an email. + debug_mode: bool + Debug mode activation flag. + + Returns + ------- + row: MelusineItem + Updated row. + """ + + # Match on thanks regex & Does not contain a body + ml_result = (row[self.DISSATISFACTION_ML_SCORE_COL] > 0.9) and row[self.DISSATISFACTION_ML_MATCH_COL] + deterministic_result = row[self.DISSATISFACTION_BY_REGEX_MATCH_COL] + row[self.result_column] = deterministic_result or ml_result + return row diff --git a/hugging_face/models/__init__.py b/hugging_face/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hugging_face/models/model.py b/hugging_face/models/model.py new file mode 100644 index 0000000..598e177 --- /dev/null +++ b/hugging_face/models/model.py @@ -0,0 +1,81 @@ +from typing import List, Optional, Tuple + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + + +class TextClassifier: + """ + The modeling class + + """ + + def __init__(self, tokenizer_name_or_path: str, model_name_or_path: str, token: Optional[str]): + """ + Apply model and get prediction + Parameters + ---------- + tokenizer_name_or_path: str + tokenizer name or path . + model_name_or_path: str + model name or path. + token: Optional[str] + hugging-face pass + Returns + ------- + row: MelusineItem + Updated row. + """ + self.tokenizer_name_or_path = tokenizer_name_or_path + self.model_name_or_path = model_name_or_path + self.hf_token = token + self.load_model() + + def load_model(self) -> None: + """ + Apply model and get prediction + Parameters + ---------- + + Returns + ------- + None + """ + if self.hf_token: + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=self.tokenizer_name_or_path, use_auth_token=self.hf_token + ) + self.model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_name_or_path=self.model_name_or_path, num_labels=2, use_auth_token=self.hf_token + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.tokenizer_name_or_path) + self.model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_name_or_path=self.model_name_or_path, num_labels=2 + ) + + def predict(self, text) -> Tuple[List, List]: + """ + Apply model and get prediction + Parameters + ---------- + text: str + Email text + Returns + ------- + predictions, scores: Tuple[List, List] + Model output post softmax appliance + """ + + inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") + # Forward pass through the model + outputs = self.model(**inputs) + # Extract logits + self.logits = outputs.logits + # Convert logits to probabilities using softmax + probs = torch.nn.functional.softmax(self.logits, dim=-1) + probs = probs.detach().cpu().numpy() + # Convert predictions and scores to lists + predictions = probs.argmax(axis=1).tolist() + scores = probs.max(axis=1).tolist() + return predictions, scores diff --git a/hugging_face/models/model_1.onnx b/hugging_face/models/model_1.onnx new file mode 100644 index 0000000..e69de29 diff --git a/melusine/base.py b/melusine/base.py index 36c347a..03ca017 100644 --- a/melusine/base.py +++ b/melusine/base.py @@ -8,7 +8,8 @@ BaseLabelProcessor, MissingModelInputFieldError, MissingFieldError, - MelusineFeatureEncoder + MelusineFeatureEncoder, + MelusineTransformerDetector ] """ @@ -302,6 +303,45 @@ def post_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineIt """What needs to be done after detection (e.g., mapping columns).""" +class MelusineTransformerDetector(BaseMelusineDetector, ABC): + """ + Defines an interface for detectors. + All detectors used in a MelusinePipeline should inherit from the MelusineDetector class and + implement the abstract methods. + This ensures homogeneous coding style throughout the application. + Alternatively, melusine user's can define their own Interface (inheriting from the BaseMelusineDetector) + to suit their needs. + """ + + @property + def transform_methods(self) -> list[Callable]: + """ + Specify the sequence of methods to be called by the transform method. + + Returns + ------- + _: list[Callable] + List of methods to be called by the transform method. + """ + return [self.pre_detect, self.by_regex_detect, self.by_ml_detect, self.post_detect] + + @abstractmethod + def pre_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """What needs to be done before detection.""" + + @abstractmethod + def by_regex_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """Run detection.""" + + @abstractmethod + def by_ml_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """Run detection.""" + + @abstractmethod + def post_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineItem: + """What needs to be done after detection (e.g., mapping columns).""" + + class MissingFieldError(Exception): """ Exception raised when a missing field is encountered by a MelusineTransformer diff --git a/melusine/regex/__init__.py b/melusine/regex/__init__.py index f27b8a8..59ab848 100644 --- a/melusine/regex/__init__.py +++ b/melusine/regex/__init__.py @@ -2,10 +2,11 @@ The melusine.regex module includes tools for handling regexes. """ +from melusine.regex.dissatisfaction_regex import DissatisfactionRegex from melusine.regex.emergency_regex import EmergencyRegex from melusine.regex.reply_regex import ReplyRegex from melusine.regex.thanks_regex import ThanksRegex from melusine.regex.transfer_regex import TransferRegex from melusine.regex.vacation_reply_regex import VacationReplyRegex -__all__ = ["EmergencyRegex", "ReplyRegex", "ThanksRegex", "TransferRegex", "VacationReplyRegex"] +__all__ = ["EmergencyRegex", "ReplyRegex", "ThanksRegex", "TransferRegex", "VacationReplyRegex", "DissatisfactionRegex"] diff --git a/melusine/regex/dissatisfaction_regex.py b/melusine/regex/dissatisfaction_regex.py new file mode 100644 index 0000000..056049d --- /dev/null +++ b/melusine/regex/dissatisfaction_regex.py @@ -0,0 +1,61 @@ +from typing import Dict, List, Optional, Union + +from melusine.base import MelusineRegex + + +class DissatisfactionRegex(MelusineRegex): + """ + Detect thanks patterns such as "merci". + """ + + @property + def positive(self) -> Union[str, Dict[str, str]]: + """ + Define regex patterns required to activate the MelusineRegex. + + Returns: + _: Regex pattern or dict of regex patterns. + """ + return r"\b(j'en ai marre|insatisfait|c'est nul|trop déçu|décevant|inadmissible|insupportable|intolérable|honteux|lamentable|catastrophe)\b" + + @property + def neutral(self) -> Optional[Union[str, Dict[str, str]]]: + """ + Define regex patterns to be ignored when running detection. + + Returns: + _: Regex pattern or dict of regex patterns. + """ + return None + + @property + def negative(self) -> Optional[Union[str, Dict[str, str]]]: + """ + Define regex patterns prohibited to activate the MelusineRegex. + + Returns: + _: Regex pattern or dict of regex patterns. + """ + return None + + @property + def match_list(self) -> List[str]: + """ + List of texts that should activate the MelusineRegex. + + Returns: + _: List of texts. + """ + return [ + "complétement insatisfait de ce que vous faites", + ] + + @property + def no_match_list(self) -> List[str]: + """ + List of texts that should NOT activate the MelusineRegex. + + Returns: + _: List of texts. + """ + return [] diff --git a/pyproject.toml b/pyproject.toml index 085cb93..cc8ef01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,16 +38,18 @@ classifiers = [ dependencies = [ "arrow", "pandas>2", - "scikit-learn>=1", + "scikit-learn<1.6", "tqdm>=4.34", "omegaconf>=2.0", ] + dynamic = ["version"] [project.optional-dependencies] # Optional -dev = ["tox", "pre-commit", "black", "flake8", "isort", "mypy", "pytest", "coverage", "build", "ruff"] +dev = ["tox", "pre-commit", "black", "flake8", "isort", "mypy", "pytest", "coverage", "build", "ruff" ] test = ["pytest", "coverage", "pytest-cov", "google-auth-oauthlib", "google-api-python-client"] -transformers = ["transformers>4"] +transformers = ["transformers>4" ] +torch-cpu = ["torch>=2.0.0"] connectors = ["exchangelib", "google-auth-oauthlib", "google-api-python-client"] docs = ["mkdocs", "markdown", "mkdocs-material", "mdx-include"] diff --git a/tests/conftest.py b/tests/conftest.py index 678606f..0ae7903 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ "tests.fixtures.docs", "tests.fixtures.pipelines", "tests.fixtures.processors", + "tests.huggingface.models", ] diff --git a/tests/huggingface/models.py b/tests/huggingface/models.py new file mode 100644 index 0000000..f8fa80c --- /dev/null +++ b/tests/huggingface/models.py @@ -0,0 +1,55 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from hugging_face.detectors import DissatisfactionDetector +from hugging_face.models.model import TextClassifier + + +def return_value(resp, content): + return content + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.return_value = {"input_ids": [[101, 102]], "attention_mask": [[1, 1]]} + return tokenizer + + +@pytest.fixture +def mock_model(): + model = MagicMock() + model.return_value.logits = torch.tensor([[0.1, 0.9]]) # Simulated logits + return model + + +@pytest.fixture +def mock_detector(mock_tokenizer, mock_model): + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + with patch("transformers.AutoModelForSequenceClassification.from_pretrained", return_value=mock_model): + # Create a TextClassifier instance + classifier = TextClassifier( + tokenizer_name_or_path="mock_tokenizer", + model_name_or_path="mock_model", + token=None, + ) + + # Create the DissatisfactionDetector using the mock classifier + detector = DissatisfactionDetector( + name="dissatisfaction", + text_column="det_normalized_last_body", + model_name_or_path="mock_model_path", + tokenizer_name_or_path="mock_tokenizer_path", + token=None, + ) + detector.melusine_model = classifier + return detector + + +# Example test using the mock_detector fixture +def test_mock_detector_instantiation(mock_detector): + assert isinstance(mock_detector, DissatisfactionDetector) + assert mock_detector.name == "dissatisfaction" + assert mock_detector.text_column == "det_normalized_last_body" diff --git a/tests/huggingface/test_dissatisfaction_detector.py b/tests/huggingface/test_dissatisfaction_detector.py new file mode 100644 index 0000000..94cdcc5 --- /dev/null +++ b/tests/huggingface/test_dissatisfaction_detector.py @@ -0,0 +1,147 @@ +""" +Unit tests of the DissatisfactionDetector +The model used inside of the detector is mocked in the fixtures tests +""" + +from unittest.mock import MagicMock, patch + +import pytest +from pandas import DataFrame + +from hugging_face.detectors import DissatisfactionDetector +from hugging_face.models.model import TextClassifier + + +@pytest.mark.usefixtures("mock_detector") +def test_instantiation(mock_detector): + """Test that the mock detector is instantiated correctly.""" + assert isinstance(mock_detector, DissatisfactionDetector) + assert mock_detector.name == "dissatisfaction" + assert mock_detector.text_column == "det_normalized_last_body" + + +@pytest.mark.usefixtures("mock_detector") +@pytest.mark.parametrize( + "row, good_deterministic_result", + [ + ( + {"det_normalized_last_body": "je suis content de votre service."}, + False, + ), + ( + {"det_normalized_last_body": "je suis complètement insatisfait de votre service."}, + True, + ), + ( + { + "det_normalized_last_body": "Franchement, j'en ai marre de ce genre de service qui ne respecte pas ses engagements." + }, + True, + ), + ( + {"det_normalized_last_body": "Je suis trop déçu par la qualité, je m'attendais à bien mieux pour ce prix."}, + True, + ), + ( + {"det_normalized_last_body": "C'est vraiment décevant de voir un tel manque de professionnalisme."}, + True, + ), + ], +) +def test_by_regex_detect(row, good_deterministic_result, mock_detector): + """Unit test of the transform() method.""" + df_copy = row.copy() + df_copy = mock_detector.pre_detect(df_copy, debug_mode=True) + df_copy = mock_detector.by_regex_detect(df_copy, debug_mode=True) + + deterministic_result = mock_detector.DISSATISFACTION_BY_REGEX_MATCH_COL + deterministic_debug_result = mock_detector.debug_dict_col + + assert deterministic_result in df_copy.keys() + assert deterministic_debug_result in df_copy.keys() + assert df_copy[deterministic_result] == good_deterministic_result + + +@pytest.mark.usefixtures("mock_detector") +@pytest.mark.parametrize( + "row, good_ml_result", + [ + ( + {"det_normalized_last_body": "je suis complètement insatisfait de votre service."}, + True, + ), + ( + { + "det_normalized_last_body": "Un service médiocre, avec des frais cachés qui ont presque doublé le coût final. Je ne ferai plus appel à eux." + }, + True, + ), + ( + { + "det_normalized_last_body": "Très déçu. L’article ne correspond pas du tout à la description, et la qualité laisse à désirer." + }, + True, + ), + ], +) +def test_by_ml_detection(row, good_ml_result, mock_detector): + """Unit test of the transform() method.""" + df_copy = row.copy() + # Test result + df_copy = mock_detector.pre_detect(df_copy, debug_mode=True) + df_copy = mock_detector.by_ml_detect(df_copy, debug_mode=True) + + # Test result + ml_result_col = mock_detector.DISSATISFACTION_ML_MATCH_COL + ml_score_col = mock_detector.DISSATISFACTION_ML_SCORE_COL + + assert ml_result_col in df_copy.keys() + assert ml_score_col in df_copy.keys() + assert df_copy[ml_result_col] == good_ml_result + assert isinstance(df_copy[ml_score_col], float) + assert df_copy[ml_score_col] > 0.5 + + +@pytest.mark.usefixtures("mock_detector") +@pytest.mark.parametrize( + "df, good_result", + [ + ( + DataFrame( + { + "det_normalized_last_body": ["je suis complètement insatisfait de votre service."], + } + ), + True, + ), + ( + DataFrame( + { + "det_normalized_last_body": [ + "Ce retard est une véritable catastrophe, cela m'a causé beaucoup de problèmes." + ], + } + ), + True, + ), + ( + DataFrame( + { + "det_normalized_last_body": [ + "Le traitement que j'ai reçu est honteux, surtout venant d'une entreprise comme la vôtre." + ], + } + ), + True, + ), + ], +) +def test_by_transform_detection(df, good_result, mock_detector): + """Unit test of the transform() method.""" + df_copy = df.copy() + # Test result + df_copy = mock_detector.transform(df_copy) + # Test result + result_col = mock_detector.result_column + assert result_col in df_copy.keys() + assert bool(df_copy[result_col][0]) == good_result diff --git a/tox.ini b/tox.ini index a94c341..8d43513 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,14 @@ [tox] requires = tox>=4 -env_list = clean, core38, core310, transformers, report +env_list = clean, core38, core310, core311, transformers, report [gh-actions] python = 3.8: clean, core38, transformers 3.10: core310 + 3.11: core311, transformers + [testenv] commands = pytest --cov --cov-append --cov-report xml @@ -15,9 +17,13 @@ deps = pytest-cov google-auth-oauthlib google-api-python-client + torch depends = {core38,transformers}: clean - report: core38,transformers + report: core38 +extras = transformers + + [testenv:core38] deps={[testenv]deps} @@ -27,6 +33,10 @@ commands={[testenv]commands} deps={[testenv]deps} commands=pytest tests +[testenv:core311] +deps={[testenv]deps} +commands=pytest tests + [testenv:clean] deps = coverage[toml] skip_install = true @@ -35,7 +45,7 @@ commands = coverage erase [testenv:transformers] description = run unit tests with the transformers dependency deps={[testenv]deps} -commands = pytest tests/huggingface --cov --cov-append --cov-report xml +commands = pytest tests/huggingface extras = transformers [testenv:report]