-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
100 lines (81 loc) · 3.97 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import logging
from typing import Union, List, Tuple, Iterable
# mostly likely you'll need these modules/classes
from clams import ClamsApp, Restifier
from mmif import Mmif, View, Document, AnnotationTypes, DocumentTypes
from mmif.utils import video_document_helper as vdh
import torch
from transformers import Pix2StructForConditionalGeneration as psg
from transformers import Pix2StructProcessor as psp
class Pix2structChyrons(ClamsApp):
def __init__(self):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = psg.from_pretrained("google/pix2struct-docvqa-base",
revision="4dde85b6b60b3765bb1e50c089828f7b0b2f999d").to(self.device)
self.processor = psp.from_pretrained("google/pix2struct-docvqa-base")
def _appmetadata(self):
# see https://sdk.clams.ai/autodoc/clams.app.html#clams.app.ClamsApp._load_appmetadata
# Also check out ``metadata.py`` in this directory.
# When using the ``metadata.py`` leave this do-nothing "pass" method here.
pass
def generate(self, img, questions):
"""
Generate answers for a list of questions using the model
:param img:
:param questions:
:return:
"""
inputs = self.processor(images=[img for _ in range(len(questions))],
text=questions, return_tensors="pt").to(self.device)
predictions = self.model.generate(**inputs, max_new_tokens=256)
return zip(questions, self.processor.batch_decode(predictions, skip_special_tokens=True))
def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
video_doc: Document = mmif.get_documents_by_type(DocumentTypes.VideoDocument)[0]
input_view: View = mmif.get_views_for_document(video_doc.properties.id)[0]
config = self.get_configuration(**parameters)
new_view: View = mmif.new_view()
self.sign_view(new_view, parameters)
new_view.new_contain(
AnnotationTypes.Alignment,
document=video_doc.id,
)
query_to_label = {
"What is the name of the person in the image": "person_name",
"What is the the person's description": "person_description"
}
queries = list(query_to_label.keys())
for timeframe in input_view.get_annotations(AnnotationTypes.TimeFrame, frameType="chyron"):
self.logger.debug(timeframe.properties)
# get images from time frame
image = vdh.extract_mid_frame(mmif, timeframe, as_PIL=True)
completions = self.generate(image, queries)
for query, answer in completions:
self.logger.debug(f"query: {query} answer: {answer}")
# add question answer pairs as properties to timeframe
text_document = new_view.new_textdocument(answer)
text_document.add_property("query", query)
text_document.add_property("label", query_to_label[query])
align_annotation = new_view.new_annotation(AnnotationTypes.Alignment)
align_annotation.add_property("source", timeframe.id)
align_annotation.add_property("target", text_document.id)
pass
return mmif
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", action="store", default="5000", help="set port to listen")
parser.add_argument("--production", action="store_true", help="run gunicorn server")
# add more arguments as needed
# parser.add_argument(more_arg...)
parsed_args = parser.parse_args()
# create the app instance
app = Pix2structChyrons()
http_app = Restifier(app, port=int(parsed_args.port))
# for running the application in production mode
if parsed_args.production:
http_app.serve_production()
# development mode
else:
app.logger.setLevel(logging.DEBUG)
http_app.run()