-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
99 lines (84 loc) · 3.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Main inference script for Nougat."""
import argparse
import logging
import os
from PIL import Image
import re
import requests
import tempfile
from typing import List, Union
from pathlib import Path
from vqa.model_lib.nougat import NougatModel
from vqa.rasterize import rasterize_paper
from fastapi import FastAPI
app = FastAPI()
def download_papers(paper_ids: Union[str, List[str]], output_dir: Path) -> None:
"""Download papers from export.arxiv.org and save in a temporary folder."""
output_dir.mkdir(parents=True, exist_ok=True)
if isinstance(paper_ids, str):
paper_ids = [paper_ids]
base_url = "https://export.arxiv.org/pdf/"
for pid in paper_ids:
url = base_url + pid + ".pdf"
r = requests.get(url, stream=True)
with open(os.path.join(output_dir, pid.replace(".", "_") + ".pdf"), "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
f.write(chunk)
@app.get("/nougat/{paper_id}")
def main(paper_id: str) -> None:
checkpoint_dir = Path("./checkpoint")
output_dir = Path("./output")
output_dir.mkdir(parents=True, exist_ok=True)
# Load the model
logging.info("Loading the model...")
model = NougatModel.from_pretrained(checkpoint_dir)
# Download the papers
logging.info("Downloading the papers...")
pdf_dir = output_dir / "pdfs"
download_papers(paper_id, pdf_dir)
# Extract text from the papers
logging.info("Extracting text from the papers...")
paper_paths = list(pdf_dir.rglob("*.pdf"))
for paper_path in paper_paths:
# rasterize the pdf into images
paper_pages = rasterize_paper(paper_path)
predictions = []
for i, paper_page in enumerate(paper_pages):
page_image = Image.open(paper_page)
page_image = model.encoder.prepare_input(page_image, random_padding=False)
model_output = model.inference(image_tensors=page_image.unsqueeze(dim=0))
# check if model output is faulty
output = model_output["predictions"][0]
if i == 0:
logging.info(
"Processing file %s with %i pages"
% (paper_path.name, i)
)
if output.strip() == "[MISSING_PAGE_POST]":
# uncaught repetitions -- most likely empty page
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{i}]\n\n")
elif model_output["repeats"][0] is not None:
if model_output["repeats"][0] > 0:
# If we end up here, it means the output is most likely not complete and was truncated.
logging.warning(f"Skipping page {i} due to repetitions.")
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{i}]\n\n")
else:
# If we end up here, it means the document page is too different from the training domain.
# This can happen e.g. for cover pages.
predictions.append(
f"\n\n[MISSING_PAGE_EMPTY:{i}]\n\n"
)
else:
predictions.append(output)
out_text = "".join(predictions).strip()
out_text = re.sub(r"\n{3,}", "\n\n", out_text).strip()
out_path = output_dir / "txts" / (paper_path.stem + ".txt")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(out_text)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_dir", "-c", type=Path, default="./checkpoint", help="Path to checkpoint directory.")
parser.add_argument("--paper_id", "-pid", nargs='+', type=str, help="Paper ID(s) for inference.")
parser.add_argument("--output_dir", "-o", type=Path, default="./output", help="Path to output directory.")
args = parser.parse_args()
main(args)