Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:add layoutreader to sort blocks #672

Merged
merged 20 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3cbcf2d
feat(draw_bbox): add layout sorting visualization
myhloli Sep 25, 2024
270ffb0
feat(draw_bbox): add layout sorting visualization
myhloli Sep 26, 2024
00cda7a
refactor(draw_bbox): clear cuda cache and update bbox sorting
myhloli Sep 26, 2024
36220d6
refactor(draw_bbox): clear cuda cache and update bbox sorting
myhloli Sep 27, 2024
1efebe4
refactor(pdf_parse_union): integrate LayoutLMv3 for block orderingRep…
myhloli Sep 27, 2024
34f8965
refactor(draw_bbox): add line sorting visualization
myhloli Sep 27, 2024
c56de49
refactor(draw_bbox): remove conditional layout bbox drawing
myhloli Sep 27, 2024
43a57d5
feat(draw_bbox): add option to toggle bounding box drawing
myhloli Sep 27, 2024
16b51c7
Merge remote-tracking branch 'origin/add-layoutreader' into add-layou…
myhloli Sep 27, 2024
b2790f6
refactor(drawing): simplify draw bbox functions and adjust debug config
myhloli Sep 27, 2024
b9dfdea
refactor(pdf_parse_union_core_v2): implement model initialization wit…
myhloli Sep 27, 2024
6561545
feat(requirements): add torch and transformers libraries
myhloli Sep 27, 2024
83c0738
refactor(draw_bbox): remove commented-out code and streamline bbox dr…
myhloli Sep 27, 2024
177ab08
refactor(pdf_parse): remove redundant sorting and optimize block inde…
myhloli Sep 27, 2024
2145a8b
fix(pdf_parse): handle blocks without lines and enable bf16 on compat…
myhloli Sep 28, 2024
5522d0a
refactor(pdf_parse_union_core_v2): update import paths to use new pac…
myhloli Sep 28, 2024
42a7d79
refactor(magic_pdf): import model helpers directly for clarity
myhloli Sep 28, 2024
4c9bf8a
refactor(memory management): remove unused clean_memory function
myhloli Sep 29, 2024
564c4ce
refactor(magic_pdf): improve line sorting and block indexing
myhloli Sep 29, 2024
fcf2424
chore: remove useless files
myhloli Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions magic_pdf/libs/clean_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc


def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
124 changes: 88 additions & 36 deletions magic_pdf/libs/draw_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
) # Draw the rectangle


def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
new_rgb = []
for item in rgb_config:
item = float(item) / 255
Expand All @@ -42,31 +42,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
for j, bbox in enumerate(page_data):
x0, y0, x1, y1 = bbox
rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
if draw_bbox:
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
page.insert_text(
(x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
(x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle


def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
dropped_bbox_list = []
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
Expand All @@ -76,16 +76,14 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts_list = []
interequations_list = []
for page in pdf_info:
page_layout_list = []

page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
texts = []
interequations = []
for layout in page['layout_bboxes']:
page_layout_list.append(layout['layout_bbox'])
layout_bbox_list.append(page_layout_list)

for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
dropped_bbox_list.append(page_dropped_list)
Expand Down Expand Up @@ -129,9 +127,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts_list.append(texts)
interequations_list.append(interequations)

layout_bbox_list = []

for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)

pdf_docs = fitz.open('pdf', pdf_bytes)

for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)

draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
True)
draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
Expand All @@ -146,13 +154,15 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
True)

draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)

# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf')

Expand Down Expand Up @@ -211,9 +221,9 @@ def get_span_info(span):
# 构造其余useful_list
for block in page['para_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
]:
for line in block['lines']:
for span in line['spans']:
Expand All @@ -232,10 +242,8 @@ def get_span_info(span):
for i, page in enumerate(pdf_docs):
# 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0],
False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255],
False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
Expand All @@ -244,7 +252,7 @@ def get_span_info(span):
pdf_docs.save(f'{out_path}/{filename}_spans.pdf')


def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
Expand Down Expand Up @@ -279,7 +287,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox)
elif layout_det[
'category_id'] == CategoryId.InterlineEquation_YOLO:
'category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox)
Expand Down Expand Up @@ -316,3 +324,47 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):

# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf')


def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []

for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in ['text', 'title', 'interline_equation']:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in ['table', 'image']:
bbox = block['bbox']
index = block['index']
page_line_list.append({'index': index, 'bbox': bbox})
# for line in block['lines']:
# bbox = line['bbox']
# index = line['index']
# page_line_list.append({'index': index, 'bbox': bbox})
sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)

pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')


def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []

for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)

pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
3 changes: 3 additions & 0 deletions magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

from magic_pdf.libs.Constants import *
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
Expand Down Expand Up @@ -330,6 +331,8 @@ def __call__(self, image):
elif int(res['category_id']) in [5]:
table_res_list.append(res)

clean_memory()

# ocr识别
if self.apply_ocr:
ocr_start = time.time()
Expand Down
Empty file added magic_pdf/model/v3/__init__.py
Empty file.
125 changes: 125 additions & 0 deletions magic_pdf/model/v3/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from collections import defaultdict
from typing import List, Dict

import torch
from transformers import LayoutLMv3ForTokenClassification

MAX_LEN = 510
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2


class DataCollator:
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
bbox = []
labels = []
input_ids = []
attention_mask = []

# clip bbox and labels to max length, build input_ids and attention_mask
for feature in features:
_bbox = feature["source_boxes"]
if len(_bbox) > MAX_LEN:
_bbox = _bbox[:MAX_LEN]
_labels = feature["target_index"]
if len(_labels) > MAX_LEN:
_labels = _labels[:MAX_LEN]
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
_attention_mask = [1] * len(_bbox)
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
bbox.append(_bbox)
labels.append(_labels)
input_ids.append(_input_ids)
attention_mask.append(_attention_mask)

# add CLS and EOS tokens
for i in range(len(bbox)):
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
labels[i] = [-100] + labels[i] + [-100]
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
attention_mask[i] = [1] + attention_mask[i] + [1]

# padding to max length
max_len = max(len(x) for x in bbox)
for i in range(len(bbox)):
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
attention_mask[i] = attention_mask[i] + [0] * (
max_len - len(attention_mask[i])
)

ret = {
"bbox": torch.tensor(bbox),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
"input_ids": torch.tensor(input_ids),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret["labels"][ret["labels"] > MAX_LEN] = -100
# set label > 0 to label-1, because original labels are 1-indexed
ret["labels"][ret["labels"] > 0] -= 1
return ret


def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}


def prepare_inputs(
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
) -> Dict[str, torch.Tensor]:
ret = {}
for k, v in inputs.items():
v = v.to(model.device)
if torch.is_floating_point(v):
v = v.to(model.dtype)
ret[k] = v
return ret


def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
parse logits to orders

:param logits: logits from model
:param length: input length
:return: orders
"""
logits = logits[1 : length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# filter idxes len > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# filter
for order, idxes in order_to_idxes.items():
# find original logits of idxes
idxes_to_logit = {}
for idx in idxes:
idxes_to_logit[idx] = logits[idx, order]
idxes_to_logit = sorted(
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
)
# keep the highest logit as order, set others to next candidate
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()

return ret


def check_duplicate(a: List[int]) -> bool:
return len(a) != len(set(a))
2 changes: 1 addition & 1 deletion magic_pdf/pdf_parse_by_ocr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from magic_pdf.pdf_parse_union_core import pdf_parse_union
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union


def parse_pdf_by_ocr(pdf_bytes,
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/pdf_parse_by_txt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from magic_pdf.pdf_parse_union_core import pdf_parse_union
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union


def parse_pdf_by_txt(
Expand Down
Loading
Loading