diff --git a/README.rst b/README.rst index a6bce47..8a4bdb3 100644 --- a/README.rst +++ b/README.rst @@ -3,4 +3,4 @@ pylint-ml About ----- -``pylint-ml`` is a pylint plugin for enhancing code analysis for machine learning and data science +``pylint-ml`` is a pylint plugin for enhancing code analysis for machine learning and data science projects diff --git a/pylint_ml/checkers/config.py b/pylint_ml/checkers/config.py new file mode 100644 index 0000000..4ea6082 --- /dev/null +++ b/pylint_ml/checkers/config.py @@ -0,0 +1,16 @@ +# Library names +PANDAS = "pandas" +PANDAS_ALIAS = "pd" + +NUMPY = "numpy" +NUMPY_ALIAS = "np" + +TENSORFLOW = "tensor" + +SCIPY = "scipy" + +SKLEARN = "sklearn" + +TORCH = "torch" + +MATPLOTLIB = "matplotlib" diff --git a/pylint_ml/checkers/library_base_checker.py b/pylint_ml/checkers/library_base_checker.py new file mode 100644 index 0000000..fef5145 --- /dev/null +++ b/pylint_ml/checkers/library_base_checker.py @@ -0,0 +1,46 @@ +from importlib.metadata import PackageNotFoundError, version + +from pylint.checkers import BaseChecker + + +class LibraryBaseChecker(BaseChecker): + + def __init__(self, linter): + super().__init__(linter) + self.imports = {} + + def visit_import(self, node): + for name, alias in node.names: + self.imports[alias or name] = name # E.g. {'pd': 'pandas'} + + def visit_importfrom(self, node): + base_module = node.modname.split(".")[0] # Extract the first part of the module name + + for name, alias in node.names: + full_name = f"{node.modname}.{name}" + self.imports[base_module] = full_name # E.g. {'scipy': 'scipy.optimize.minimize'} + + def is_library_imported_and_version_valid(self, lib_name, required_version): + """ + Checks if the library is imported and whether the installed version is valid (greater than or equal to the + required version). + + param lib_name: Name of the library (as a string). + param required_version: The required minimum version (as a string). + return: True if the library is imported and the version is valid, otherwise False. + """ + # Check if the library is imported + if not any(mod.startswith(lib_name) for mod in self.imports.values()): + return False + + # Check if the library version is valid + try: + installed_version = version(lib_name) + except PackageNotFoundError: + return False + + # Compare versions (this assumes versioning follows standard conventions like '1.2.3') + if required_version is not None and installed_version < required_version: + return False + + return True diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index d19fbfe..e630dcf 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -8,10 +8,12 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.library_handler import LibraryHandler +from pylint_ml.checkers.config import MATPLOTLIB +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name -class MatplotlibParameterChecker(LibraryHandler): +class MatplotlibParameterChecker(LibraryBaseChecker): name = "matplotlib-parameter" msgs = { "W8111": ( @@ -47,11 +49,10 @@ class MatplotlibParameterChecker(LibraryHandler): @only_required_for_messages("matplotlib-parameter") def visit_call(self, node: nodes.Call) -> None: - # TODO Update - # if not self.is_library_imported('matplotlib') and self.is_library_version_valid(lib_version=): - # return + if not self.is_library_imported_and_version_valid(lib_name=MATPLOTLIB, required_version=None): + return - method_name = self._get_full_method_name(node) + method_name = get_full_method_name(node=node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] @@ -62,15 +63,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - def _get_full_method_name(self, node: nodes.Call) -> str: - func = node.func - method_chain = [] - - while isinstance(func, nodes.Attribute): - method_chain.insert(0, func.attrname) - func = func.expr - if isinstance(func, nodes.Name): - method_chain.insert(0, func.name) - - return ".".join(method_chain) diff --git a/pylint_ml/checkers/numpy/numpy_dot.py b/pylint_ml/checkers/numpy/numpy_dot.py index 955656f..3c905d1 100644 --- a/pylint_ml/checkers/numpy/numpy_dot.py +++ b/pylint_ml/checkers/numpy/numpy_dot.py @@ -10,10 +10,12 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.library_handler import LibraryHandler +from pylint_ml.checkers.config import NUMPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class NumpyDotChecker(LibraryHandler): +class NumpyDotChecker(LibraryBaseChecker): name = "numpy-dot-checker" msgs = { "W8122": ( @@ -24,19 +26,16 @@ class NumpyDotChecker(LibraryHandler): ), } - def visit_import(self, node: nodes.Import): - super().visit_import(node=node) - @only_required_for_messages("numpy-dot-usage") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported("numpy"): + if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return # Check if the function being called is np.dot - if isinstance(node.func, nodes.Attribute): - func_name = node.func.attrname - module_name = getattr(node.func.expr, "name", None) - - if func_name == "dot" and module_name == "np": - # Suggest using np.matmul() instead - self.add_message("numpy-dot-usage", node=node, confidence=HIGH) + if ( + isinstance(node.func, nodes.Attribute) + and node.func.attrname == "dot" + and infer_specific_module_from_call(node=node, module_name=NUMPY) + ): + # Suggest using np.matmul() instead + self.add_message("numpy-dot-usage", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index 4a1a7ad..7eedd4b 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -7,15 +7,18 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import NUMPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute + COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "==")) NUMPY_NAN = frozenset(("nan", "NaN", "NAN")) -class NumpyNaNComparisonChecker(BaseChecker): +class NumpyNaNComparisonChecker(LibraryBaseChecker): name = "numpy-nan-compare" msgs = { "W8001": ( @@ -28,15 +31,19 @@ class NumpyNaNComparisonChecker(BaseChecker): @classmethod def __is_np_nan_call(cls, node: nodes.Attribute) -> bool: """Check if the node represents a call to np.nan.""" - return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == "np" + return node.attrname in NUMPY_NAN and (infer_specific_module_from_attribute(node=node, module_name=NUMPY)) @only_required_for_messages("numpy-nan-compare") def visit_compare(self, node: nodes.Compare) -> None: + if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): + return + # Check node.left first for numpy nan usage if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left): self.add_message("numpy-nan-compare", node=node, confidence=HIGH) return + # Check remaining nodes and operators for numpy nan usage for op, comparator in node.ops: if op in COMPARISON_OP and isinstance(comparator, nodes.Attribute) and self.__is_np_nan_call(comparator): self.add_message("numpy-nan-compare", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index e045e83..180fa84 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -5,12 +5,15 @@ """Check for proper usage of numpy functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import NUMPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class NumPyParameterChecker(BaseChecker): + +class NumPyParameterChecker(LibraryBaseChecker): name = "numpy-parameter" msgs = { "W8111": ( @@ -33,12 +36,12 @@ class NumPyParameterChecker(BaseChecker): "eye": ["N"], "identity": ["n"], # Random Sampling - "random.rand": ["d0"], - "random.randn": ["d0"], - "random.randint": ["low", "high"], - "random.choice": ["a"], - "random.uniform": ["low", "high"], - "random.normal": ["loc", "scale"], + "rand": ["d0"], + "randn": ["d0"], + "randint": ["low", "high"], + "choice": ["a"], + "uniform": ["low", "high"], + "normal": ["loc", "scale"], # Mathematical Functions "sum": ["a"], "mean": ["a"], @@ -59,9 +62,9 @@ class NumPyParameterChecker(BaseChecker): # Linear Algebra "dot": ["a", "b"], "matmul": ["a", "b"], - "linalg.inv": ["a"], - "linalg.eig": ["a"], - "linalg.solve": ["a", "b"], + "inv": ["a"], + "eig": ["a"], + "solve": ["a", "b"], # Statistical Functions "percentile": ["a", "q"], "quantile": ["a", "q"], @@ -71,11 +74,12 @@ class NumPyParameterChecker(BaseChecker): @only_required_for_messages("numpy-parameter") def visit_call(self, node: nodes.Call) -> None: - method_name = self._get_full_method_name(node) + if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): + return - if method_name in self.REQUIRED_PARAMS: + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=NUMPY) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -84,21 +88,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_full_method_name(node: nodes.Call) -> str: - """ - Extracts the full method name, including chained attributes (e.g., np.random.rand). - """ - func = node.func - method_chain = [] - - # Traverse the attribute chain - while isinstance(func, nodes.Attribute): - method_chain.insert(0, func.attrname) - func = func.expr - - # Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix) - if isinstance(func, nodes.Name) and func.name == "np": - return ".".join(method_chain) - return "" diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index a519ca8..693385d 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -7,14 +7,15 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -# Todo add version deprecated +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class PandasDataFrameBoolChecker(BaseChecker): +class PandasDataFrameBoolChecker(LibraryBaseChecker): name = "pandas-dataframe-bool" msgs = { "W8104": ( @@ -26,13 +27,19 @@ class PandasDataFrameBoolChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-bool") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version="2.1.0"): + return + if isinstance(node.func, nodes.Attribute): method_name = getattr(node.func, "attrname", None) - if method_name == "bool": # Check if the object calling .bool() has a name starting with 'df_' object_name = getattr(node.func.expr, "name", None) - if object_name and self._is_valid_dataframe_name(object_name): + if ( + infer_specific_module_from_call(node=node, module_name=PANDAS) + and object_name + and self._is_valid_dataframe_name(object_name) + ): self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH) @staticmethod diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py index 3f3e480..8a92de9 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py @@ -7,12 +7,15 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute -class PandasColumnSelectionChecker(BaseChecker): + +class PandasColumnSelectionChecker(LibraryBaseChecker): name = "pandas-column-selection" msgs = { "W8118": ( @@ -25,6 +28,14 @@ class PandasColumnSelectionChecker(BaseChecker): @only_required_for_messages("pandas-column-selection") def visit_attribute(self, node: nodes.Attribute) -> None: """Check for attribute access that might be a column selection.""" - if isinstance(node.expr, nodes.Name) and node.expr.name.startswith("df_"): + + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + + if ( + infer_specific_module_from_attribute(node=node, module_name=PANDAS) + and isinstance(node.expr, nodes.Name) + and node.expr.name.startswith("df_") + ): # Issue a warning for property-like access self.add_message("pandas-column-selection", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index 3427f1b..b8595f5 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -7,12 +7,14 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -class PandasEmptyColumnChecker(BaseChecker): + +class PandasEmptyColumnChecker(LibraryBaseChecker): name = "pandas-dataframe-empty-column" msgs = { "W8113": ( @@ -25,7 +27,15 @@ class PandasEmptyColumnChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-empty-column") def visit_subscript(self, node: nodes.Subscript) -> None: - if isinstance(node.value, nodes.Name) and node.value.name.startswith("df_"): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + + if ( + isinstance(node.value, nodes.Name) + and node.value.name.startswith("df_") + and PANDAS in safe_infer(node.value).qname() + ): + print(node.value.name) if isinstance(node.slice, nodes.Const) and isinstance(node.parent, nodes.Assign): if isinstance(node.parent.value, nodes.Const): # Checking for filler values: 0 or empty string diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index 99d2b0c..e9871bc 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -7,12 +7,15 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute -class PandasIterrowsChecker(BaseChecker): + +class PandasIterrowsChecker(LibraryBaseChecker): name = "pandas-iterrows" msgs = { "W8106": ( @@ -25,7 +28,12 @@ class PandasIterrowsChecker(BaseChecker): @only_required_for_messages("pandas-iterrows") def visit_call(self, node: nodes.Call) -> None: - if isinstance(node.func, nodes.Attribute): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + + if isinstance(node.func, nodes.Attribute) and infer_specific_module_from_attribute( + node=node.func, module_name=PANDAS + ): method_name = getattr(node.func, "attrname", None) if method_name == "iterrows": object_name = getattr(node.func.expr, "name", None) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py index a0aaf2d..922e1a4 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py @@ -7,12 +7,15 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class PandasDataFrameNamingChecker(BaseChecker): + +class PandasDataFrameNamingChecker(LibraryBaseChecker): name = "pandas-dataframe-naming" msgs = { "W8103": ( @@ -24,11 +27,18 @@ class PandasDataFrameNamingChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-naming") def visit_assign(self, node: nodes.Assign) -> None: + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + if isinstance(node.value, nodes.Call): func_name = getattr(node.value.func, "attrname", None) module_name = getattr(node.value.func.expr, "name", None) - if func_name == "DataFrame" and module_name == "pd": + if ( + func_name == "DataFrame" + and module_name == "pd" + and infer_specific_module_from_call(node=node.value, module_name=PANDAS) + ): for target in node.targets: if isinstance(target, nodes.AssignName): var_name = target.name diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_values.py b/pylint_ml/checkers/pandas/pandas_dataframe_values.py index 13b382f..5651131 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_values.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_values.py @@ -7,12 +7,14 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -class PandasValuesChecker(BaseChecker): + +class PandasValuesChecker(LibraryBaseChecker): name = "pandas-dataframe-values" msgs = { "W8112": ( @@ -25,6 +27,13 @@ class PandasValuesChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-values") def visit_attribute(self, node: nodes.Attribute) -> None: + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + if isinstance(node.expr, nodes.Name): - if node.attrname == "values" and node.expr.name.startswith("df_"): + if ( + node.attrname == "values" + and node.expr.name.startswith("df_") + and PANDAS in safe_infer(node.expr).qname() + ): self.add_message("pandas-dataframe-values", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/pandas/pandas_inplace.py b/pylint_ml/checkers/pandas/pandas_inplace.py index c1eded6..7a13142 100644 --- a/pylint_ml/checkers/pandas/pandas_inplace.py +++ b/pylint_ml/checkers/pandas/pandas_inplace.py @@ -7,12 +7,15 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute -class PandasInplaceChecker(BaseChecker): + +class PandasInplaceChecker(LibraryBaseChecker): name = "pandas-inplace" msgs = { "W8109": ( @@ -39,8 +42,13 @@ class PandasInplaceChecker(BaseChecker): @only_required_for_messages("pandas-inplace") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + # Check if the call is to a method that supports 'inplace' - if isinstance(node.func, nodes.Attribute): + if isinstance(node.func, nodes.Attribute) and infer_specific_module_from_attribute( + node=node.func, module_name=PANDAS + ): method_name = node.func.attrname if method_name in self._inplace_methods: for keyword in node.keywords: diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index 3efee90..36067d9 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -5,12 +5,15 @@ """Check for proper usage of Pandas functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class PandasParameterChecker(BaseChecker): + +class PandasParameterChecker(LibraryBaseChecker): name = "pandas-parameter" msgs = { "W8111": ( @@ -64,10 +67,12 @@ class PandasParameterChecker(BaseChecker): @only_required_for_messages("pandas-parameter") def visit_call(self, node: nodes.Call) -> None: - method_name = self._get_method_name(node) - if method_name in self.REQUIRED_PARAMS: + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=PANDAS) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -76,15 +81,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/pandas/pandas_series_bool.py b/pylint_ml/checkers/pandas/pandas_series_bool.py index dafac68..eaac96d 100644 --- a/pylint_ml/checkers/pandas/pandas_series_bool.py +++ b/pylint_ml/checkers/pandas/pandas_series_bool.py @@ -7,14 +7,16 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH # Todo add version deprecated +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class PandasSeriesBoolChecker(BaseChecker): +class PandasSeriesBoolChecker(LibraryBaseChecker): name = "pandas-series-bool" msgs = { "W8105": ( @@ -26,7 +28,10 @@ class PandasSeriesBoolChecker(BaseChecker): @only_required_for_messages("pandas-series-bool") def visit_call(self, node: nodes.Call) -> None: - if isinstance(node.func, nodes.Attribute): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + + if isinstance(node.func, nodes.Attribute) and infer_specific_module_from_call(node=node, module_name=PANDAS): method_name = getattr(node.func, "attrname", None) if method_name == "bool": diff --git a/pylint_ml/checkers/pandas/pandas_series_naming.py b/pylint_ml/checkers/pandas/pandas_series_naming.py index 8e5e3e2..010b83e 100644 --- a/pylint_ml/checkers/pandas/pandas_series_naming.py +++ b/pylint_ml/checkers/pandas/pandas_series_naming.py @@ -7,12 +7,15 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class PandasSeriesNamingChecker(BaseChecker): + +class PandasSeriesNamingChecker(LibraryBaseChecker): name = "pandas-series-naming" msgs = { "W8103": ( @@ -24,7 +27,10 @@ class PandasSeriesNamingChecker(BaseChecker): @only_required_for_messages("pandas-series-naming") def visit_assign(self, node: nodes.Assign) -> None: - if isinstance(node.value, nodes.Call): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): + return + + if isinstance(node.value, nodes.Call) and infer_specific_module_from_call(node=node.value, module_name=PANDAS): func_name = getattr(node.value.func, "attrname", None) module_name = getattr(node.value.func.expr, "name", None) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index 9b9464e..e754d25 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -5,12 +5,14 @@ """Check for proper usage of Scipy functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import SCIPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -class ScipyParameterChecker(BaseChecker): + +class ScipyParameterChecker(LibraryBaseChecker): name = "scipy-parameter" msgs = { "W8111": ( @@ -33,43 +35,48 @@ class ScipyParameterChecker(BaseChecker): # scipy.stats "ttest_ind": ["a", "b"], "ttest_rel": ["a", "b"], - "norm.pdf": ["x"], + "pdf": ["x"], # scipy.spatial - "distance.euclidean": ["u", "v"], # Full chain - "euclidean": ["u", "v"], # Direct import of euclidean - "KDTree.query": ["x"], + "euclidean": ["u", "v"], + "query": ["x"], } @only_required_for_messages("scipy-parameter") def visit_call(self, node: nodes.Call) -> None: - method_name = self._get_full_method_name(node) - if method_name in self.REQUIRED_PARAMS: - provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters - missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] - if missing_params: - self.add_message( - "scipy-parameter", - node=node, - confidence=HIGH, - args=(", ".join(missing_params), method_name), - ) + if not self.is_library_imported_and_version_valid(lib_name=SCIPY, required_version=None): + return + + # Determine whether the function is a simple Name (method call) + if isinstance(node.func, nodes.Name): + method_name = node.func.name # For cases like minimize() + else: + return # Exit early - def _get_full_method_name(self, node: nodes.Call) -> str: - """ - Extracts the full method name, including handling chained attributes (e.g., scipy.spatial.distance.euclidean) - and also handles direct imports like euclidean. - """ - func = node.func - method_chain = [] + # Perform a lookup in the current scope for the function/method name + scope = node.scope() + name_lookup = scope.lookup(method_name) - # Traverse the attribute chain to get the full method name - while isinstance(func, nodes.Attribute): - method_chain.insert(0, func.attrname) - func = func.expr + if name_lookup: + _, assignments = name_lookup + if assignments: + assignment = assignments[0] + if isinstance(assignment, nodes.ImportFrom): + # Check if the import is from scipy.optimize + # Correctly unpack the names from the tuple + imported_names = [name for name, _ in assignment.names] - # If it's a direct function name, like `euclidean`, return it - if isinstance(func, nodes.Name): - method_chain.insert(0, func.name) + if SCIPY in assignment.modname and method_name in imported_names: + # Proceed with checking parameters + if method_name in self.REQUIRED_PARAMS: + provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} + missing_params = [ + param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords + ] - return ".".join(method_chain) + if missing_params: + self.add_message( + "scipy-parameter", + node=node, + confidence=HIGH, + args=(", ".join(missing_params), method_name), + ) diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index c9ef152..b8f3a95 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -5,12 +5,15 @@ """Check for proper usage of Scikit-learn functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import SKLEARN +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name -class SklearnParameterChecker(BaseChecker): + +class SklearnParameterChecker(LibraryBaseChecker): name = "sklearn-parameter" msgs = { "W8111": ( @@ -37,10 +40,12 @@ class SklearnParameterChecker(BaseChecker): @only_required_for_messages("sklearn-parameter") def visit_call(self, node: nodes.Call) -> None: - method_name = self._get_method_name(node) + if not self.is_library_imported_and_version_valid(lib_name=SKLEARN, required_version=None): + return + + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -49,15 +54,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index c9a334f..cefdeaa 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -5,12 +5,15 @@ """Check for proper usage of Tensorflow functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import TENSORFLOW +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class TensorFlowParameterChecker(BaseChecker): + +class TensorFlowParameterChecker(LibraryBaseChecker): name = "tensor-parameter" msgs = { "W8111": ( @@ -35,10 +38,16 @@ class TensorFlowParameterChecker(BaseChecker): @only_required_for_messages("tensor-parameter") def visit_call(self, node: nodes.Call) -> None: - method_name = self._get_method_name(node) - if method_name in self.REQUIRED_PARAMS: + if not self.is_library_imported_and_version_valid(lib_name=TENSORFLOW, required_version=None): + return + + # TODO UPDATE SOLUTION + + # method_name = get_full_method_name(node) + # if method_name in self.REQUIRED_PARAMS: + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=TENSORFLOW) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -47,15 +56,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index 75d6b4a..4c631f7 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -5,12 +5,15 @@ """Check for proper usage of PyTorch functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.checkers.config import TORCH +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call -class PyTorchParameterChecker(BaseChecker): + +class PyTorchParameterChecker(LibraryBaseChecker): name = "pytorch-parameter" msgs = { "W8111": ( @@ -34,10 +37,14 @@ class PyTorchParameterChecker(BaseChecker): @only_required_for_messages("pytorch-parameter") def visit_call(self, node: nodes.Call) -> None: - method_name = self._get_method_name(node) - if method_name in self.REQUIRED_PARAMS: + if not self.is_library_imported_and_version_valid(lib_name=TORCH, required_version=None): + return + + # TODO UPDATE SOLUTION + + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=TORCH) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -46,15 +53,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/utils.py b/pylint_ml/checkers/utils.py new file mode 100644 index 0000000..a5c489f --- /dev/null +++ b/pylint_ml/checkers/utils.py @@ -0,0 +1,106 @@ +from astroid import nodes +from pylint.checkers.utils import safe_infer + + +def get_full_method_name(node: nodes.Call) -> str: + """ + Extracts the full method name from a Call node, including handling chained calls. + """ + func = node.func + method_chain = [] + + # Traverse the attribute chain to build the full method chain + while isinstance(func, nodes.Attribute): + method_chain.insert(0, func.attrname) + func = func.expr + + # Check if the root of the chain is a Name node (like a module or base name) + if isinstance(func, nodes.Name): + method_chain.insert(0, func.name) # Add the base name + + # Join the method chain to create the full method name + return ".".join(method_chain) + + +def is_specific_library_object(node: nodes.NodeNG, library_name: str) -> bool: + """ + Returns True if the given node is an object from the specified library/module. + + Args: + node: The AST node to check. + library_name: The name of the library/module to check (e.g., 'pandas', 'numpy'). + + Returns: + bool: True if the node belongs to the specified library, False otherwise. + """ + return node and node.root().name == library_name # Checks if the root module matches the library name + + +def infer_module_from_node_chain(start_node: nodes.NodeNG, module_name: str) -> bool: + """ + Traverses the chain of attributes and checks if the root module of the node chain + matches the specified module name (e.g., 'numpy' or 'pandas'). + + Args: + start_node (nodes.NodeNG): The starting node (either Attribute or Call). + module_name (str): The module name to check against (e.g., 'numpy', 'pandas'). + + Returns: + bool: True if the root module matches the specified module_name, False otherwise. + """ + current_node = start_node + + # Traverse backward through the chain, handling Attribute and Name node types + while isinstance(current_node, (nodes.Attribute, nodes.Name)): + print(current_node) + + if isinstance(current_node, nodes.Attribute): + # Infer the current expression (e.g., np.some) + inferred_object = safe_infer(current_node.expr) + if inferred_object is None: + current_node = current_node.expr + else: + current_node = current_node.expr # Step backwards + elif isinstance(current_node, nodes.Name): + # Base case: a Name node is likely a module or variable (e.g., 'np') + inferred_root = safe_infer(current_node) + if inferred_root: + # Check if the inferred object's name matches the module_name + # TODO update solution to handle MODULE and INSTANCE + + if module_name in inferred_root.qname() or inferred_root.qname() == module_name: + return True + else: + return False + else: + return False # If inference of the Name node fails + + return False # Return False if we couldn't infer a valid module + + +def infer_specific_module_from_call(node: nodes.Call, module_name: str) -> bool: + """ + Infers if the function call belongs to the specified module (e.g., 'numpy', 'pandas'). + + Args: + node (nodes.Call): The Call node representing the method call. + module_name (str): The module name to check against (e.g., 'numpy', 'pandas'). + + Returns: + bool: True if the root module matches the specified module_name, False otherwise. + """ + return infer_module_from_node_chain(node.func, module_name) + + +def infer_specific_module_from_attribute(node: nodes.Attribute, module_name: str) -> bool: + """ + Infers if the attribute access belongs to the specified module (e.g., 'numpy', 'pandas'). + + Args: + node (nodes.Attribute): The Attribute node representing the method or attribute access. + module_name (str): The module name to check against (e.g., 'numpy', 'pandas'). + + Returns: + bool: True if the root module matches the specified module_name, False otherwise. + """ + return infer_module_from_node_chain(node, module_name) diff --git a/pylint_ml/util/library_handler.py b/pylint_ml/util/library_handler.py deleted file mode 100644 index 2d54203..0000000 --- a/pylint_ml/util/library_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -from pylint.checkers import BaseChecker - - -class LibraryHandler(BaseChecker): - - def __init__(self, linter): - super().__init__(linter) - self.imports = {} - - def visit_import(self, node): - for name, alias in node.names: - self.imports[alias or name] = name - - def visit_importfrom( - self, - node, - ): - # TODO Update method to handle either: - # 1. Check of specific method-name imported? - # 2. Store all method names importfrom libname? - - module = node.modname - for name, alias in node.names: - full_name = f"{module}.{name}" - self.imports[alias or name] = full_name - - def is_library_imported(self, library_name): - return any(mod.startswith(library_name) for mod in self.imports.values()) - - # def is_library_version_valid(self, lib_version): - # # TODO update solution - # if lib_version is None: - # pass - # return diff --git a/tests/checkers/test_numpy/test_numpy_dot.py b/tests/checkers/test_numpy/test_numpy_dot.py index f01b811..ccda7f8 100644 --- a/tests/checkers/test_numpy/test_numpy_dot.py +++ b/tests/checkers/test_numpy/test_numpy_dot.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,7 +10,9 @@ class TestNumpyDotChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumpyDotChecker - def test_warning_for_dot(self): + @patch("pylint_ml.checkers.library_base_checker.version") + def test_warning_for_dot(self, mock_version): + mock_version.return_value = "1.7.0" import_np, node = astroid.extract_node( """ import numpy as np #@ diff --git a/tests/checkers/test_numpy/test_numpy_nan_comparison.py b/tests/checkers/test_numpy/test_numpy_nan_comparison.py index 6191d2e..58bfa4a 100644 --- a/tests/checkers/test_numpy/test_numpy_nan_comparison.py +++ b/tests/checkers/test_numpy/test_numpy_nan_comparison.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,9 +10,13 @@ class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumpyNaNComparisonChecker - def test_singleton_nan_compare(self): - singleton_node, chained_node, great_than_node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_singleton_nan_compare(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, singleton_node, chained_node, great_than_node = astroid.extract_node( """ + import numpy as np #@ + a_nan = np.array([0, 1, np.nan]) np.nan == a_nan #@ 1 == 1 == np.nan #@ @@ -36,6 +42,7 @@ def test_singleton_nan_compare(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_compare(singleton_node) self.checker.visit_compare(chained_node) self.checker.visit_compare(great_than_node) diff --git a/tests/checkers/test_numpy/test_numpy_parameter.py b/tests/checkers/test_numpy/test_numpy_parameter.py index 40b8d52..969b573 100644 --- a/tests/checkers/test_numpy/test_numpy_parameter.py +++ b/tests/checkers/test_numpy/test_numpy_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,31 +10,36 @@ class TestNumPyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumPyParameterChecker - def test_array_missing_object(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_array_missing_object(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, call_node = astroid.extract_node( """ - import numpy as np + import numpy as np #@ arr = np.array() #@ """ ) - array_call = node.value + call_node = call_node.value with self.assertAddsMessages( pylint.testutils.MessageTest( msg_id="numpy-parameter", confidence=HIGH, - node=array_call, + node=call_node, args=("object", "array"), ), ignore_position=True, ): - self.checker.visit_call(array_call) + self.checker.visit_import(import_node) + self.checker.visit_call(call_node) - def test_zeros_without_shape(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_zeros_without_shape(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, node = astroid.extract_node( """ - import numpy as np + import numpy as np #@ arr = np.zeros() #@ """ ) @@ -48,13 +55,16 @@ def test_zeros_without_shape(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(zeros_call) - def test_random_rand_without_shape(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_random_rand_without_shape(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, node = astroid.extract_node( """ - import numpy as np - arr = np.random.rand() #@ + import numpy as np #@ + arr = np.random.rand() #@ """ ) @@ -65,17 +75,20 @@ def test_random_rand_without_shape(self): msg_id="numpy-parameter", confidence=HIGH, node=rand_call, - args=("d0", "random.rand"), + args=("d0", "rand"), ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(rand_call) - def test_dot_without_b(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_dot_without_b(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, node = astroid.extract_node( """ - import numpy as np - arr = np.dot(a=[1, 2, 3]) #@ + import numpy as np #@ + arr = np.dot(a=[1, 2, 3]) #@ """ ) @@ -90,13 +103,16 @@ def test_dot_without_b(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(dot_call) - def test_percentile_without_q(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_percentile_without_q(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, node = astroid.extract_node( """ - import numpy as np - result = np.percentile(a=[1, 2, 3]) #@ + import numpy as np #@ + result = np.percentile(a=[1, 2, 3]) #@ """ ) @@ -111,4 +127,5 @@ def test_percentile_without_q(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(percentile_call) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py index ec1ae81..0560173 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,10 +10,12 @@ class TestDataFrameBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameBoolChecker - def test_dataframe_bool_usage(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_dataframe_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, call_node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_customers = pd.DataFrame(data) df_customers.bool() #@ """ @@ -20,19 +24,23 @@ def test_dataframe_bool_usage(self): pylint.testutils.MessageTest( msg_id="pandas-dataframe-bool", confidence=HIGH, - node=node, + node=call_node, ), ignore_position=True, ): - self.checker.visit_call(node) + self.checker.visit_import(import_node) + self.checker.visit_call(call_node) - def test_no_bool_usage(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_no_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_customers = pd.DataFrame(data) df_customers.sum() #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(node) diff --git a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py b/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py similarity index 59% rename from tests/checkers/test_pandas/pandas_dataframe_column_selection.py rename to tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py index 0bd3592..e65126d 100644 --- a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,23 +10,26 @@ class TestPandasColumnSelectionChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasColumnSelectionChecker - def test_incorrect_column_selection(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_incorrect_column_selection(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) value = df_sales.A #@ """ ) - column_attribute = node.value + attribute_node = node.value with self.assertAddsMessages( pylint.testutils.MessageTest( msg_id="pandas-column-selection", confidence=HIGH, - node=column_attribute, + node=attribute_node, ), ignore_position=True, ): - self.checker.visit_attribute(column_attribute) + self.checker.visit_import(import_node) + self.checker.visit_attribute(attribute_node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py index 37875f6..4856fe1 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,22 +10,26 @@ class TestPandasEmptyColumnChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasEmptyColumnChecker - def test_correct_empty_column_initialization(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_correct_empty_column_initialization(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import numpy as np - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame() df_sales['new_col_str'] = pd.Series(dtype='object') #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_subscript(node) - def test_incorrect_empty_column_initialization_with_zero(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_incorrect_empty_column_initialization_with_zero(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame() df_sales['new_col_int'] = 0 #@ """ @@ -39,12 +45,15 @@ def test_incorrect_empty_column_initialization_with_zero(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_subscript(subscript_node) - def test_incorrect_empty_column_initialization_with_empty_string(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_incorrect_empty_column_initialization_with_empty_string(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame() df_sales['new_col_str'] = '' #@ """ @@ -60,4 +69,5 @@ def test_incorrect_empty_column_initialization_with_empty_string(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_subscript(subscript_node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py index 721a75e..5464624 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,10 +10,12 @@ class TestPandasIterrowsChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasIterrowsChecker - def test_iterrows_used(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_iterrows_used(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame({ "Product": ["A", "B", "C"], "Sales": [100, 200, 300] @@ -32,4 +36,5 @@ def test_iterrows_used(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(iterrows_call) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py index 558df80..3ab77ce 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,20 +10,25 @@ class TestPandasDataFrameNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameNamingChecker - def test_correct_dataframe_naming(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_correct_dataframe_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_customers = pd.DataFrame(data) #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_assign(node) - def test_incorrect_dataframe_naming(self): - pandas_dataframe_node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_incorrect_dataframe_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, pandas_dataframe_node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ customers = pd.DataFrame(data) #@ """ ) @@ -33,12 +40,15 @@ def test_incorrect_dataframe_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(pandas_dataframe_node) - def test_incorrect_dataframe_name_length(self): - pandas_dataframe_node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_incorrect_dataframe_name_length(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, pandas_dataframe_node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_ = pd.DataFrame(data) #@ """ ) @@ -50,4 +60,5 @@ def test_incorrect_dataframe_name_length(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(pandas_dataframe_node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_values.py b/tests/checkers/test_pandas/test_pandas_dataframe_values.py index 232d9bf..5373c87 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_values.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_values.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,10 +10,12 @@ class TestPandasValuesChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasValuesChecker - def test_values_usage_with_correct_naming(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_values_usage_with_correct_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] @@ -31,4 +35,5 @@ def test_values_usage_with_correct_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_attribute(attribute_node) diff --git a/tests/checkers/test_pandas/test_pandas_inplace.py b/tests/checkers/test_pandas/test_pandas_inplace.py index 20ed034..0a05cfc 100644 --- a/tests/checkers/test_pandas/test_pandas_inplace.py +++ b/tests/checkers/test_pandas/test_pandas_inplace.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,15 +10,17 @@ class TestPandasInplaceChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasInplaceChecker - def test_inplace_used_in_drop(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_inplace_used_in_drop(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }) - df.drop(columns=["A"], inplace=True) #@ + df.drop(columns=["A"], inplace=True) #@ """ ) with self.assertAddsMessages( @@ -27,17 +31,20 @@ def test_inplace_used_in_drop(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) - def test_inplace_used_in_fillna(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_inplace_used_in_fillna(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, None, 3], "B": [4, 5, None] }) - df.fillna(0, inplace=True) #@ + df.fillna(0, inplace=True) #@ """ ) with self.assertAddsMessages( @@ -48,17 +55,20 @@ def test_inplace_used_in_fillna(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) - def test_inplace_used_in_sort_values(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_inplace_used_in_sort_values(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [3, 2, 1], "B": [4, 5, 6] }) - df.sort_values(by="A", inplace=True) #@ + df.sort_values(by="A", inplace=True) #@ """ ) with self.assertAddsMessages( @@ -69,36 +79,43 @@ def test_inplace_used_in_sort_values(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) - def test_no_inplace(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_no_inplace(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }) - df = df.drop(columns=["A"]) #@ + df = df.drop(columns=["A"]) #@ """ ) inplace_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(inplace_call) - def test_inplace_used_in_unsupported_method(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_inplace_used_in_unsupported_method(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }) - df.append({"A": 4, "B": 7}, inplace=True) #@ + df.append({"A": 4, "B": 7}, inplace=True) #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(node) diff --git a/tests/checkers/test_pandas/test_pandas_parameter.py b/tests/checkers/test_pandas/test_pandas_parameter.py index 6cfe8ca..ed4769f 100644 --- a/tests/checkers/test_pandas/test_pandas_parameter.py +++ b/tests/checkers/test_pandas/test_pandas_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,11 +10,13 @@ class TestPandasParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasParameterChecker - def test_dataframe_missing_data(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_dataframe_missing_data(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - df_yoda = pd.DataFrame() #@ + import pandas as pd #@ + df_yoda = pd.DataFrame() #@ """ ) @@ -27,15 +31,18 @@ def test_dataframe_missing_data(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(dataframe_call) - def test_merge_without_required_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_merge_without_required_params(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda1 = pd.DataFrame({'A': [1, 2]}) df_yoda2 = pd.DataFrame({'A': [3, 4]}) - df_yoda_merged = df_yoda1.merge(df_yoda2) #@ + df_yoda_merged = df_yoda1.merge(df_yoda2) #@ """ ) @@ -50,13 +57,16 @@ def test_merge_without_required_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - def test_read_csv_without_filepath(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_read_csv_without_filepath(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - df_yoda = pd.read_csv() #@ + import pandas as pd #@ + df_yoda = pd.read_csv() #@ """ ) @@ -71,14 +81,17 @@ def test_read_csv_without_filepath(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(read_csv_call) - def test_to_csv_without_path(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_to_csv_without_path(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, 2]}) - df_yoda.to_csv() #@ + df_yoda.to_csv() #@ """ ) @@ -93,14 +106,17 @@ def test_to_csv_without_path(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(to_csv_call) - def test_groupby_without_by(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_groupby_without_by(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, 2]}) - df_yoda.groupby() #@ + df_yoda.groupby() #@ """ ) @@ -115,14 +131,17 @@ def test_groupby_without_by(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(groupby_call) - def test_fillna_without_value(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_fillna_without_value(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, None]}) - df_yoda.fillna() #@ + df_yoda.fillna() #@ """ ) @@ -137,14 +156,17 @@ def test_fillna_without_value(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(fillna_call) - def test_sort_values_without_by(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_sort_values_without_by(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, 2]}) - df_yoda.sort_values() #@ + df_yoda.sort_values() #@ """ ) @@ -159,13 +181,17 @@ def test_sort_values_without_by(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(sort_values_call) - def test_merge_with_missing_validate(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_merge_with_missing_validate(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - df_3 = df_1.merge(right=df_2, how='inner', on='col1') #@ + import pandas as pd #@ + df_1 = pd.DataFrame({'A': [1, 2]}) + df_3 = df_1.merge(right=df_2, how='inner', on='col1') #@ """ ) @@ -180,13 +206,17 @@ def test_merge_with_missing_validate(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - def test_merge_with_wrong_naming_and_missing_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_merge_with_wrong_naming_and_missing_params(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - merged_df = df_1.merge(right=df_2) #@ + import pandas as pd #@ + df_1 = pd.DataFrame({'A': [1, 2]}) + merged_df = df_1.merge(right=df_2) #@ """ ) @@ -198,17 +228,22 @@ def test_merge_with_wrong_naming_and_missing_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - def test_merge_with_all_params_and_correct_naming(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_merge_with_all_params_and_correct_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - df_merged = df_1.merge(right=df_2, how='inner', on='col1', validate='1:1') #@ + import pandas as pd #@ + df_1 = pd.DataFrame({'A': [1, 2]}) + df_merged = df_1.merge(right=df_2, how='inner', on='col1', validate='1:1') #@ """ ) merge_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) diff --git a/tests/checkers/test_pandas/test_pandas_series_bool.py b/tests/checkers/test_pandas/test_pandas_series_bool.py index b1d5b42..7efdc64 100644 --- a/tests/checkers/test_pandas/test_pandas_series_bool.py +++ b/tests/checkers/test_pandas/test_pandas_series_bool.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,12 +10,14 @@ class TestSeriesBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesBoolChecker - def test_series_bool_usage(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_series_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ ser_customer = pd.Series(data) - ser_customer.bool() #@ + ser_customer.bool() #@ """ ) with self.assertAddsMessages( @@ -24,15 +28,19 @@ def test_series_bool_usage(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) - def test_no_bool_usage(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_no_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ ser_customer = pd.Series(data) - ser_customer.sum() #@ + ser_customer.sum() #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(node) diff --git a/tests/checkers/test_pandas/test_pandas_series_naming.py b/tests/checkers/test_pandas/test_pandas_series_naming.py index c76d928..58c9cbc 100644 --- a/tests/checkers/test_pandas/test_pandas_series_naming.py +++ b/tests/checkers/test_pandas/test_pandas_series_naming.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,21 +10,26 @@ class TestPandasSeriesNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesNamingChecker - def test_series_correct_naming(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_series_correct_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - ser_sales = pd.Series([100, 200, 300]) + import pandas as pd #@ + ser_sales = pd.Series([100, 200, 300]) #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_assign(node) - def test_series_incorrect_naming(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_series_incorrect_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - df_sales = pd.Series([100, 200, 300]) + import pandas as pd #@ + df_sales = pd.Series([100, 200, 300]) #@ """ ) with self.assertAddsMessages( @@ -33,13 +40,16 @@ def test_series_incorrect_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(node) - def test_series_invalid_length_naming(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_series_invalid_length_naming(self, mock_version): + mock_version.return_value = "2.2.2" + import_node, node = astroid.extract_node( """ - import pandas as pd - ser_ = pd.Series([True]) + import pandas as pd #@ + ser_ = pd.Series([True]) #@ """ ) with self.assertAddsMessages( @@ -50,4 +60,5 @@ def test_series_invalid_length_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(node) diff --git a/tests/checkers/test_scipy/test_scipy_parameter.py b/tests/checkers/test_scipy/test_scipy_parameter.py index 1bec24b..45bbde7 100644 --- a/tests/checkers/test_scipy/test_scipy_parameter.py +++ b/tests/checkers/test_scipy/test_scipy_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,11 +10,13 @@ class TestScipyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = ScipyParameterChecker - def test_minimize_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_minimize_params(self, mock_version): + mock_version.return_value = "1.7.0" + importfrom_node, node = astroid.extract_node( """ - from scipy.optimize import minimize - result = minimize(x0=[1, 2, 3]) #@ + from scipy.optimize import minimize #@ + result = minimize(x0=[1, 2, 3]) #@ """ ) minimize_call = node.value @@ -26,12 +30,15 @@ def test_minimize_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(minimize_call) - def test_curve_fit_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_curve_fit_params(self, mock_version): + mock_version.return_value = "1.7.0" + importfrom_node, node = astroid.extract_node( """ - from scipy.optimize import curve_fit + from scipy.optimize import curve_fit #@ params = curve_fit(xdata=[1, 2, 3], ydata=[4, 5, 6]) #@ """ ) @@ -46,12 +53,15 @@ def test_curve_fit_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(curve_fit_call) - def test_quad_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_quad_params(self, mock_version): + mock_version.return_value = "1.7.0" + importfrom_node, node = astroid.extract_node( """ - from scipy.integrate import quad + from scipy.integrate import quad #@ result = quad(a=0, b=1) #@ """ ) @@ -66,12 +76,15 @@ def test_quad_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(quad_call) - def test_solve_ivp_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_solve_ivp_params(self, mock_version): + mock_version.return_value = "1.7.0" + importfrom_node, node = astroid.extract_node( """ - from scipy.integrate import solve_ivp + from scipy.integrate import solve_ivp #@ result = solve_ivp(fun=None, t_span=[0, 1]) #@ """ ) @@ -86,12 +99,15 @@ def test_solve_ivp_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(solve_ivp_call) - def test_ttest_ind_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_ttest_ind_params(self, mock_version): + mock_version.return_value = "1.7.0" + importfrom_node, node = astroid.extract_node( """ - from scipy.stats import ttest_ind + from scipy.stats import ttest_ind #@ result = ttest_ind(a=[1, 2]) #@ """ ) @@ -106,12 +122,15 @@ def test_ttest_ind_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(ttest_ind_call) - def test_euclidean_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_euclidean_params(self, mock_version): + mock_version.return_value = "1.7.0" + importfrom_node, node = astroid.extract_node( """ - from scipy.spatial.distance import euclidean + from scipy.spatial.distance import euclidean #@ dist = euclidean(u=[1, 2, 3]) #@ """ ) @@ -126,4 +145,5 @@ def test_euclidean_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(euclidean_call) diff --git a/tests/checkers/test_sklearn/test_sklearn_parameter.py b/tests/checkers/test_sklearn/test_sklearn_parameter.py index 9612965..77e744d 100644 --- a/tests/checkers/test_sklearn/test_sklearn_parameter.py +++ b/tests/checkers/test_sklearn/test_sklearn_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,11 +10,13 @@ class TestSklearnParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = SklearnParameterChecker - def test_random_forest_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_random_forest_params(self, mock_version): + mock_version.return_value = "1.5.2" + importfrom_node, node = astroid.extract_node( """ - from sklearn.ensemble import RandomForestClassifier - clf = RandomForestClassifier() #@ + from sklearn.ensemble import RandomForestClassifier #@ + clf = RandomForestClassifier() #@ """ ) @@ -27,26 +31,32 @@ def test_random_forest_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(forest_call) - def test_random_forest_with_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_random_forest_with_params(self, mock_version): + mock_version.return_value = "1.5.2" + importfrom_node, node = astroid.extract_node( """ - from sklearn.ensemble import RandomForestClassifier - clf = RandomForestClassifier(n_estimators=100) #@ + from sklearn.ensemble import RandomForestClassifier #@ + clf = RandomForestClassifier(n_estimators=100) #@ """ ) forest_call = node.value with self.assertNoMessages(): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(forest_call) - def test_svc_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_svc_params(self, mock_version): + mock_version.return_value = "1.5.2" + importfrom_node, node = astroid.extract_node( """ - from sklearn.svm import SVC - clf = SVC() #@ + from sklearn.svm import SVC #@ + clf = SVC() #@ """ ) @@ -61,26 +71,32 @@ def test_svc_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(svc_call) - def test_svc_with_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_svc_with_params(self, mock_version): + mock_version.return_value = "1.5.2" + importfrom_node, node = astroid.extract_node( """ - from sklearn.svm import SVC - clf = SVC(C=1.0, kernel='linear') #@ + from sklearn.svm import SVC #@ + clf = SVC(C=1.0, kernel='linear') #@ """ ) svc_call = node.value with self.assertNoMessages(): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(svc_call) - def test_kmeans_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_kmeans_params(self, mock_version): + mock_version.return_value = "1.5.2" + importfrom_node, node = astroid.extract_node( """ - from sklearn.cluster import KMeans - kmeans = KMeans() #@ + from sklearn.cluster import KMeans #@ + kmeans = KMeans() #@ """ ) @@ -95,17 +111,21 @@ def test_kmeans_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(kmeans_call) - def test_kmeans_with_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_kmeans_with_params(self, mock_version): + mock_version.return_value = "1.5.2" + importfrom_node, node = astroid.extract_node( """ - from sklearn.cluster import KMeans - kmeans = KMeans(n_clusters=8) #@ + from sklearn.cluster import KMeans #@ + kmeans = KMeans(n_clusters=8) #@ """ ) kmeans_call = node.value with self.assertNoMessages(): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(kmeans_call) diff --git a/tests/checkers/test_tensorflow/test_tensor_parameter.py b/tests/checkers/test_tensorflow/test_tensor_parameter.py index 48197dd..b00db15 100644 --- a/tests/checkers/test_tensorflow/test_tensor_parameter.py +++ b/tests/checkers/test_tensorflow/test_tensor_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,11 +10,13 @@ class TestTensorParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = TensorFlowParameterChecker - def test_sequential_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_sequential_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf - model = tf.keras.models.Sequential() #@ + import tensorflow as tf #@ + model = tf.keras.models.Sequential() #@ """ ) @@ -27,30 +31,33 @@ def test_sequential_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(sequential_call) - def test_sequential_with_layers(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_sequential_with_layers(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf - model = tf.keras.Sequential(layers=[ - tf.keras.layers.Dense(units=64, activation='relu'), - tf.keras.layers.Dense(units=10) - ]) + import tensorflow as tf #@ + model = tf.keras.Sequential(layers=[tf.keras.layers.Dense(units=64, activation='relu'),tf.keras.layers.Dense(units=10)]) #@ """ ) sequential_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(sequential_call) - def test_compile_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_compile_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() - model.compile() #@ + model.compile() #@ """ ) @@ -63,29 +70,35 @@ def test_compile_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) - def test_compile_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_compile_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #@ + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #@ """ ) compile_call = node with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(compile_call) - def test_fit_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_fit_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - model.fit(epochs=10) #@ + model.fit(epochs=10) #@ """ ) @@ -100,28 +113,34 @@ def test_fit_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(fit_call) - def test_fit_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_fit_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - model.fit(x=train_data, y=train_labels, epochs=10) #@ + model.fit(x=train_data, y=train_labels, epochs=10) #@ """ ) fit_call = node with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(fit_call) - def test_conv2d_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_conv2d_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Conv2D(kernel_size=(3, 3)) #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Conv2D(kernel_size=(3, 3)) #@ """ ) @@ -136,26 +155,32 @@ def test_conv2d_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) - def test_conv2d_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_conv2d_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)) #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)) #@ """ ) conv2d_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) - def test_dense_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_dense_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Dense() #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Dense() #@ """ ) @@ -170,17 +195,21 @@ def test_dense_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(dense_call) - def test_dense_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_dense_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Dense(units=64) #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Dense(units=64) #@ """ ) dense_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(dense_call) diff --git a/tests/checkers/test_torch/test_torch_parameter.py b/tests/checkers/test_torch/test_torch_parameter.py index 6c81205..73f613b 100644 --- a/tests/checkers/test_torch/test_torch_parameter.py +++ b/tests/checkers/test_torch/test_torch_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,10 +10,12 @@ class TestTorchParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PyTorchParameterChecker - def test_sgd_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_sgd_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.SGD(model.parameters(), momentum=0.9) #@ """ ) @@ -27,12 +31,15 @@ def test_sgd_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(sgd_call) - def test_sgd_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_sgd_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.SGD(lr=0.01) #@ """ ) @@ -40,12 +47,15 @@ def test_sgd_with_all_params(self): sgd_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(sgd_call) - def test_adam_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_adam_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.Adam(model.parameters()) #@ """ ) @@ -61,12 +71,15 @@ def test_adam_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(adam_call) - def test_adam_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_adam_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.Adam(lr=0.001) #@ """ ) @@ -74,12 +87,15 @@ def test_adam_with_all_params(self): adam_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(adam_call) - def test_conv2d_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_conv2d_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Conv2d(in_channels=3, kernel_size=3) #@ """ ) @@ -95,12 +111,15 @@ def test_conv2d_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) - def test_conv2d_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_conv2d_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3) #@ """ ) @@ -108,12 +127,15 @@ def test_conv2d_with_all_params(self): conv2d_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) - def test_linear_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_linear_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Linear(in_features=128) #@ """ ) @@ -129,12 +151,15 @@ def test_linear_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(linear_call) - def test_linear_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_linear_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Linear(in_features=128, out_features=64) #@ """ ) @@ -144,10 +169,12 @@ def test_linear_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(linear_call) - def test_lstm_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_lstm_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.LSTM(input_size=128) #@ """ ) @@ -163,12 +190,15 @@ def test_lstm_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(lstm_call) - def test_lstm_with_all_params(self): - node = astroid.extract_node( + @patch("pylint_ml.checkers.library_base_checker.version") + def test_lstm_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.LSTM(input_size=128, hidden_size=64) #@ """ ) @@ -176,4 +206,5 @@ def test_lstm_with_all_params(self): lstm_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(lstm_call)