Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

60 add library base checker #61

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pylint-ml

About
-----
``pylint-ml`` is a pylint plugin for enhancing code analysis for machine learning and data science
``pylint-ml`` is a pylint plugin for enhancing code analysis for machine learning and data science projects
16 changes: 16 additions & 0 deletions pylint_ml/checkers/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Library names
PANDAS = "pandas"
PANDAS_ALIAS = "pd"

NUMPY = "numpy"
NUMPY_ALIAS = "np"

TENSORFLOW = "tensor"

SCIPY = "scipy"

SKLEARN = "sklearn"

TORCH = "torch"

MATPLOTLIB = "matplotlib"
46 changes: 46 additions & 0 deletions pylint_ml/checkers/library_base_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from importlib.metadata import PackageNotFoundError, version

from pylint.checkers import BaseChecker


class LibraryBaseChecker(BaseChecker):

def __init__(self, linter):
super().__init__(linter)
self.imports = {}

def visit_import(self, node):
for name, alias in node.names:
self.imports[alias or name] = name # E.g. {'pd': 'pandas'}

def visit_importfrom(self, node):
base_module = node.modname.split(".")[0] # Extract the first part of the module name

for name, alias in node.names:
full_name = f"{node.modname}.{name}"
self.imports[base_module] = full_name # E.g. {'scipy': 'scipy.optimize.minimize'}

def is_library_imported_and_version_valid(self, lib_name, required_version):
"""
Checks if the library is imported and whether the installed version is valid (greater than or equal to the
required version).

param lib_name: Name of the library (as a string).
param required_version: The required minimum version (as a string).
return: True if the library is imported and the version is valid, otherwise False.
"""
# Check if the library is imported
if not any(mod.startswith(lib_name) for mod in self.imports.values()):
return False

# Check if the library version is valid
try:
installed_version = version(lib_name)
except PackageNotFoundError:
return False

# Compare versions (this assumes versioning follows standard conventions like '1.2.3')
if required_version is not None and installed_version < required_version:
return False

return True
25 changes: 7 additions & 18 deletions pylint_ml/checkers/matplotlib/matplotlib_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from pylint.checkers.utils import only_required_for_messages
from pylint.interfaces import HIGH

from pylint_ml.util.library_handler import LibraryHandler
from pylint_ml.checkers.config import MATPLOTLIB
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker
from pylint_ml.checkers.utils import get_full_method_name


class MatplotlibParameterChecker(LibraryHandler):
class MatplotlibParameterChecker(LibraryBaseChecker):
name = "matplotlib-parameter"
msgs = {
"W8111": (
Expand Down Expand Up @@ -47,11 +49,10 @@ class MatplotlibParameterChecker(LibraryHandler):

@only_required_for_messages("matplotlib-parameter")
def visit_call(self, node: nodes.Call) -> None:
# TODO Update
# if not self.is_library_imported('matplotlib') and self.is_library_version_valid(lib_version=):
# return
if not self.is_library_imported_and_version_valid(lib_name=MATPLOTLIB, required_version=None):
return

method_name = self._get_full_method_name(node)
method_name = get_full_method_name(node=node)
if method_name in self.REQUIRED_PARAMS:
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
Expand All @@ -62,15 +63,3 @@ def visit_call(self, node: nodes.Call) -> None:
confidence=HIGH,
args=(", ".join(missing_params), method_name),
)

def _get_full_method_name(self, node: nodes.Call) -> str:
func = node.func
method_chain = []

while isinstance(func, nodes.Attribute):
method_chain.insert(0, func.attrname)
func = func.expr
if isinstance(func, nodes.Name):
method_chain.insert(0, func.name)

return ".".join(method_chain)
25 changes: 12 additions & 13 deletions pylint_ml/checkers/numpy/numpy_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from pylint.checkers.utils import only_required_for_messages
from pylint.interfaces import HIGH

from pylint_ml.util.library_handler import LibraryHandler
from pylint_ml.checkers.config import NUMPY
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker
from pylint_ml.checkers.utils import infer_specific_module_from_call


class NumpyDotChecker(LibraryHandler):
class NumpyDotChecker(LibraryBaseChecker):
name = "numpy-dot-checker"
msgs = {
"W8122": (
Expand All @@ -24,19 +26,16 @@ class NumpyDotChecker(LibraryHandler):
),
}

def visit_import(self, node: nodes.Import):
super().visit_import(node=node)

@only_required_for_messages("numpy-dot-usage")
def visit_call(self, node: nodes.Call) -> None:
if not self.is_library_imported("numpy"):
if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None):
return

# Check if the function being called is np.dot
if isinstance(node.func, nodes.Attribute):
func_name = node.func.attrname
module_name = getattr(node.func.expr, "name", None)

if func_name == "dot" and module_name == "np":
# Suggest using np.matmul() instead
self.add_message("numpy-dot-usage", node=node, confidence=HIGH)
if (
isinstance(node.func, nodes.Attribute)
and node.func.attrname == "dot"
and infer_specific_module_from_call(node=node, module_name=NUMPY)
):
# Suggest using np.matmul() instead
self.add_message("numpy-dot-usage", node=node, confidence=HIGH)
13 changes: 10 additions & 3 deletions pylint_ml/checkers/numpy/numpy_nan_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
from __future__ import annotations

from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.checkers.utils import only_required_for_messages
from pylint.interfaces import HIGH

from pylint_ml.checkers.config import NUMPY
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker
from pylint_ml.checkers.utils import infer_specific_module_from_attribute

COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "=="))
NUMPY_NAN = frozenset(("nan", "NaN", "NAN"))


class NumpyNaNComparisonChecker(BaseChecker):
class NumpyNaNComparisonChecker(LibraryBaseChecker):
name = "numpy-nan-compare"
msgs = {
"W8001": (
Expand All @@ -28,15 +31,19 @@ class NumpyNaNComparisonChecker(BaseChecker):
@classmethod
def __is_np_nan_call(cls, node: nodes.Attribute) -> bool:
"""Check if the node represents a call to np.nan."""
return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == "np"
return node.attrname in NUMPY_NAN and (infer_specific_module_from_attribute(node=node, module_name=NUMPY))

@only_required_for_messages("numpy-nan-compare")
def visit_compare(self, node: nodes.Compare) -> None:
if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None):
return

# Check node.left first for numpy nan usage
if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left):
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
return

# Check remaining nodes and operators for numpy nan usage
for op, comparator in node.ops:
if op in COMPARISON_OP and isinstance(comparator, nodes.Attribute) and self.__is_np_nan_call(comparator):
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
Expand Down
50 changes: 18 additions & 32 deletions pylint_ml/checkers/numpy/numpy_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
"""Check for proper usage of numpy functions with required parameters."""

from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.checkers.utils import only_required_for_messages
from pylint.interfaces import HIGH

from pylint_ml.checkers.config import NUMPY
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker
from pylint_ml.checkers.utils import infer_specific_module_from_call

class NumPyParameterChecker(BaseChecker):

class NumPyParameterChecker(LibraryBaseChecker):
name = "numpy-parameter"
msgs = {
"W8111": (
Expand All @@ -33,12 +36,12 @@ class NumPyParameterChecker(BaseChecker):
"eye": ["N"],
"identity": ["n"],
# Random Sampling
"random.rand": ["d0"],
"random.randn": ["d0"],
"random.randint": ["low", "high"],
"random.choice": ["a"],
"random.uniform": ["low", "high"],
"random.normal": ["loc", "scale"],
"rand": ["d0"],
"randn": ["d0"],
"randint": ["low", "high"],
"choice": ["a"],
"uniform": ["low", "high"],
"normal": ["loc", "scale"],
# Mathematical Functions
"sum": ["a"],
"mean": ["a"],
Expand All @@ -59,9 +62,9 @@ class NumPyParameterChecker(BaseChecker):
# Linear Algebra
"dot": ["a", "b"],
"matmul": ["a", "b"],
"linalg.inv": ["a"],
"linalg.eig": ["a"],
"linalg.solve": ["a", "b"],
"inv": ["a"],
"eig": ["a"],
"solve": ["a", "b"],
# Statistical Functions
"percentile": ["a", "q"],
"quantile": ["a", "q"],
Expand All @@ -71,11 +74,12 @@ class NumPyParameterChecker(BaseChecker):

@only_required_for_messages("numpy-parameter")
def visit_call(self, node: nodes.Call) -> None:
method_name = self._get_full_method_name(node)
if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None):
return

if method_name in self.REQUIRED_PARAMS:
method_name = getattr(node.func, "attrname", None)
if infer_specific_module_from_call(node=node, module_name=NUMPY) and method_name in self.REQUIRED_PARAMS:
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
# Collect all missing parameters
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
if missing_params:
self.add_message(
Expand All @@ -84,21 +88,3 @@ def visit_call(self, node: nodes.Call) -> None:
confidence=HIGH,
args=(", ".join(missing_params), method_name),
)

@staticmethod
def _get_full_method_name(node: nodes.Call) -> str:
"""
Extracts the full method name, including chained attributes (e.g., np.random.rand).
"""
func = node.func
method_chain = []

# Traverse the attribute chain
while isinstance(func, nodes.Attribute):
method_chain.insert(0, func.attrname)
func = func.expr

# Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix)
if isinstance(func, nodes.Name) and func.name == "np":
return ".".join(method_chain)
return ""
17 changes: 12 additions & 5 deletions pylint_ml/checkers/pandas/pandas_dataframe_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from __future__ import annotations

from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.checkers.utils import only_required_for_messages
from pylint.interfaces import HIGH

# Todo add version deprecated
from pylint_ml.checkers.config import PANDAS
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker
from pylint_ml.checkers.utils import infer_specific_module_from_call


class PandasDataFrameBoolChecker(BaseChecker):
class PandasDataFrameBoolChecker(LibraryBaseChecker):
name = "pandas-dataframe-bool"
msgs = {
"W8104": (
Expand All @@ -26,13 +27,19 @@ class PandasDataFrameBoolChecker(BaseChecker):

@only_required_for_messages("pandas-dataframe-bool")
def visit_call(self, node: nodes.Call) -> None:
if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version="2.1.0"):
return

if isinstance(node.func, nodes.Attribute):
method_name = getattr(node.func, "attrname", None)

if method_name == "bool":
# Check if the object calling .bool() has a name starting with 'df_'
object_name = getattr(node.func.expr, "name", None)
if object_name and self._is_valid_dataframe_name(object_name):
if (
infer_specific_module_from_call(node=node, module_name=PANDAS)
and object_name
and self._is_valid_dataframe_name(object_name)
):
self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH)

@staticmethod
Expand Down
17 changes: 14 additions & 3 deletions pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from __future__ import annotations

from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.checkers.utils import only_required_for_messages
from pylint.interfaces import HIGH

from pylint_ml.checkers.config import PANDAS
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker
from pylint_ml.checkers.utils import infer_specific_module_from_attribute

class PandasColumnSelectionChecker(BaseChecker):

class PandasColumnSelectionChecker(LibraryBaseChecker):
name = "pandas-column-selection"
msgs = {
"W8118": (
Expand All @@ -25,6 +28,14 @@ class PandasColumnSelectionChecker(BaseChecker):
@only_required_for_messages("pandas-column-selection")
def visit_attribute(self, node: nodes.Attribute) -> None:
"""Check for attribute access that might be a column selection."""
if isinstance(node.expr, nodes.Name) and node.expr.name.startswith("df_"):

if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None):
return

if (
infer_specific_module_from_attribute(node=node, module_name=PANDAS)
and isinstance(node.expr, nodes.Name)
and node.expr.name.startswith("df_")
):
# Issue a warning for property-like access
self.add_message("pandas-column-selection", node=node, confidence=HIGH)
18 changes: 14 additions & 4 deletions pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from __future__ import annotations

from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.checkers.utils import only_required_for_messages
from pylint.checkers.utils import only_required_for_messages, safe_infer
from pylint.interfaces import HIGH

from pylint_ml.checkers.config import PANDAS
from pylint_ml.checkers.library_base_checker import LibraryBaseChecker

class PandasEmptyColumnChecker(BaseChecker):

class PandasEmptyColumnChecker(LibraryBaseChecker):
name = "pandas-dataframe-empty-column"
msgs = {
"W8113": (
Expand All @@ -25,7 +27,15 @@ class PandasEmptyColumnChecker(BaseChecker):

@only_required_for_messages("pandas-dataframe-empty-column")
def visit_subscript(self, node: nodes.Subscript) -> None:
if isinstance(node.value, nodes.Name) and node.value.name.startswith("df_"):
if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None):
return

if (
isinstance(node.value, nodes.Name)
and node.value.name.startswith("df_")
and PANDAS in safe_infer(node.value).qname()
):
print(node.value.name)
if isinstance(node.slice, nodes.Const) and isinstance(node.parent, nodes.Assign):
if isinstance(node.parent.value, nodes.Const):
# Checking for filler values: 0 or empty string
Expand Down
Loading