Skip to content

Commit

Permalink
refactor(model): improve timing information and performance
Browse files Browse the repository at this point in the history
- Enhance timing output precision to two decimal places for better readability- Calculate and log document analysis speed in pages per second
- Optimize logging for YOLO and table recognition processes
- Remove unnecessary comments and improve code efficiency
  • Loading branch information
myhloli committed Oct 6, 2024
1 parent 14bb586 commit be1b1ae
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
7 changes: 5 additions & 2 deletions magic_pdf/model/doc_analyze_by_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}")

doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
f" speed: {doc_analyze_speed} pages/second")

return model_json
22 changes: 14 additions & 8 deletions magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from magic_pdf.model.model_list import AtomicModel

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import cv2
import yaml
Expand Down Expand Up @@ -274,20 +275,24 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):

def __call__(self, image):

page_start = time.time()

latex_filling_list = []
mf_image_list = []

# layout检测
layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}")
logger.info(f"layout detection time: {layout_cost}")

pil_img = Image.fromarray(image)

if self.apply_formula:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
Expand Down Expand Up @@ -381,15 +386,15 @@ def __call__(self, image):
})

ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
logger.info(f"ocr time: {ocr_cost}")

# 表格识别 table recognition
if self.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
# logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_type == STRUCT_EQTABLE:
Expand All @@ -399,7 +404,7 @@ def __call__(self, image):
html_code = self.table_model.img2html(new_image)

run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
Expand All @@ -410,12 +415,13 @@ def __call__(self, image):
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
elif html_code:
res["html"] = html_code
else:
logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
logger.warning(f"table recognition processing fails, not get latex or html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}")

logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")

return layout_res

0 comments on commit be1b1ae

Please sign in to comment.