Skip to content

Commit

Permalink
Merge pull request #773 from myhloli/add-doclayout-yolo
Browse files Browse the repository at this point in the history
feat(model): add support for DocLayout-YOLO model
  • Loading branch information
myhloli authored Oct 23, 2024
2 parents efb5851 + 1279f2c commit c1ba9dc
Show file tree
Hide file tree
Showing 18 changed files with 365 additions and 130 deletions.
2 changes: 1 addition & 1 deletion magic-pdf.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"layout-config": {
"model": "doclayout_yolo"
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
Expand Down
19 changes: 13 additions & 6 deletions magic_pdf/libs/Constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,12 @@
# block中lines是否被删除
LINES_DELETED = "lines_deleted"

# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"

# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400

# pp_table_result_max_length
TABLE_MAX_LEN = 480

# pp table structure algorithm
TABLE_MASTER = "TableMaster"

# table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt"

Expand All @@ -38,3 +32,16 @@
REC_CHAR_DICT = "ppocr_keys_v1.txt"


class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"

DocLayout_YOLO = "doclayout_yolo"

LAYOUTLMv3 = "layoutlmv3"

YOLO_V8_MFD = "yolo_v8_mfd"

UniMerNet_v2_Small = "unimernet_small"
35 changes: 35 additions & 0 deletions magic_pdf/libs/boxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):

# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)


def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2

# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)

if x_right < x_left:
return 0.0

# Length of the intersection
intersection_length = x_right - x_left

# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1

if block1_length == 0:
return 0.0

# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
23 changes: 22 additions & 1 deletion magic_pdf/libs/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from loguru import logger

from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key

# 定义配置文件名常量
Expand Down Expand Up @@ -94,10 +95,30 @@ def get_table_recog_config():
table_config = config.get("table-config")
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
else:
return table_config


def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config


def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config


if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
52 changes: 38 additions & 14 deletions magic_pdf/model/doc_analyze_by_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from loguru import logger

from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config

Expand Down Expand Up @@ -68,14 +69,17 @@ def __new__(cls, *args, **kwargs):
cls._instance = super().__new__(cls)
return cls._instance

def get_model(self, ocr: bool, show_log: bool, lang=None):
key = (ocr, show_log, lang)
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key]


def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):

model = None

if model_config.__model_mode__ == "lite":
Expand All @@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()

layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model

formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable

table_config = get_table_recog_config()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"lang": lang,
}
if table_enable is not None:
table_config["enable"] = table_enable

model_input = {
"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}

custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
Expand All @@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):


def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):

if lang == "":
lang = None

model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang)
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)

with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
Expand Down
Loading

0 comments on commit c1ba9dc

Please sign in to comment.