diff --git a/magic_pdf/model/magic_model.py b/magic_pdf/model/magic_model.py index bd8e061a..673177b3 100644 --- a/magic_pdf/model/magic_model.py +++ b/magic_pdf/model/magic_model.py @@ -110,6 +110,26 @@ def __init__(self, model_list: list, docs: fitz.Document): self.__fix_by_remove_high_iou_and_low_confidence() self.__fix_footnote() + def _bbox_distance(self, bbox1, bbox2): + left, right, bottom, top = bbox_relative_pos(bbox1, bbox2) + flags = [left, right, bottom, top] + count = sum([1 if v else 0 for v in flags]) + if count > 1: + return float('inf') + if left or right: + l1 = bbox1[3] - bbox1[1] + l2 = bbox2[3] - bbox2[1] + minL, maxL = min(l1, l2), max(l1, l2) + if (maxL - minL) / minL > 0.5: + return float('inf') + if bottom or top: + l1 = bbox1[2] - bbox1[0] + l2 = bbox2[2] - bbox2[0] + minL, maxL = min(l1, l2), max(l1, l2) + if (maxL - minL) / minL > 0.5: + return float('inf') + return bbox_distance(bbox1, bbox2) + def __fix_footnote(self): # 3: figure, 5: table, 7: footnote for model_page_info in self.__model_list: @@ -144,7 +164,7 @@ def __fix_footnote(self): if pos_flag_count > 1: continue dis_figure_footnote[i] = min( - bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), + self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), dis_figure_footnote.get(i, float('inf')), ) for i in range(len(footnotes)): @@ -163,7 +183,7 @@ def __fix_footnote(self): continue dis_table_footnote[i] = min( - bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), + self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), dis_table_footnote.get(i, float('inf')), ) for i in range(len(footnotes)): @@ -350,7 +370,7 @@ def expand_bbbox(idxes): dis[j][i] = dis[i][j] continue - dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) + dis[i][j] = self._bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[j][i] = dis[i][j] used = set() @@ -441,7 +461,7 @@ def expand_bbbox(idxes): if is_nearest: nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k]) - n_dis = bbox_distance( + n_dis = self._bbox_distance( all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1] ) if float_gt(dis[i][j], n_dis): @@ -537,7 +557,7 @@ def expand_bbbox(idxes): # 计算已经配对的 distance 距离 for i in subject_object_relation_map.keys(): for j in subject_object_relation_map[i]: - total_subject_object_dis += bbox_distance( + total_subject_object_dis += self._bbox_distance( all_bboxes[i]['bbox'], all_bboxes[j]['bbox'] )