diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index c4c0dc3496..b3398015a9 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import itertools import logging import re import traceback @@ -93,16 +93,12 @@ def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = No node_clusters[graph.nodes[node]['entity_type']].append(node) candidate_resolution = {entity_type: [] for entity_type in entity_types} - for node_cluster in node_clusters.items(): + for k, v in node_clusters.items(): candidate_resolution_tmp = [] - for a in node_cluster[1]: - for b in node_cluster[1]: - if a == b: - continue - if self.is_similarity(a, b) and (b, a) not in candidate_resolution_tmp: - candidate_resolution_tmp.append((a, b)) - if candidate_resolution_tmp: - candidate_resolution[node_cluster[0]] = candidate_resolution_tmp + for a, b in itertools.permutations(v, 2): + if self.is_similarity(a, b) and (b, a) not in candidate_resolution_tmp: + candidate_resolution_tmp.append((a, b)) + candidate_resolution[k] = candidate_resolution_tmp gen_conf = {"temperature": 0.5} resolution_result = set()