-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathocr.py
175 lines (156 loc) · 7.16 KB
/
ocr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# from paddleocr import PaddleOCR
# from openai import OpenAI
# import re
# import requests
# from io import BytesIO
# from PIL import Image
# import json
# import os
# #OCR Initialization
# cls_path = 'files/paddleOCR/ch_ppocr_mobile_v2.0_cls_slim_infer'
# rec_path = 'files/paddleOCR/ch_PP-OCRv3_rec_slim_infer'
# det_path = 'files/paddleOCR/ch_PP-OCRv3_det_slim_infer'
# ocr = PaddleOCR(det_model_dir=det_path, rec_model_dir=rec_path, cls_model_dir=cls_path, use_angle_cls=True)
# #OpenAI initialisation
# client = OpenAI(
# api_key=os.environ.get("OPENAI_API_KEY"),
# )
# #OCR Hyperparameters
# left_determining_ratio = 0.2 #determines what's considered messages that start on the left
# same_height_tolerance_ratio = 0.3 #proportion of the max of current and previous line heights, that is the tolerance for determining whether we assume current line height to be same as previous line
# same_left_alignment_tolerance_ratio = 0.01 #proportion of total image width, that is the tolerance for determining 2 lines are both left aligned
# threshold = 6 #1/threshold is the proportion of messages from top that are ignored when looking for timestamp
# #OpenAI Hyperparameters
# prompts = json.load(open("files/prompts.json"))
# ##FUNCTIONS##
# def perform_ocr(img_url):
# """
# Performs OCR on the image at the given url, and returns the OCR result and the image size
# """
# response = requests.get(img_url)
# img = Image.open(BytesIO(response.content))
# img_size = img.size
# result = ocr.ocr(response.content, cls=True)
# return result[0], img_size
# def group_messages(ocr_result, image_size):
# """
# Groups the OCR result into messages, and returns a list of dictionaries, each dictionary representing a message
# """
# img_width = image_size[0]
# img_height = image_size[1]
# def get_snippet_features(img_snippet):
# left_edge = (img_snippet[0][0][0] + img_snippet[0][3][0]) / 2
# right_edge = (img_snippet[0][1][0] + img_snippet[0][2][0]) / 2
# top_edge = (img_snippet[0][0][1] + img_snippet[0][1][1]) / 2
# bottom_edge = (img_snippet[0][2][1] + img_snippet[0][3][1]) / 2
# height = bottom_edge - top_edge
# return left_edge, right_edge, top_edge, bottom_edge, height
# def replace_time_format(s):
# pattern = '([01]?[0-9]|2[0-3]):[0-5][0-9]'
# return re.sub(pattern, "", s).replace("AM","").replace("PM","") #i notice sometimes it puts out AMV/PMV instead of AM/PM
# message_groups = []
# message_group = {}
# previous_left_edge = None
# previous_right_edge = None
# previous_top_edge = None
# previous_bottom_edge = None
# previous_height = None
# for index, snippet in enumerate(ocr_result):
# left_edge, right_edge, top_edge, bottom_edge, height = get_snippet_features(snippet)
# is_left = left_edge <= img_width * left_determining_ratio
# is_right = not is_left
# text = snippet[1][0]
# if index == 0:
# pass
# else:
# is_same_height_as_previous = abs(height - previous_height) <= max(height, previous_height)*same_height_tolerance_ratio
# is_same_left_alignment = abs(left_edge - previous_left_edge) <= img_width * same_left_alignment_tolerance_ratio
# is_timestamp = (len(text) - len(replace_time_format(text))) >= 2
# if is_timestamp:
# continue
# if not ((is_left and is_previous_left and is_same_height_as_previous and is_same_left_alignment) or (is_right and is_previous_right and is_same_height_as_previous)):
# message_groups.append(message_group)
# message_group = {}
# message_group["bottom_edge"] = max(bottom_edge, message_group.get("bottom_edge",0))
# message_group["right_edge"] = max(right_edge, message_group.get("right_edge",0))
# message_group["left_edge"] = min(left_edge, message_group.get("left_edge",right_edge))
# message_group["top_edge"] = min(top_edge, message_group.get("top_edge",bottom_edge))
# message_group["is_left"] = is_left
# if ("text" in message_group and previous_bottom_edge and ((top_edge - previous_bottom_edge) > height)):
# message_group["text"] = message_group.get("text","") + "\n\n"
# message_group["text"] = message_group.get("text","") + " "+ text
# previous_left_edge = left_edge
# previous_right_edge = right_edge
# previous_top_edge = top_edge
# previous_bottom_edge = bottom_edge
# previous_height = height
# is_previous_left = is_left
# is_previous_right = is_right
# if message_group:
# message_groups.append(message_group)
# return message_groups
# def extract_meaningful_groups(message_groups):
# """
# calls openai to add spaces to the messages, and extract those that are relevant
# """
# condensed_message_groups = json.dumps([{"is_left": message_group["is_left"], "text": message_group["text"]} for message_group in message_groups], indent=2)
# relevant_prompts = prompts.get("extract-convo",{})
# system_message_template = relevant_prompts.get("system","")
# assert system_message_template
# user_message_example = relevant_prompts.get("user-example","")
# assert user_message_example
# ai_message_example = relevant_prompts.get("ai-example","")
# assert ai_message_example
# MODEL = "gpt-3.5-turbo"
# response = client.chat.completions.create(
# model=MODEL,
# messages=[
# {"role": "system", "content": system_message_template},
# {"role": "user", "content": user_message_example},
# {"role": "assistant", "content": ai_message_example},
# {"role": "user", "content": condensed_message_groups},
# ],
# temperature=0,
# seed=11,
# #response_format= {"type":"json_object"}
# )
# try:
# return json.loads(response.choices[0].message.content)
# except Exception as e:
# print("Error:", e)
# return condensed_message_groups
# def check_convo(messages, threshold):
# """
# function to check if the screenshot is a conversation from the messages. It basically looks for standalone timestamps which are indicative of messages.
# """
# pattern = r'^([01]?[0-9]|2[0-3]):[0-5][0-9](?:\s?[apAP][mM])?$'
# message_length = len(messages)
# cutoff = max(1,int(message_length/threshold)) #only starts finding timestamps from below here
# for message in messages[cutoff:]:
# if re.match(pattern, message):
# return True
# return False
# def get_impt_message(final_output):
# """
# of all remaining messages, return the longest one which we will then be considered the longest message
# """
# try:
# if len(final_output["text_messages"]) > 0:
# longest_text_item = max(final_output["text_messages"], key=lambda x: len(x["text"]))
# return longest_text_item["text"]
# else:
# return ""
# except:
# return ""
# def end_to_end(img_url):
# """
# runs the process end to end
# """
# ocr_result, img_size = perform_ocr(img_url)
# message_groups = group_messages(ocr_result, img_size)
# output = extract_meaningful_groups(message_groups)
# ocr_output_processed = [item[1][0] for item in ocr_result]
# is_convo = check_convo(ocr_output_processed, threshold)
# important_message = get_impt_message(output)
# sender = output.get("sender_name_or_phone_number", "")
# return output, is_convo, important_message, sender