Skip to content

Commit

Permalink
feat(layoutreader): support local model directory and improve model l…
Browse files Browse the repository at this point in the history
…oading

- Add function to get local LayoutReader model directory- Check and use local model directory if available
- Fall back to online model if local directory not found
- Update model initialization to support local path
- Refactor model loading in singleton class
  • Loading branch information
myhloli committed Oct 8, 2024
1 parent 3fb0494 commit ded2818
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/download_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# use modelscope sdk download models
from modelscope import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
print(f"model dir is: {model_dir}/models")
1 change: 1 addition & 0 deletions docs/download_models_hf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader')
print(f"model dir is: {model_dir}/models")
1 change: 1 addition & 0 deletions magic-pdf.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"bucket-name-2":["ak", "sk", "endpoint"]
},
"models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"table-config": {
"model": "TableMaster",
Expand Down
12 changes: 12 additions & 0 deletions magic_pdf/libs/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def get_local_models_dir():
return models_dir


def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get("layoutreader-model-dir")
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser("~")
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir


def get_device():
config = read_config()
device = config.get("device-mode")
Expand Down
19 changes: 11 additions & 8 deletions magic_pdf/pdf_parse_union_core_v2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import statistics
import time

Expand All @@ -9,6 +10,7 @@

from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
Expand Down Expand Up @@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans


def model_init(model_name: str, local_path=None):
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available():
device = torch.device("cuda")
Expand All @@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
supports_bfloat16 = False

if model_name == "layoutreader":
if local_path:
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
else:
logger.warning(
f"local layoutreader model not exists, use online model from huggingface")
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# 检查设备是否支持 bfloat16
if supports_bfloat16:
Expand All @@ -131,12 +137,9 @@ def __new__(cls, *args, **kwargs):
cls._instance = super().__new__(cls)
return cls._instance

def get_model(self, model_name: str, local_path=None):
def get_model(self, model_name: str):
if model_name not in self._models:
if local_path:
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name)
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]


Expand Down

0 comments on commit ded2818

Please sign in to comment.