Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(model): improve timing information and performance #690

Merged
merged 1 commit into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions magic_pdf/model/doc_analyze_by_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}")

doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
f" speed: {doc_analyze_speed} pages/second")

return model_json
22 changes: 14 additions & 8 deletions magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from magic_pdf.model.model_list import AtomicModel

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import cv2
import yaml
Expand Down Expand Up @@ -274,20 +275,24 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):

def __call__(self, image):

page_start = time.time()

latex_filling_list = []
mf_image_list = []

# layout检测
layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}")
logger.info(f"layout detection time: {layout_cost}")

pil_img = Image.fromarray(image)

if self.apply_formula:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
Expand Down Expand Up @@ -381,15 +386,15 @@ def __call__(self, image):
})

ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
logger.info(f"ocr time: {ocr_cost}")

# 表格识别 table recognition
if self.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
# logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_type == STRUCT_EQTABLE:
Expand All @@ -399,7 +404,7 @@ def __call__(self, image):
html_code = self.table_model.img2html(new_image)

run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
Expand All @@ -410,12 +415,13 @@ def __call__(self, image):
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
elif html_code:
res["html"] = html_code
else:
logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
logger.warning(f"table recognition processing fails, not get latex or html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}")

logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")

return layout_res
Loading