From 4529ca1297cc2204068655298913074507740e5f Mon Sep 17 00:00:00 2001 From: Hugo Perrier Date: Tue, 10 Dec 2024 17:17:44 +0100 Subject: [PATCH] :alien: Add a fit method to MelusineTransformer and BaseMelusineDetector --- melusine/base.py | 34 ++++++++++++++++++++++++++++++++++ tests/gmail/test_gmail.py | 11 ++++++----- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/melusine/base.py b/melusine/base.py index b7cd199..016a910 100644 --- a/melusine/base.py +++ b/melusine/base.py @@ -99,6 +99,23 @@ def parse_column_list(columns: str | Iterable[str]) -> list[str]: columns = [columns] return list(columns) + def fit(self, X: MelusineDataset, y: Any = None) -> MelusineTransformer: + """A reference implementation of a fitting function. + + Parameters + ---------- + X : The training input samples. + + y : The target values (class labels in classification, real numbers in + regression). + + Returns + ------- + self : object + Returns self. + """ + return self + def transform(self, data: MelusineDataset) -> MelusineDataset: """ Transform input data. @@ -196,6 +213,23 @@ def transform_methods(self) -> list[Callable]: List of methods to be called by the transform method. """ + def fit(self, X: MelusineDataset, y: Any = None) -> MelusineTransformer: + """A reference implementation of a fitting function. + + Parameters + ---------- + X : The training input samples. + + y : The target values (class labels in classification, real numbers in + regression). + + Returns + ------- + self : object + Returns self. + """ + return self + def transform(self, df: MelusineDataset) -> MelusineDataset: """ Re-definition of super().transform() => specific detector's implementation diff --git a/tests/gmail/test_gmail.py b/tests/gmail/test_gmail.py index d074684..107f038 100644 --- a/tests/gmail/test_gmail.py +++ b/tests/gmail/test_gmail.py @@ -4,10 +4,11 @@ import pandas as pd import pytest -googleapiclient = pytest.importorskip("googleapiclient") -from google.oauth2.credentials import Credentials from unittest.mock import MagicMock, patch +google = pytest.importorskip("google") +googleapiclient = pytest.importorskip("googleapiclient") + from melusine.connectors.gmail import GmailConnector @@ -43,7 +44,7 @@ def mocked_gc(): return_value, ) mock_build.return_value = mock_service - mock_creds_from_file.return_value = Credentials("dummy") + mock_creds_from_file.return_value = google.oauth2.credentials.Credentials("dummy") return GmailConnector(token_json_path="token.json", done_label="TRASH", target_column="target") @@ -91,7 +92,7 @@ def test_init(mock_exists, mock_creds_from_file, mock_build, caplog): return_value, ) mock_build.return_value = mock_service - mock_creds_from_file.return_value = Credentials("dummy") + mock_creds_from_file.return_value = google.oauth2.credentials.Credentials("dummy") # Creating an instance of GmailConnector with caplog.at_level(logging.DEBUG): @@ -134,7 +135,7 @@ def test_init_without_creds(mock_flow, mock_build, caplog): return_value, ) mock_build.return_value = mock_service - mock_flow.return_value.run_local_server.return_value = Credentials("dummy") + mock_flow.return_value.run_local_server.return_value = google.oauth2.credentials.Credentials("dummy") # Creating an instance of GmailConnector with caplog.at_level(logging.DEBUG):