From 1279f2cd0f2639dcad2fb3437fceda7539455ae2 Mon Sep 17 00:00:00 2001 From: myhloli Date: Wed, 23 Oct 2024 17:07:39 +0800 Subject: [PATCH] feat(model): add support for DocLayout-YOLO model - Add new layout model option: DocLayout-YOLO - Implement model initialization and prediction for DocLayout-YOLO - Update configuration options to include new model- Modify existing code to support both LayoutLMv3 and DocLayout-YOLO models - Update Gradio app to support more Custom Switch --- magic-pdf.template.json | 2 +- magic_pdf/libs/Constants.py | 19 ++- magic_pdf/libs/boxbase.py | 35 +++++ magic_pdf/libs/config_reader.py | 23 +++- .../model/doc_analyze_by_custom_model.py | 52 ++++++-- magic_pdf/model/pdf_extract_kit.py | 125 ++++++++++++------ magic_pdf/model/ppTableModel.py | 4 +- magic_pdf/pipe/AbsPipe.py | 5 +- magic_pdf/pipe/OCRPipe.py | 12 +- magic_pdf/pipe/TXTPipe.py | 12 +- magic_pdf/pipe/UNIPipe.py | 15 ++- magic_pdf/pre_proc/ocr_detect_all_bboxes.py | 28 +++- .../resources/model_config/model_configs.yaml | 18 +-- magic_pdf/tools/common.py | 13 +- magic_pdf/user_api.py | 18 ++- old_docs/download_models.py | 29 ++-- old_docs/download_models_hf.py | 38 ++++-- projects/gradio_app/app.py | 47 ++++++- 18 files changed, 365 insertions(+), 130 deletions(-) diff --git a/magic-pdf.template.json b/magic-pdf.template.json index 487bbf08..114dfce3 100644 --- a/magic-pdf.template.json +++ b/magic-pdf.template.json @@ -7,7 +7,7 @@ "layoutreader-model-dir":"/tmp/layoutreader", "device-mode":"cpu", "layout-config": { - "model": "doclayout_yolo" + "model": "layoutlmv3" }, "formula-config": { "mfd_model": "yolo_v8_mfd", diff --git a/magic_pdf/libs/Constants.py b/magic_pdf/libs/Constants.py index 4e132290..2a51b2dc 100644 --- a/magic_pdf/libs/Constants.py +++ b/magic_pdf/libs/Constants.py @@ -10,18 +10,12 @@ # block中lines是否被删除 LINES_DELETED = "lines_deleted" -# struct eqtable -STRUCT_EQTABLE = "struct_eqtable" - # table recognition max time default value TABLE_MAX_TIME_VALUE = 400 # pp_table_result_max_length TABLE_MAX_LEN = 480 -# pp table structure algorithm -TABLE_MASTER = "TableMaster" - # table master structure dict TABLE_MASTER_DICT = "table_master_structure_dict.txt" @@ -38,3 +32,16 @@ REC_CHAR_DICT = "ppocr_keys_v1.txt" +class MODEL_NAME: + # pp table structure algorithm + TABLE_MASTER = "tablemaster" + # struct eqtable + STRUCT_EQTABLE = "struct_eqtable" + + DocLayout_YOLO = "doclayout_yolo" + + LAYOUTLMv3 = "layoutlmv3" + + YOLO_V8_MFD = "yolo_v8_mfd" + + UniMerNet_v2_Small = "unimernet_small" \ No newline at end of file diff --git a/magic_pdf/libs/boxbase.py b/magic_pdf/libs/boxbase.py index 0472328f..52779a22 100644 --- a/magic_pdf/libs/boxbase.py +++ b/magic_pdf/libs/boxbase.py @@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2): # The area of overlap area return (x_right - x_left) * (y_bottom - y_top) + + +def calculate_vertical_projection_overlap_ratio(block1, block2): + """ + Calculate the proportion of the x-axis covered by the vertical projection of two blocks. + + Args: + block1 (tuple): Coordinates of the first block (x0, y0, x1, y1). + block2 (tuple): Coordinates of the second block (x0, y0, x1, y1). + + Returns: + float: The proportion of the x-axis covered by the vertical projection of the two blocks. + """ + x0_1, _, x1_1, _ = block1 + x0_2, _, x1_2, _ = block2 + + # Calculate the intersection of the x-coordinates + x_left = max(x0_1, x0_2) + x_right = min(x1_1, x1_2) + + if x_right < x_left: + return 0.0 + + # Length of the intersection + intersection_length = x_right - x_left + + # Length of the x-axis projection of the first block + block1_length = x1_1 - x0_1 + + if block1_length == 0: + return 0.0 + + # Proportion of the x-axis covered by the intersection + # logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}") + return intersection_length / block1_length diff --git a/magic_pdf/libs/config_reader.py b/magic_pdf/libs/config_reader.py index 9b4b7d8b..8a831d7f 100644 --- a/magic_pdf/libs/config_reader.py +++ b/magic_pdf/libs/config_reader.py @@ -8,6 +8,7 @@ from loguru import logger +from magic_pdf.libs.Constants import MODEL_NAME from magic_pdf.libs.commons import parse_bucket_key # 定义配置文件名常量 @@ -94,10 +95,30 @@ def get_table_recog_config(): table_config = config.get("table-config") if table_config is None: logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default") - return json.loads('{"is_table_recog_enable": false, "max_time": 400}') + return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}') else: return table_config +def get_layout_config(): + config = read_config() + layout_config = config.get("layout-config") + if layout_config is None: + logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default") + return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}') + else: + return layout_config + + +def get_formula_config(): + config = read_config() + formula_config = config.get("formula-config") + if formula_config is None: + logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default") + return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}') + else: + return formula_config + + if __name__ == "__main__": ak, sk, endpoint = get_s3_config("llm-raw") diff --git a/magic_pdf/model/doc_analyze_by_custom_model.py b/magic_pdf/model/doc_analyze_by_custom_model.py index 3fbbea61..ee50d6eb 100644 --- a/magic_pdf/model/doc_analyze_by_custom_model.py +++ b/magic_pdf/model/doc_analyze_by_custom_model.py @@ -5,7 +5,8 @@ from loguru import logger from magic_pdf.libs.clean_memory import clean_memory -from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config +from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \ + get_formula_config from magic_pdf.model.model_list import MODEL import magic_pdf.model as model_config @@ -68,14 +69,17 @@ def __new__(cls, *args, **kwargs): cls._instance = super().__new__(cls) return cls._instance - def get_model(self, ocr: bool, show_log: bool, lang=None): - key = (ocr, show_log, lang) + def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None): + key = (ocr, show_log, lang, layout_model, formula_enable, table_enable) if key not in self._models: - self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang) + self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model, + formula_enable=formula_enable, table_enable=table_enable) return self._models[key] -def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None): +def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None, + layout_model=None, formula_enable=None, table_enable=None): + model = None if model_config.__model_mode__ == "lite": @@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None): # 从配置文件读取model-dir和device local_models_dir = get_local_models_dir() device = get_device() + + layout_config = get_layout_config() + if layout_model is not None: + layout_config["model"] = layout_model + + formula_config = get_formula_config() + if formula_enable is not None: + formula_config["enable"] = formula_enable + table_config = get_table_recog_config() - model_input = {"ocr": ocr, - "show_log": show_log, - "models_dir": local_models_dir, - "device": device, - "table_config": table_config, - "lang": lang, - } + if table_enable is not None: + table_config["enable"] = table_enable + + model_input = { + "ocr": ocr, + "show_log": show_log, + "models_dir": local_models_dir, + "device": device, + "table_config": table_config, + "layout_config": layout_config, + "formula_config": formula_config, + "lang": lang, + } + custom_model = CustomPEKModel(**model_input) else: logger.error("Not allow model_name!") @@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None): def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, - start_page_id=0, end_page_id=None, lang=None): + start_page_id=0, end_page_id=None, lang=None, + layout_model=None, formula_enable=None, table_enable=None): + + if lang == "": + lang = None model_manager = ModelSingleton() - custom_model = model_manager.get_model(ocr, show_log, lang) + custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable) with fitz.open("pdf", pdf_bytes) as doc: pdf_page_num = doc.page_count diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 1e391104..6072deda 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -25,6 +25,7 @@ from unimernet.common.config import Config import unimernet.tasks as tasks from unimernet.processors import load_processor + from doclayout_yolo import YOLOv10 except ImportError as e: logger.exception(e) @@ -41,7 +42,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): - if table_model_type == STRUCT_EQTABLE: + if table_model_type == MODEL_NAME.STRUCT_EQTABLE: table_model = StructTableModel(model_path, max_time=max_time, device=_device_) else: config = { @@ -77,6 +78,11 @@ def layout_model_init(weight, config_file, device): return model +def doclayout_yolo_model_init(weight): + model = YOLOv10(weight) + return model + + def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=2.4): if lang is not None: model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio) @@ -114,19 +120,27 @@ def __new__(cls, *args, **kwargs): return cls._instance def get_atom_model(self, atom_model_name: str, **kwargs): - if atom_model_name not in self._models: - self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs) - return self._models[atom_model_name] + lang = kwargs.get("lang", None) + layout_model_name = kwargs.get("layout_model_name", None) + key = (atom_model_name, layout_model_name, lang) + if key not in self._models: + self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) + return self._models[key] def atom_model_init(model_name: str, **kwargs): if model_name == AtomicModel.Layout: - atom_model = layout_model_init( - kwargs.get("layout_weights"), - kwargs.get("layout_config_file"), - kwargs.get("device") - ) + if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3: + atom_model = layout_model_init( + kwargs.get("layout_weights"), + kwargs.get("layout_config_file"), + kwargs.get("device") + ) + elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO: + atom_model = doclayout_yolo_model_init( + kwargs.get("doclayout_yolo_weights"), + ) elif model_name == AtomicModel.MFD: atom_model = mfd_model_init( kwargs.get("mfd_weights") @@ -145,7 +159,7 @@ def atom_model_init(model_name: str, **kwargs): ) elif model_name == AtomicModel.Table: atom_model = table_model_init( - kwargs.get("table_model_type"), + kwargs.get("table_model_name"), kwargs.get("table_model_path"), kwargs.get("table_max_time"), kwargs.get("device") @@ -193,23 +207,35 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): with open(config_path, "r", encoding='utf-8') as f: self.configs = yaml.load(f, Loader=yaml.FullLoader) # 初始化解析配置 - self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"]) - self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"]) + + # layout config + self.layout_config = kwargs.get("layout_config") + self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO) + + # formula config + self.formula_config = kwargs.get("formula_config") + self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD) + self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small) + self.apply_formula = self.formula_config.get("enable", True) + # table config - self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"]) - self.apply_table = self.table_config.get("is_table_recog_enable", False) + self.table_config = kwargs.get("table_config") + self.apply_table = self.table_config.get("enable", False) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) - self.table_model_type = self.table_config.get("model", TABLE_MASTER) + self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER) + + # ocr config self.apply_ocr = ocr self.lang = kwargs.get("lang", None) + logger.info( - "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format( - self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang + "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, " + "apply_table: {}, table_model: {}, lang: {}".format( + self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang ) ) - assert self.apply_layout, "DocAnalysis must contain layout model." # 初始化解析方案 - self.device = kwargs.get("device", self.configs["config"]["device"]) + self.device = kwargs.get("device", "cpu") logger.info("using device: {}".format(self.device)) models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) logger.info("using models_dir: {}".format(models_dir)) @@ -218,17 +244,16 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): # 初始化公式识别 if self.apply_formula: + # 初始化公式检测模型 - # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"]))) self.mfd_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.MFD, - mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"])) + mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])) ) + # 初始化公式解析模型 - mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"])) + mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name])) mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) - # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device) - # self.mfr_transform = transforms.Compose([mfr_vis_processors, ]) self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.MFR, mfr_weight_dir=mfr_weight_dir, @@ -237,17 +262,20 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): ) # 初始化layout模型 - # self.layout_model = Layoutlmv3_Predictor( - # str(os.path.join(models_dir, self.configs['weights']['layout'])), - # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), - # device=self.device - # ) - self.layout_model = atom_model_manager.get_atom_model( - atom_model_name=AtomicModel.Layout, - layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])), - layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), - device=self.device - ) + if self.layout_model_name == MODEL_NAME.LAYOUTLMv3: + self.layout_model = atom_model_manager.get_atom_model( + atom_model_name=AtomicModel.Layout, + layout_model_name=MODEL_NAME.LAYOUTLMv3, + layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), + layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), + device=self.device + ) + elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: + self.layout_model = atom_model_manager.get_atom_model( + atom_model_name=AtomicModel.Layout, + layout_model_name=MODEL_NAME.DocLayout_YOLO, + doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])) + ) # 初始化ocr if self.apply_ocr: @@ -260,12 +288,10 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): ) # init table model if self.apply_table: - table_model_dir = self.configs["weights"][self.table_model_type] - # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)), - # max_time=self.table_max_time, _device_=self.device) + table_model_dir = self.configs["weights"][self.table_model_name] self.table_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.Table, - table_model_type=self.table_model_type, + table_model_name=self.table_model_name, table_model_path=str(os.path.join(models_dir, table_model_dir)), table_max_time=self.table_max_time, device=self.device @@ -282,7 +308,21 @@ def __call__(self, image): # layout检测 layout_start = time.time() - layout_res = self.layout_model(image, ignore_catids=[]) + if self.layout_model_name == MODEL_NAME.LAYOUTLMv3: + # layoutlmv3 + layout_res = self.layout_model(image, ignore_catids=[]) + elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: + # doclayout_yolo + layout_res = [] + doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.15, iou=0.45, verbose=True, device=self.device)[0] + for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()): + xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] + new_item = { + 'category_id': int(cla.item()), + 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], + 'score': round(float(conf.item()), 3), + } + layout_res.append(new_item) layout_cost = round(time.time() - layout_start, 2) logger.info(f"layout detection time: {layout_cost}") @@ -291,7 +331,7 @@ def __call__(self, image): if self.apply_formula: # 公式检测 mfd_start = time.time() - mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0] + mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] @@ -303,7 +343,6 @@ def __call__(self, image): } layout_res.append(new_item) latex_filling_list.append(new_item) - # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax]) bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) mf_image_list.append(bbox_img) @@ -405,7 +444,7 @@ def __call__(self, image): # logger.info("------------------table recognition processing begins-----------------") latex_code = None html_code = None - if self.table_model_type == STRUCT_EQTABLE: + if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: with torch.no_grad(): latex_code = self.table_model.image2latex(new_image)[0] else: diff --git a/magic_pdf/model/ppTableModel.py b/magic_pdf/model/ppTableModel.py index 310bcc79..933f31a0 100644 --- a/magic_pdf/model/ppTableModel.py +++ b/magic_pdf/model/ppTableModel.py @@ -52,11 +52,11 @@ def parse_args(self, **kwargs): rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) device = kwargs.get("device", "cpu") - use_gpu = True if device == "cuda" else False + use_gpu = True if device.startswith("cuda") else False config = { "use_gpu": use_gpu, "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN), - "table_algorithm": TABLE_MASTER, + "table_algorithm": "TableMaster", "table_model_dir": table_model_dir, "table_char_dict_path": table_char_dict_path, "det_model_dir": det_model_dir, diff --git a/magic_pdf/pipe/AbsPipe.py b/magic_pdf/pipe/AbsPipe.py index 93cb4b1b..19841374 100644 --- a/magic_pdf/pipe/AbsPipe.py +++ b/magic_pdf/pipe/AbsPipe.py @@ -17,7 +17,7 @@ class AbsPipe(ABC): PIP_TXT = "txt" def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, - start_page_id=0, end_page_id=None, lang=None): + start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None): self.pdf_bytes = pdf_bytes self.model_list = model_list self.image_writer = image_writer @@ -26,6 +26,9 @@ def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWr self.start_page_id = start_page_id self.end_page_id = end_page_id self.lang = lang + self.layout_model = layout_model + self.formula_enable = formula_enable + self.table_enable = table_enable def get_compress_pdf_mid_data(self): return JsonCompressor.compress_json(self.pdf_mid_data) diff --git a/magic_pdf/pipe/OCRPipe.py b/magic_pdf/pipe/OCRPipe.py index 7a30776b..71002a93 100644 --- a/magic_pdf/pipe/OCRPipe.py +++ b/magic_pdf/pipe/OCRPipe.py @@ -10,8 +10,10 @@ class OCRPipe(AbsPipe): def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, - start_page_id=0, end_page_id=None, lang=None): - super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang) + start_page_id=0, end_page_id=None, lang=None, + layout_model=None, formula_enable=None, table_enable=None): + super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, + layout_model, formula_enable, table_enable) def pipe_classify(self): pass @@ -19,12 +21,14 @@ def pipe_classify(self): def pipe_analyze(self): self.model_list = doc_analyze(self.pdf_bytes, ocr=True, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) def pipe_parse(self): self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): result = super().pipe_mk_uni_format(img_parent_path, drop_mode) diff --git a/magic_pdf/pipe/TXTPipe.py b/magic_pdf/pipe/TXTPipe.py index 14c4f4e4..f0bc9b7b 100644 --- a/magic_pdf/pipe/TXTPipe.py +++ b/magic_pdf/pipe/TXTPipe.py @@ -11,8 +11,10 @@ class TXTPipe(AbsPipe): def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, - start_page_id=0, end_page_id=None, lang=None): - super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang) + start_page_id=0, end_page_id=None, lang=None, + layout_model=None, formula_enable=None, table_enable=None): + super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, + layout_model, formula_enable, table_enable) def pipe_classify(self): pass @@ -20,12 +22,14 @@ def pipe_classify(self): def pipe_analyze(self): self.model_list = doc_analyze(self.pdf_bytes, ocr=False, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) def pipe_parse(self): self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): result = super().pipe_mk_uni_format(img_parent_path, drop_mode) diff --git a/magic_pdf/pipe/UNIPipe.py b/magic_pdf/pipe/UNIPipe.py index 226ae48f..a1ae7f90 100644 --- a/magic_pdf/pipe/UNIPipe.py +++ b/magic_pdf/pipe/UNIPipe.py @@ -14,9 +14,11 @@ class UNIPipe(AbsPipe): def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False, - start_page_id=0, end_page_id=None, lang=None): + start_page_id=0, end_page_id=None, lang=None, + layout_model=None, formula_enable=None, table_enable=None): self.pdf_type = jso_useful_key["_pdf_type"] - super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang) + super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, + lang, layout_model, formula_enable, table_enable) if len(self.model_list) == 0: self.input_model_is_empty = True else: @@ -29,18 +31,21 @@ def pipe_analyze(self): if self.pdf_type == self.PIP_TXT: self.model_list = doc_analyze(self.pdf_bytes, ocr=False, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) elif self.pdf_type == self.PIP_OCR: self.model_list = doc_analyze(self.pdf_bytes, ocr=True, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) def pipe_parse(self): if self.pdf_type == self.PIP_TXT: self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, start_page_id=self.start_page_id, end_page_id=self.end_page_id, - lang=self.lang) + lang=self.lang, layout_model=self.layout_model, + formula_enable=self.formula_enable, table_enable=self.table_enable) elif self.pdf_type == self.PIP_OCR: self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, diff --git a/magic_pdf/pre_proc/ocr_detect_all_bboxes.py b/magic_pdf/pre_proc/ocr_detect_all_bboxes.py index 8725b884..f0ae0924 100644 --- a/magic_pdf/pre_proc/ocr_detect_all_bboxes.py +++ b/magic_pdf/pre_proc/ocr_detect_all_bboxes.py @@ -1,7 +1,7 @@ from loguru import logger from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \ - calculate_iou + calculate_iou, calculate_vertical_projection_overlap_ratio from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.ocr_content_type import BlockType from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block @@ -97,12 +97,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b # 通过后续大框套小框逻辑删除 '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)''' + footnote_blocks = [] for discarded in discarded_blocks: x0, y0, x1, y1 = discarded['bbox'] all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]]) # 将footnote加入到all_bboxes中,用来计算layout - # if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2): - # all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]]) + if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2): + footnote_blocks.append([x0, y0, x1, y1]) + + '''移除在footnote下面的任何框''' + need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks) + if len(need_remove_blocks) > 0: + for block in need_remove_blocks: + all_bboxes.remove(block) + all_discarded_blocks.append(block) '''经过以上处理后,还存在大框套小框的情况,则删除小框''' all_bboxes = remove_overlaps_min_blocks(all_bboxes) @@ -113,6 +121,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b return all_bboxes, all_discarded_blocks +def find_blocks_under_footnote(all_bboxes, footnote_blocks): + need_remove_blocks = [] + for block in all_bboxes: + block_x0, block_y0, block_x1, block_y1 = block[:4] + for footnote_bbox in footnote_blocks: + footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox + # 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1 + if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8: + if block not in need_remove_blocks: + need_remove_blocks.append(block) + break + return need_remove_blocks + + def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes): # 先提取所有text和interline block text_blocks = [] diff --git a/magic_pdf/resources/model_config/model_configs.yaml b/magic_pdf/resources/model_config/model_configs.yaml index e9f0d588..e56d6ee1 100644 --- a/magic_pdf/resources/model_config/model_configs.yaml +++ b/magic_pdf/resources/model_config/model_configs.yaml @@ -1,15 +1,7 @@ -config: - device: cpu - layout: True - formula: True - table_config: - model: TableMaster - is_table_recog_enable: False - max_time: 400 - weights: - layout: Layout/model_final.pth - mfd: MFD/weights.pt - mfr: MFR/unimernet_small + layoutlmv3: Layout/LayoutLMv3/model_final.pth + doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt + yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt + unimernet_small: MFR/unimernet_small struct_eqtable: TabRec/StructEqTable - TableMaster: TabRec/TableMaster \ No newline at end of file + tablemaster: TabRec/TableMaster \ No newline at end of file diff --git a/magic_pdf/tools/common.py b/magic_pdf/tools/common.py index bae1224c..ba0a740d 100644 --- a/magic_pdf/tools/common.py +++ b/magic_pdf/tools/common.py @@ -46,10 +46,12 @@ def do_parse( start_page_id=0, end_page_id=None, lang=None, + layout_model=None, + formula_enable=None, + table_enable=None, ): if debug_able: logger.warning('debug mode is on') - # f_dump_content_list = True f_draw_model_bbox = True f_draw_line_sort_bbox = True @@ -64,13 +66,16 @@ def do_parse( if parse_method == 'auto': jso_useful_key = {'_pdf_type': '', 'model_list': model_list} pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True, - start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) + start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, + layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) elif parse_method == 'txt': pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True, - start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) + start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, + layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) elif parse_method == 'ocr': pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True, - start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) + start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, + layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) else: logger.error('unknown parse method') exit(1) diff --git a/magic_pdf/user_api.py b/magic_pdf/user_api.py index c602fc33..2a4bd59e 100644 --- a/magic_pdf/user_api.py +++ b/magic_pdf/user_api.py @@ -101,11 +101,19 @@ def parse_pdf(method): if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False): logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr") if input_model_is_empty: - pdf_models = doc_analyze(pdf_bytes, - ocr=True, - start_page_id=start_page_id, - end_page_id=end_page_id, - lang=lang) + layout_model = kwargs.get("layout_model", None) + formula_enable = kwargs.get("formula_enable", None) + table_enable = kwargs.get("table_enable", None) + pdf_models = doc_analyze( + pdf_bytes, + ocr=True, + start_page_id=start_page_id, + end_page_id=end_page_id, + lang=lang, + layout_model=layout_model, + formula_enable=formula_enable, + table_enable=table_enable, + ) pdf_info_dict = parse_pdf(parse_pdf_by_ocr) if pdf_info_dict is None: raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.") diff --git a/old_docs/download_models.py b/old_docs/download_models.py index 7f116a0c..ed1ee5c3 100644 --- a/old_docs/download_models.py +++ b/old_docs/download_models.py @@ -5,16 +5,21 @@ from modelscope import snapshot_download +def download_json(url): + # 下载JSON文件 + response = requests.get(url) + response.raise_for_status() # 检查请求是否成功 + return response.json() + + def download_and_modify_json(url, local_filename, modifications): if os.path.exists(local_filename): data = json.load(open(local_filename)) + config_version = data.get('config_version', '0.0.0') + if config_version < '1.0.0': + data = download_json(url) else: - # 下载JSON文件 - response = requests.get(url) - response.raise_for_status() # 检查请求是否成功 - - # 解析JSON内容 - data = response.json() + data = download_json(url) # 修改内容 for key, value in modifications.items(): @@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications): if __name__ == '__main__': - model_dir = snapshot_download('opendatalab/PDF-Extract-Kit') + mineru_patterns = [ + "models/Layout/LayoutLMv3/*", + "models/Layout/YOLO/*", + "models/MFD/YOLO/*", + "models/MFR/unimernet_small/*", + "models/TabRec/TableMaster/*", + "models/TabRec/StructEqTable/*", + ] + model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns) layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader') model_dir = model_dir + '/models' print(f'model_dir is: {model_dir}') print(f'layoutreader_model_dir is: {layoutreader_model_dir}') - json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json' + json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json' config_file_name = 'magic-pdf.json' home_dir = os.path.expanduser('~') config_file = os.path.join(home_dir, config_file_name) diff --git a/old_docs/download_models_hf.py b/old_docs/download_models_hf.py index 915f1a24..5e6b8dce 100644 --- a/old_docs/download_models_hf.py +++ b/old_docs/download_models_hf.py @@ -5,16 +5,21 @@ from huggingface_hub import snapshot_download +def download_json(url): + # 下载JSON文件 + response = requests.get(url) + response.raise_for_status() # 检查请求是否成功 + return response.json() + + def download_and_modify_json(url, local_filename, modifications): if os.path.exists(local_filename): data = json.load(open(local_filename)) + config_version = data.get('config_version', '0.0.0') + if config_version < '1.0.0': + data = download_json(url) else: - # 下载JSON文件 - response = requests.get(url) - response.raise_for_status() # 检查请求是否成功 - - # 解析JSON内容 - data = response.json() + data = download_json(url) # 修改内容 for key, value in modifications.items(): @@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications): if __name__ == '__main__': - model_dir = snapshot_download('opendatalab/PDF-Extract-Kit') - layoutreader_model_dir = snapshot_download('hantian/layoutreader') + + mineru_patterns = [ + "models/Layout/LayoutLMv3/*", + "models/Layout/YOLO/*", + "models/MFD/YOLO/*", + "models/MFR/unimernet_small/*", + "models/TabRec/TableMaster/*", + "models/TabRec/StructEqTable/*", + ] + model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns) + + layoutreader_pattern = [ + "*.json", + "*.safetensors", + ] + layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern) + model_dir = model_dir + '/models' print(f'model_dir is: {model_dir}') print(f'layoutreader_model_dir is: {layoutreader_model_dir}') - json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json' + json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json' config_file_name = 'magic-pdf.json' home_dir = os.path.expanduser('~') config_file = os.path.join(home_dir, config_file_name) diff --git a/projects/gradio_app/app.py b/projects/gradio_app/app.py index aa576ecb..c0914877 100644 --- a/projects/gradio_app/app.py +++ b/projects/gradio_app/app.py @@ -23,7 +23,7 @@ def read_fn(path): return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN) -def parse_pdf(doc_path, output_dir, end_page_id, is_ocr): +def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language): os.makedirs(output_dir, exist_ok=True) try: @@ -42,6 +42,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr): parse_method, False, end_page_id=end_page_id, + layout_model=layout_mode, + formula_enable=formula_enable, + table_enable=table_enable, + lang=language, ) return local_md_dir, file_name except Exception as e: @@ -93,9 +97,10 @@ def replace(match): return re.sub(pattern, replace, markdown_text) -def to_markdown(file_path, end_pages, is_ocr): +def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language): # 获取识别的md文件以及压缩包文件路径 - local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr) + local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr, + layout_mode, formula_enable, table_enable, language) archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip") zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path) if zip_archive_success == 0: @@ -138,6 +143,27 @@ def init_model(): header = file.read() +latin_lang = [ + 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', + 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl', + 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv', + 'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german' +] +arabic_lang = ['ar', 'fa', 'ug', 'ur'] +cyrillic_lang = [ + 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', + 'dar', 'inh', 'che', 'lbe', 'lez', 'tab' +] +devanagari_lang = [ + 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', + 'sa', 'bgc' +] +other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka'] + +all_lang = [""] +all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang]) + + if __name__ == "__main__": with gr.Blocks() as demo: gr.HTML(header) @@ -145,8 +171,14 @@ def init_model(): with gr.Column(variant='panel', scale=5): pdf_show = gr.Markdown() max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages") - with gr.Row() as bu_flow: - is_ocr = gr.Checkbox(label="Force enable OCR") + with gr.Row(): + layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3") + language = gr.Dropdown(all_lang, label="Language", value="") + with gr.Row(): + formula_enable = gr.Checkbox(label="Enable formula recognition", value=True) + is_ocr = gr.Checkbox(label="Force enable OCR", value=False) + table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False) + with gr.Row(): change_bu = gr.Button("Convert") clear_bu = gr.ClearButton([pdf_show], value="Clear") pdf_show = PDF(label="Please upload pdf", interactive=True, height=800) @@ -166,7 +198,8 @@ def init_model(): latex_delimiters=latex_delimiters, line_breaks=True) with gr.Tab("Markdown text"): md_text = gr.TextArea(lines=45, show_copy_button=True) - change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr], outputs=[md, md_text, output_file, pdf_show]) + change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language], + outputs=[md, md_text, output_file, pdf_show]) clear_bu.add([md, pdf_show, md_text, output_file, is_ocr]) - demo.launch() \ No newline at end of file + demo.launch(server_name="0.0.0.0") \ No newline at end of file