diff --git a/magic_pdf/libs/clean_memory.py b/magic_pdf/libs/clean_memory.py new file mode 100644 index 00000000..6bfc174f --- /dev/null +++ b/magic_pdf/libs/clean_memory.py @@ -0,0 +1,10 @@ +# Copyright (c) Opendatalab. All rights reserved. +import torch +import gc + + +def clean_memory(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + gc.collect() \ No newline at end of file diff --git a/magic_pdf/libs/draw_bbox.py b/magic_pdf/libs/draw_bbox.py index 1346c392..36265fb2 100644 --- a/magic_pdf/libs/draw_bbox.py +++ b/magic_pdf/libs/draw_bbox.py @@ -33,7 +33,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config): ) # Draw the rectangle -def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config): +def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True): new_rgb = [] for item in rgb_config: item = float(item) / 255 @@ -42,31 +42,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config): for j, bbox in enumerate(page_data): x0, y0, x1, y1 = bbox rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle - if fill_config: - page.draw_rect( - rect_coords, - color=None, - fill=new_rgb, - fill_opacity=0.3, - width=0.5, - overlay=True, - ) # Draw the rectangle - else: - page.draw_rect( - rect_coords, - color=new_rgb, - fill=None, - fill_opacity=1, - width=0.5, - overlay=True, - ) # Draw the rectangle + if draw_bbox: + if fill_config: + page.draw_rect( + rect_coords, + color=None, + fill=new_rgb, + fill_opacity=0.3, + width=0.5, + overlay=True, + ) # Draw the rectangle + else: + page.draw_rect( + rect_coords, + color=new_rgb, + fill=None, + fill_opacity=1, + width=0.5, + overlay=True, + ) # Draw the rectangle page.insert_text( - (x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb + (x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb ) # Insert the index in the top left corner of the rectangle def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): - layout_bbox_list = [] dropped_bbox_list = [] tables_list, tables_body_list = [], [] tables_caption_list, tables_footnote_list = [], [] @@ -76,16 +76,14 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): texts_list = [] interequations_list = [] for page in pdf_info: - page_layout_list = [] + page_dropped_list = [] tables, tables_body, tables_caption, tables_footnote = [], [], [], [] imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], [] titles = [] texts = [] interequations = [] - for layout in page['layout_bboxes']: - page_layout_list.append(layout['layout_bbox']) - layout_bbox_list.append(page_layout_list) + for dropped_bbox in page['discarded_blocks']: page_dropped_list.append(dropped_bbox['bbox']) dropped_bbox_list.append(page_dropped_list) @@ -129,9 +127,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): texts_list.append(texts) interequations_list.append(interequations) + layout_bbox_list = [] + + for page in pdf_info: + page_block_list = [] + for block in page['para_blocks']: + bbox = block['bbox'] + page_block_list.append(bbox) + layout_bbox_list.append(page_block_list) + pdf_docs = fitz.open('pdf', pdf_bytes) + for i, page in enumerate(pdf_docs): - draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False) + draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True) draw_bbox_without_number(i, tables_list, page, [153, 153, 0], @@ -146,13 +154,15 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True) - draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], + draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True), draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True) + draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False) + # Save the PDF pdf_docs.save(f'{out_path}/{filename}_layout.pdf') @@ -211,9 +221,9 @@ def get_span_info(span): # 构造其余useful_list for block in page['para_blocks']: if block['type'] in [ - BlockType.Text, - BlockType.Title, - BlockType.InterlineEquation, + BlockType.Text, + BlockType.Title, + BlockType.InterlineEquation, ]: for line in block['lines']: for span in line['spans']: @@ -232,10 +242,8 @@ def get_span_info(span): for i, page in enumerate(pdf_docs): # 获取当前页面的数据 draw_bbox_without_number(i, text_list, page, [255, 0, 0], False) - draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], - False) - draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], - False) + draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False) + draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False) draw_bbox_without_number(i, image_list, page, [255, 204, 0], False) draw_bbox_without_number(i, table_list, page, [204, 0, 255], False) draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False) @@ -244,7 +252,7 @@ def get_span_info(span): pdf_docs.save(f'{out_path}/{filename}_spans.pdf') -def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): +def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): dropped_bbox_list = [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [] @@ -279,7 +287,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): elif layout_det['category_id'] == CategoryId.ImageCaption: imgs_caption.append(bbox) elif layout_det[ - 'category_id'] == CategoryId.InterlineEquation_YOLO: + 'category_id'] == CategoryId.InterlineEquation_YOLO: interequations.append(bbox) elif layout_det['category_id'] == CategoryId.Abandon: page_dropped_list.append(bbox) @@ -316,3 +324,47 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): # Save the PDF pdf_docs.save(f'{out_path}/{filename}_model.pdf') + + +def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): + layout_bbox_list = [] + + for page in pdf_info: + page_line_list = [] + for block in page['preproc_blocks']: + if block['type'] in ['text', 'title', 'interline_equation']: + for line in block['lines']: + bbox = line['bbox'] + index = line['index'] + page_line_list.append({'index': index, 'bbox': bbox}) + if block['type'] in ['table', 'image']: + bbox = block['bbox'] + index = block['index'] + page_line_list.append({'index': index, 'bbox': bbox}) + # for line in block['lines']: + # bbox = line['bbox'] + # index = line['index'] + # page_line_list.append({'index': index, 'bbox': bbox}) + sorted_bboxes = sorted(page_line_list, key=lambda x: x['index']) + layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes) + pdf_docs = fitz.open('pdf', pdf_bytes) + for i, page in enumerate(pdf_docs): + draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False) + + pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf') + + +def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename): + layout_bbox_list = [] + + for page in pdf_info: + page_block_list = [] + for block in page['para_blocks']: + bbox = block['bbox'] + page_block_list.append(bbox) + layout_bbox_list.append(page_block_list) + pdf_docs = fitz.open('pdf', pdf_bytes) + for i, page in enumerate(pdf_docs): + draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False) + + pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf') diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 6c5b9d18..f0bc468d 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -3,6 +3,7 @@ import time from magic_pdf.libs.Constants import * +from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.model.model_list import AtomicModel os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 @@ -330,6 +331,8 @@ def __call__(self, image): elif int(res['category_id']) in [5]: table_res_list.append(res) + clean_memory() + # ocr识别 if self.apply_ocr: ocr_start = time.time() diff --git a/magic_pdf/model/v3/__init__.py b/magic_pdf/model/v3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/v3/helpers.py b/magic_pdf/model/v3/helpers.py new file mode 100644 index 00000000..dfe71a89 --- /dev/null +++ b/magic_pdf/model/v3/helpers.py @@ -0,0 +1,125 @@ +from collections import defaultdict +from typing import List, Dict + +import torch +from transformers import LayoutLMv3ForTokenClassification + +MAX_LEN = 510 +CLS_TOKEN_ID = 0 +UNK_TOKEN_ID = 3 +EOS_TOKEN_ID = 2 + + +class DataCollator: + def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]: + bbox = [] + labels = [] + input_ids = [] + attention_mask = [] + + # clip bbox and labels to max length, build input_ids and attention_mask + for feature in features: + _bbox = feature["source_boxes"] + if len(_bbox) > MAX_LEN: + _bbox = _bbox[:MAX_LEN] + _labels = feature["target_index"] + if len(_labels) > MAX_LEN: + _labels = _labels[:MAX_LEN] + _input_ids = [UNK_TOKEN_ID] * len(_bbox) + _attention_mask = [1] * len(_bbox) + assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask) + bbox.append(_bbox) + labels.append(_labels) + input_ids.append(_input_ids) + attention_mask.append(_attention_mask) + + # add CLS and EOS tokens + for i in range(len(bbox)): + bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]] + labels[i] = [-100] + labels[i] + [-100] + input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID] + attention_mask[i] = [1] + attention_mask[i] + [1] + + # padding to max length + max_len = max(len(x) for x in bbox) + for i in range(len(bbox)): + bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i])) + labels[i] = labels[i] + [-100] * (max_len - len(labels[i])) + input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i])) + attention_mask[i] = attention_mask[i] + [0] * ( + max_len - len(attention_mask[i]) + ) + + ret = { + "bbox": torch.tensor(bbox), + "attention_mask": torch.tensor(attention_mask), + "labels": torch.tensor(labels), + "input_ids": torch.tensor(input_ids), + } + # set label > MAX_LEN to -100, because original labels may be > MAX_LEN + ret["labels"][ret["labels"] > MAX_LEN] = -100 + # set label > 0 to label-1, because original labels are 1-indexed + ret["labels"][ret["labels"] > 0] -= 1 + return ret + + +def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]: + bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]] + input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID] + attention_mask = [1] + [1] * len(boxes) + [1] + return { + "bbox": torch.tensor([bbox]), + "attention_mask": torch.tensor([attention_mask]), + "input_ids": torch.tensor([input_ids]), + } + + +def prepare_inputs( + inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification +) -> Dict[str, torch.Tensor]: + ret = {} + for k, v in inputs.items(): + v = v.to(model.device) + if torch.is_floating_point(v): + v = v.to(model.dtype) + ret[k] = v + return ret + + +def parse_logits(logits: torch.Tensor, length: int) -> List[int]: + """ + parse logits to orders + + :param logits: logits from model + :param length: input length + :return: orders + """ + logits = logits[1 : length + 1, :length] + orders = logits.argsort(descending=False).tolist() + ret = [o.pop() for o in orders] + while True: + order_to_idxes = defaultdict(list) + for idx, order in enumerate(ret): + order_to_idxes[order].append(idx) + # filter idxes len > 1 + order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1} + if not order_to_idxes: + break + # filter + for order, idxes in order_to_idxes.items(): + # find original logits of idxes + idxes_to_logit = {} + for idx in idxes: + idxes_to_logit[idx] = logits[idx, order] + idxes_to_logit = sorted( + idxes_to_logit.items(), key=lambda x: x[1], reverse=True + ) + # keep the highest logit as order, set others to next candidate + for idx, _ in idxes_to_logit[1:]: + ret[idx] = orders[idx].pop() + + return ret + + +def check_duplicate(a: List[int]) -> bool: + return len(a) != len(set(a)) diff --git a/magic_pdf/pdf_parse_by_ocr.py b/magic_pdf/pdf_parse_by_ocr.py index 42d9acbd..0686d59e 100644 --- a/magic_pdf/pdf_parse_by_ocr.py +++ b/magic_pdf/pdf_parse_by_ocr.py @@ -1,4 +1,4 @@ -from magic_pdf.pdf_parse_union_core import pdf_parse_union +from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union def parse_pdf_by_ocr(pdf_bytes, diff --git a/magic_pdf/pdf_parse_by_txt.py b/magic_pdf/pdf_parse_by_txt.py index 21d11766..bd8e202d 100644 --- a/magic_pdf/pdf_parse_by_txt.py +++ b/magic_pdf/pdf_parse_by_txt.py @@ -1,4 +1,4 @@ -from magic_pdf.pdf_parse_union_core import pdf_parse_union +from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union def parse_pdf_by_txt( diff --git a/magic_pdf/pdf_parse_union_core_v2.py b/magic_pdf/pdf_parse_union_core_v2.py new file mode 100644 index 00000000..ec5905e0 --- /dev/null +++ b/magic_pdf/pdf_parse_union_core_v2.py @@ -0,0 +1,451 @@ +import statistics +import time + +from loguru import logger + +from typing import List + +import torch + +from magic_pdf.libs.clean_memory import clean_memory +from magic_pdf.libs.commons import fitz, get_delta_time +from magic_pdf.libs.convert_utils import dict_to_list +from magic_pdf.libs.drop_reason import DropReason +from magic_pdf.libs.hash_utils import compute_md5 +from magic_pdf.libs.local_math import float_equal +from magic_pdf.libs.ocr_content_type import ContentType +from magic_pdf.model.magic_model import MagicModel +from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker +from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 +from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table +from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \ + combine_chars_to_pymudict +from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2 +from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans, fix_discarded_block +from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \ + remove_overlaps_low_confidence_spans +from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap + + +def remove_horizontal_overlap_block_which_smaller(all_bboxes): + useful_blocks = [] + for bbox in all_bboxes: + useful_blocks.append({ + "bbox": bbox[:4] + }) + is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks) + if is_useful_block_horz_overlap: + logger.warning( + f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}") + for bbox in all_bboxes.copy(): + if smaller_bbox == bbox[:4]: + all_bboxes.remove(bbox) + + return is_useful_block_horz_overlap, all_bboxes + + +def __replace_STX_ETX(text_str:str): + """ Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks. +Drawback: This issue is only observed in English text; it has not been found in Chinese text so far. + + Args: + text_str (str): raw text + + Returns: + _type_: replaced text + """ + if text_str: + s = text_str.replace('\u0002', "'") + s = s.replace("\u0003", "'") + return s + return text_str + + +def txt_spans_extract(pdf_page, inline_equations, interline_equations): + text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] + char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[ + "blocks" + ] + text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks) + text_blocks = replace_equations_in_textblock( + text_blocks, inline_equations, interline_equations + ) + text_blocks = remove_citation_marker(text_blocks) + text_blocks = remove_chars_in_text_blocks(text_blocks) + spans = [] + for v in text_blocks: + for line in v["lines"]: + for span in line["spans"]: + bbox = span["bbox"] + if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]): + continue + if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation): + spans.append( + { + "bbox": list(span["bbox"]), + "content": __replace_STX_ETX(span["text"]), + "type": ContentType.Text, + "score": 1.0, + } + ) + return spans + + +def replace_text_span(pymu_spans, ocr_spans): + return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans + + +def model_init(model_name: str, local_path=None): + from transformers import LayoutLMv3ForTokenClassification + if torch.cuda.is_available(): + device = torch.device("cuda") + if torch.cuda.is_bf16_supported(): + supports_bfloat16 = True + else: + supports_bfloat16 = False + else: + device = torch.device("cpu") + supports_bfloat16 = False + + if model_name == "layoutreader": + if local_path: + model = LayoutLMv3ForTokenClassification.from_pretrained(local_path) + else: + model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader") + # 检查设备是否支持 bfloat16 + if supports_bfloat16: + model.bfloat16() + model.to(device).eval() + else: + logger.error("model name not allow") + exit(1) + return model + + +class ModelSingleton: + _instance = None + _models = {} + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def get_model(self, model_name: str, local_path=None): + if model_name not in self._models: + if local_path: + self._models[model_name] = model_init(model_name=model_name, local_path=local_path) + else: + self._models[model_name] = model_init(model_name=model_name) + return self._models[model_name] + + +def do_predict(boxes: List[List[int]], model) -> List[int]: + from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits + inputs = boxes2inputs(boxes) + inputs = prepare_inputs(inputs, model) + logits = model(**inputs).logits.cpu().squeeze(0) + return parse_logits(logits, len(boxes)) + + +def cal_block_index(fix_blocks, sorted_bboxes): + for block in fix_blocks: + # if block['type'] in ['text', 'title', 'interline_equation']: + # line_index_list = [] + # if len(block['lines']) == 0: + # block['index'] = sorted_bboxes.index(block['bbox']) + # else: + # for line in block['lines']: + # line['index'] = sorted_bboxes.index(line['bbox']) + # line_index_list.append(line['index']) + # median_value = statistics.median(line_index_list) + # block['index'] = median_value + # + # elif block['type'] in ['table', 'image']: + # block['index'] = sorted_bboxes.index(block['bbox']) + + line_index_list = [] + if len(block['lines']) == 0: + block['index'] = sorted_bboxes.index(block['bbox']) + else: + for line in block['lines']: + line['index'] = sorted_bboxes.index(line['bbox']) + line_index_list.append(line['index']) + median_value = statistics.median(line_index_list) + block['index'] = median_value + + # 删除图表block中的虚拟line信息 + if block['type'] in ['table', 'image']: + del block['lines'] + + return fix_blocks + + +def insert_lines_into_block(block_bbox, line_height, page_w, page_h): + # block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标 + x0, y0, x1, y1 = block_bbox + + block_height = y1 - y0 + block_weight = x1 - x0 + + # 如果block高度小于n行正文,则直接返回block的bbox + if line_height*3 < block_height: + if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25: # 可能是双列结构,可以切细点 + lines = int(block_height/line_height) + else: + # 如果block的宽度超过0.4页面宽度,则将block分成3行 + if block_weight > page_w*0.4: + line_height = (y1 - y0) / 3 + lines = 3 + elif block_weight > page_w*0.25: # 否则将block分成两行 + line_height = (y1 - y0) / 2 + lines = 2 + else: # 判断长宽比 + if block_height/block_weight > 1.2: # 细长的不分 + return [[x0, y0, x1, y1]] + else: # 不细长的还是分成两行 + line_height = (y1 - y0) / 2 + lines = 2 + + # 确定从哪个y位置开始绘制线条 + current_y = y0 + + # 用于存储线条的位置信息[(x0, y), ...] + lines_positions = [] + + for i in range(lines): + lines_positions.append([x0, current_y, x1, current_y + line_height]) + current_y += line_height + return lines_positions + + else: + return [[x0, y0, x1, y1]] + + +def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): + page_line_list = [] + for block in fix_blocks: + if block['type'] in ['text', 'title', 'interline_equation']: + if len(block['lines']) == 0: + bbox = block['bbox'] + lines = insert_lines_into_block(bbox, line_height, page_w, page_h) + for line in lines: + block['lines'].append({'bbox': line, 'spans': []}) + page_line_list.extend(lines) + else: + for line in block['lines']: + bbox = line['bbox'] + page_line_list.append(bbox) + elif block['type'] in ['table', 'image']: + bbox = block['bbox'] + lines = insert_lines_into_block(bbox, line_height, page_w, page_h) + block['lines'] = [] + for line in lines: + block['lines'].append({'bbox': line, 'spans': []}) + page_line_list.extend(lines) + + # 使用layoutreader排序 + x_scale = 1000.0 / page_w + y_scale = 1000.0 / page_h + boxes = [] + # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}") + for left, top, right, bottom in page_line_list: + if left < 0: + logger.warning( + f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") + left = 0 + if right > page_w: + logger.warning( + f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") + right = page_w + if top < 0: + logger.warning( + f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") + top = 0 + if bottom > page_h: + logger.warning( + f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") + bottom = page_h + + left = round(left * x_scale) + top = round(top * y_scale) + right = round(right * x_scale) + bottom = round(bottom * y_scale) + assert ( + 1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0 + ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}" + boxes.append([left, top, right, bottom]) + model_manager = ModelSingleton() + model = model_manager.get_model("layoutreader") + with torch.no_grad(): + orders = do_predict(boxes, model) + sorted_bboxes = [page_line_list[i] for i in orders] + + return sorted_bboxes + + +def get_line_height(blocks): + page_line_height_list = [] + for block in blocks: + if block['type'] in ['text', 'title', 'interline_equation']: + for line in block['lines']: + bbox = line['bbox'] + page_line_height_list.append(int(bbox[3]-bbox[1])) + if len(page_line_height_list) > 0: + return statistics.median(page_line_height_list) + else: + return 10 + + +def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode): + need_drop = False + drop_reason = [] + + '''从magic_model对象中获取后面会用到的区块信息''' + img_blocks = magic_model.get_imgs(page_id) + table_blocks = magic_model.get_tables(page_id) + discarded_blocks = magic_model.get_discarded(page_id) + text_blocks = magic_model.get_text_blocks(page_id) + title_blocks = magic_model.get_title_blocks(page_id) + inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id) + + page_w, page_h = magic_model.get_page_size(page_id) + + spans = magic_model.get_all_spans(page_id) + + '''根据parse_mode,构造spans''' + if parse_mode == "txt": + """ocr 中文本类的 span 用 pymu spans 替换!""" + pymu_spans = txt_spans_extract( + pdf_docs[page_id], inline_equations, interline_equations + ) + spans = replace_text_span(pymu_spans, spans) + elif parse_mode == "ocr": + pass + else: + raise Exception("parse_mode must be txt or ocr") + + '''删除重叠spans中置信度较低的那些''' + spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans) + '''删除重叠spans中较小的那些''' + spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans) + '''对image和table截图''' + spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter) + + '''将所有区块的bbox整理到一起''' + # interline_equation_blocks参数不够准,后面切换到interline_equations上 + interline_equation_blocks = [] + if len(interline_equation_blocks) > 0: + all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( + img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, + interline_equation_blocks, page_w, page_h) + else: + all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( + img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, + interline_equations, page_w, page_h) + + '''先处理不需要排版的discarded_blocks''' + discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4) + fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans) + + '''如果当前页面没有bbox则跳过''' + if len(all_bboxes) == 0: + logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}") + return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], + [], [], interline_equations, fix_discarded_blocks, + need_drop, drop_reason) + + '''将span填入blocks中''' + block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3) + + '''对block进行fix操作''' + fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks) + + '''获取所有line并计算正文line的高度''' + line_height = get_line_height(fix_blocks) + + '''获取所有line并对line排序''' + sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height) + + '''根据line的中位数算block的序列关系''' + fix_blocks = cal_block_index(fix_blocks, sorted_bboxes) + + '''重排block''' + sorted_blocks = sorted(fix_blocks, key=lambda b: b['index']) + + '''获取QA需要外置的list''' + images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks) + + '''构造pdf_info_dict''' + page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [], + images, tables, interline_equations, fix_discarded_blocks, + need_drop, drop_reason) + return page_info + + +def pdf_parse_union(pdf_bytes, + model_list, + imageWriter, + parse_mode, + start_page_id=0, + end_page_id=None, + debug_mode=False, + ): + pdf_bytes_md5 = compute_md5(pdf_bytes) + pdf_docs = fitz.open("pdf", pdf_bytes) + + '''初始化空的pdf_info_dict''' + pdf_info_dict = {} + + '''用model_list和docs对象初始化magic_model''' + magic_model = MagicModel(model_list, pdf_docs) + + '''根据输入的起始范围解析pdf''' + # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1 + end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1 + + if end_page_id > len(pdf_docs) - 1: + logger.warning("end_page_id is out of range, use pdf_docs length") + end_page_id = len(pdf_docs) - 1 + + '''初始化启动时间''' + start_time = time.time() + + for page_id, page in enumerate(pdf_docs): + '''debug时输出每页解析的耗时''' + if debug_mode: + time_now = time.time() + logger.info( + f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}" + ) + start_time = time_now + + '''解析pdf中的每一页''' + if start_page_id <= page_id <= end_page_id: + page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode) + else: + page_w = page.rect.width + page_h = page.rect.height + page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], + [], [], [], [], + True, "skip page") + pdf_info_dict[f"page_{page_id}"] = page_info + + """分段""" + # para_split(pdf_info_dict, debug_mode=debug_mode) + for page_num, page in pdf_info_dict.items(): + page['para_blocks'] = page['preproc_blocks'] + + """dict转list""" + pdf_info_list = dict_to_list(pdf_info_dict) + new_pdf_info_dict = { + "pdf_info": pdf_info_list, + } + + clean_memory() + + return new_pdf_info_dict + + +if __name__ == '__main__': + pass diff --git a/magic_pdf/pre_proc/ocr_detect_all_bboxes.py b/magic_pdf/pre_proc/ocr_detect_all_bboxes.py index 9e2cd429..9767030b 100644 --- a/magic_pdf/pre_proc/ocr_detect_all_bboxes.py +++ b/magic_pdf/pre_proc/ocr_detect_all_bboxes.py @@ -60,6 +60,59 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc return all_bboxes, all_discarded_blocks, drop_reasons +def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks, + title_blocks, interline_equation_blocks, page_w, page_h): + all_bboxes = [] + all_discarded_blocks = [] + for image in img_blocks: + x0, y0, x1, y1 = image['bbox'] + all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]]) + + for table in table_blocks: + x0, y0, x1, y1 = table['bbox'] + all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]]) + + for text in text_blocks: + x0, y0, x1, y1 = text['bbox'] + all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]]) + + for title in title_blocks: + x0, y0, x1, y1 = title['bbox'] + all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]]) + + for interline_equation in interline_equation_blocks: + x0, y0, x1, y1 = interline_equation['bbox'] + all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]]) + + '''block嵌套问题解决''' + '''文本框与标题框重叠,优先信任文本框''' + all_bboxes = fix_text_overlap_title_blocks(all_bboxes) + '''任何框体与舍弃框重叠,优先信任舍弃框''' + all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks) + + # interline_equation 与title或text框冲突的情况,分两种情况处理 + '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框''' + all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes) + '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框''' + # 通过后续大框套小框逻辑删除 + + '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)''' + for discarded in discarded_blocks: + x0, y0, x1, y1 = discarded['bbox'] + all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]]) + # 将footnote加入到all_bboxes中,用来计算layout + # if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2): + # all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]]) + + '''经过以上处理后,还存在大框套小框的情况,则删除小框''' + all_bboxes = remove_overlaps_min_blocks(all_bboxes) + all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks) + '''将剩余的bbox做分离处理,防止后面分layout时出错''' + # all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes) + + return all_bboxes, all_discarded_blocks + + def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes): # 先提取所有text和interline block text_blocks = [] diff --git a/magic_pdf/tools/common.py b/magic_pdf/tools/common.py index 3939c28e..bae1224c 100644 --- a/magic_pdf/tools/common.py +++ b/magic_pdf/tools/common.py @@ -7,7 +7,7 @@ import magic_pdf.model as model_config from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox, - drow_model_bbox) + draw_model_bbox, draw_line_sort_bbox) from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.TXTPipe import TXTPipe @@ -39,17 +39,19 @@ def do_parse( f_dump_middle_json=True, f_dump_model_json=True, f_dump_orig_pdf=True, - f_dump_content_list=False, + f_dump_content_list=True, f_make_md_mode=MakeMode.MM_MD, f_draw_model_bbox=False, + f_draw_line_sort_bbox=False, start_page_id=0, end_page_id=None, lang=None, ): if debug_able: logger.warning('debug mode is on') - f_dump_content_list = True + # f_dump_content_list = True f_draw_model_bbox = True + f_draw_line_sort_bbox = True orig_model_list = copy.deepcopy(model_list) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, @@ -90,7 +92,9 @@ def do_parse( if f_draw_span_bbox: draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) if f_draw_model_bbox: - drow_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name) + draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name) + if f_draw_line_sort_bbox: + draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, diff --git a/requirements.txt b/requirements.txt index 4e6a0f94..d0bd653e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ pydantic>=2.7.2,<2.8.0 PyMuPDF>=1.24.9 scikit-learn>=1.0.2 wordninja>=2.0.0 +torch>=2.2.2,<=2.3.1 +transformers # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.