diff --git a/docs/download_models.py b/docs/download_models.py index 934c99be..9fbaea48 100644 --- a/docs/download_models.py +++ b/docs/download_models.py @@ -1,4 +1,5 @@ # use modelscope sdk download models from modelscope import snapshot_download model_dir = snapshot_download('opendatalab/PDF-Extract-Kit') +layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader') print(f"model dir is: {model_dir}/models") diff --git a/docs/download_models_hf.py b/docs/download_models_hf.py index 7436a06b..0c7079e9 100644 --- a/docs/download_models_hf.py +++ b/docs/download_models_hf.py @@ -1,3 +1,4 @@ from huggingface_hub import snapshot_download model_dir = snapshot_download('opendatalab/PDF-Extract-Kit') +layoutreader_model_dir = snapshot_download('hantian/layoutreader') print(f"model dir is: {model_dir}/models") diff --git a/magic-pdf.template.json b/magic-pdf.template.json index 1eb61101..fcb99955 100644 --- a/magic-pdf.template.json +++ b/magic-pdf.template.json @@ -4,6 +4,7 @@ "bucket-name-2":["ak", "sk", "endpoint"] }, "models-dir":"/tmp/models", + "layoutreader-model-dir":"/tmp/layoutreader", "device-mode":"cpu", "table-config": { "model": "TableMaster", diff --git a/magic_pdf/libs/config_reader.py b/magic_pdf/libs/config_reader.py index eb282903..9b4b7d8b 100644 --- a/magic_pdf/libs/config_reader.py +++ b/magic_pdf/libs/config_reader.py @@ -67,6 +67,18 @@ def get_local_models_dir(): return models_dir +def get_local_layoutreader_model_dir(): + config = read_config() + layoutreader_model_dir = config.get("layoutreader-model-dir") + if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir): + home_dir = os.path.expanduser("~") + layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader") + logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default") + return layoutreader_at_modelscope_dir_path + else: + return layoutreader_model_dir + + def get_device(): config = read_config() device = config.get("device-mode") diff --git a/magic_pdf/pdf_parse_union_core_v2.py b/magic_pdf/pdf_parse_union_core_v2.py index eee1c04a..1fd7604b 100644 --- a/magic_pdf/pdf_parse_union_core_v2.py +++ b/magic_pdf/pdf_parse_union_core_v2.py @@ -1,3 +1,4 @@ +import os import statistics import time @@ -9,6 +10,7 @@ from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.commons import fitz, get_delta_time +from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.hash_utils import compute_md5 @@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans): return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans -def model_init(model_name: str, local_path=None): +def model_init(model_name: str): from transformers import LayoutLMv3ForTokenClassification if torch.cuda.is_available(): device = torch.device("cuda") @@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None): supports_bfloat16 = False if model_name == "layoutreader": - if local_path: - model = LayoutLMv3ForTokenClassification.from_pretrained(local_path) + # 检测modelscope的缓存目录是否存在 + layoutreader_model_dir = get_local_layoutreader_model_dir() + if os.path.exists(layoutreader_model_dir): + model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir) else: + logger.warning( + f"local layoutreader model not exists, use online model from huggingface") model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader") # 检查设备是否支持 bfloat16 if supports_bfloat16: @@ -131,12 +137,9 @@ def __new__(cls, *args, **kwargs): cls._instance = super().__new__(cls) return cls._instance - def get_model(self, model_name: str, local_path=None): + def get_model(self, model_name: str): if model_name not in self._models: - if local_path: - self._models[model_name] = model_init(model_name=model_name, local_path=local_path) - else: - self._models[model_name] = model_init(model_name=model_name) + self._models[model_name] = model_init(model_name=model_name) return self._models[model_name]