Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model): add support for DocLayout-YOLO model #773

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion magic-pdf.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"layout-config": {
"model": "doclayout_yolo"
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
Expand Down
19 changes: 13 additions & 6 deletions magic_pdf/libs/Constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,12 @@
# block中lines是否被删除
LINES_DELETED = "lines_deleted"

# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"

# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400

# pp_table_result_max_length
TABLE_MAX_LEN = 480

# pp table structure algorithm
TABLE_MASTER = "TableMaster"

# table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt"

Expand All @@ -38,3 +32,16 @@
REC_CHAR_DICT = "ppocr_keys_v1.txt"


class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"

DocLayout_YOLO = "doclayout_yolo"

LAYOUTLMv3 = "layoutlmv3"

YOLO_V8_MFD = "yolo_v8_mfd"

UniMerNet_v2_Small = "unimernet_small"
35 changes: 35 additions & 0 deletions magic_pdf/libs/boxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):

# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)


def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.

Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).

Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2

# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)

if x_right < x_left:
return 0.0

# Length of the intersection
intersection_length = x_right - x_left

# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1

if block1_length == 0:
return 0.0

# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
23 changes: 22 additions & 1 deletion magic_pdf/libs/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from loguru import logger

from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key

# 定义配置文件名常量
Expand Down Expand Up @@ -94,10 +95,30 @@ def get_table_recog_config():
table_config = config.get("table-config")
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
else:
return table_config


def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config


def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config


if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
52 changes: 38 additions & 14 deletions magic_pdf/model/doc_analyze_by_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from loguru import logger

from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config

Expand Down Expand Up @@ -68,14 +69,17 @@ def __new__(cls, *args, **kwargs):
cls._instance = super().__new__(cls)
return cls._instance

def get_model(self, ocr: bool, show_log: bool, lang=None):
key = (ocr, show_log, lang)
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key]


def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):

model = None

if model_config.__model_mode__ == "lite":
Expand All @@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()

layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model

formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable

table_config = get_table_recog_config()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"lang": lang,
}
if table_enable is not None:
table_config["enable"] = table_enable

model_input = {
"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}

custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
Expand All @@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):


def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):

if lang == "":
lang = None

model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang)
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)

with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
Expand Down
Loading
Loading