diff --git a/pylint/pyreverse/diadefslib.py b/pylint/pyreverse/diadefslib.py index c1f14ce727..f91a89eca8 100644 --- a/pylint/pyreverse/diadefslib.py +++ b/pylint/pyreverse/diadefslib.py @@ -8,6 +8,7 @@ import argparse import warnings +from collections import defaultdict from collections.abc import Generator, Sequence from typing import Any @@ -15,7 +16,7 @@ from astroid import nodes from astroid.modutils import is_stdlib_module -from pylint.pyreverse.diagrams import ClassDiagram, PackageDiagram +from pylint.pyreverse.diagrams import ClassDiagram, PackageDiagram, Relationship from pylint.pyreverse.inspector import Linker, Project from pylint.pyreverse.utils import LocalsVisitor @@ -272,6 +273,99 @@ def __init__(self, config: argparse.Namespace, args: Sequence[str]) -> None: self.config = config self.args = args + def _get_object_name(self, obj: Any) -> str: + """Get object name safely handling both title attributes and strings. + + :param obj: Object to get name from + :return: Object name/title + :rtype: str + """ + return obj.title if hasattr(obj, "title") else str(obj) + + def _process_relationship( + self, relationship: Relationship, unique_classes: dict[str, Any], obj: Any + ) -> None: + """Process a single relationship for deduplication. + + :param relationship: Relationship to process + :param unique_classes: Dict of unique classes + :param obj: Current object being processed + """ + if relationship.from_object == obj: + relationship.from_object = unique_classes[obj.node.qname()] + if relationship.to_object == obj: + relationship.to_object = unique_classes[obj.node.qname()] + + def _process_class_relationships( + self, diagram: ClassDiagram, obj: Any, unique_classes: dict[str, Any] + ) -> None: + """Merge relationships for a class. + + :param diagram: Current diagram + :param obj: Object whose relationships to process + :param unique_classes: Dict of unique classes + """ + for rel_type in ("specialization", "association", "aggregation"): + for rel in diagram.get_relationships(rel_type): + self._process_relationship(rel, unique_classes, obj) + + def deduplicate_classes(self, diagrams: list[ClassDiagram]) -> list[ClassDiagram]: + """Remove duplicate classes from diagrams.""" + for diagram in diagrams: + # Track unique classes by qualified name + unique_classes: dict[str, Any] = {} + duplicate_classes: Any = set() + + # First pass - identify duplicates + for obj in diagram.objects: + qname = obj.node.qname() + if qname in unique_classes: + duplicate_classes.add(obj) + self._process_class_relationships(diagram, obj, unique_classes) + else: + unique_classes[qname] = obj + + # Second pass - filter out duplicates + diagram.objects = [ + obj for obj in diagram.objects if obj not in duplicate_classes + ] + + return diagrams + + def _process_relationship_type( + self, + rel_list: list[Relationship], + seen: set[tuple[str, str, str, Any | None]], + unique_rels: dict[str, list[Relationship]], + rel_name: str, + ) -> None: + """Process a list of relationships of a single type. + + :param rel_list: List of relationships to process + :param seen: Set of seen relationships + :param unique_rels: Dict to store unique relationships + :param rel_name: Name of relationship type + """ + for rel in rel_list: + key = ( + self._get_object_name(rel.from_object), + self._get_object_name(rel.to_object), + type(rel).__name__, + getattr(rel, "name", None), + ) + if key not in seen: + seen.add(key) + unique_rels[rel_name].append(rel) + + def deduplicate_relationships(self, diagram: ClassDiagram) -> None: + """Remove duplicate relationships between objects.""" + seen: set[tuple[str, str, str, Any | None]] = set() + unique_rels: dict[str, list[Relationship]] = defaultdict(list) + for rel_name, rel_list in diagram.relationships.items(): + self._process_relationship_type(rel_list, seen, unique_rels, rel_name) + + diagram.relationships = dict(unique_rels) + def get_diadefs(self, project: Project, linker: Linker) -> list[ClassDiagram]: """Get the diagram's configuration data. @@ -292,4 +386,5 @@ def get_diadefs(self, project: Project, linker: Linker) -> list[ClassDiagram]: diagrams = DefaultDiadefGenerator(linker, self).visit(project) for diagram in diagrams: diagram.extract_relationships() - return diagrams + self.deduplicate_relationships(diagram) + return self.deduplicate_classes(diagrams) diff --git a/tests/pyreverse/functional/class_diagrams/aggregation/class_attribute_duplicate.mmd b/tests/pyreverse/functional/class_diagrams/aggregation/class_attribute_duplicate.mmd new file mode 100644 index 0000000000..5e451ec3af --- /dev/null +++ b/tests/pyreverse/functional/class_diagrams/aggregation/class_attribute_duplicate.mmd @@ -0,0 +1,9 @@ +classDiagram + class A { + var : int + } + class B { + a_obj + func() + } + A --* B : a_obj diff --git a/tests/pyreverse/functional/class_diagrams/aggregation/class_attribute_duplicate.py b/tests/pyreverse/functional/class_diagrams/aggregation/class_attribute_duplicate.py new file mode 100644 index 0000000000..eccd0aee86 --- /dev/null +++ b/tests/pyreverse/functional/class_diagrams/aggregation/class_attribute_duplicate.py @@ -0,0 +1,13 @@ +# Test for issue #9267 +class A: + def __init__(self) -> None: + self.var = 2 + + +class B: + def __init__(self) -> None: + self.a_obj = A() + + def func(self): + self.a_obj = A() + self.a_obj = A()