Skip to content
This repository has been archived by the owner on May 29, 2024. It is now read-only.

v2 ocr addon #9

Open
wants to merge 1 commit into
base: starlette-api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from base import ROOT_DIR
from PIL import Image
import os
from utils import main

from utils import main
from base import reader
from routes import predict_v2

async def predict(request):
form = await request.form()
Expand All @@ -29,9 +31,33 @@ async def predict(request):
},
)

async def get_ocr_data(request):
form = await request.form()
filename = form["image"].filename
contents = await form["image"].read()
contents = io.BytesIO(contents)
save_path = str(ROOT_DIR) + "/image/" + filename
Image.open(contents).save(save_path)

data = reader.readtext(save_path)
data = [{"bounding_box": d[0], "text": d[1], "confidence": d[2]} for d in data]
for d in data:
d["bounding_box"] = [[int(x), int(y)] for x, y in d["bounding_box"]]
os.remove(save_path)

return JSONResponse(
{
"status": "success",
"data": data,
},
)



routes = [
Route('/api/predict', predict, methods=['POST'])
Route('/api/predict', predict, methods=['POST']),
Route('/api/get-ocr-data', get_ocr_data, methods=['POST']),
Route('/api/predict-v2', predict_v2, methods=['POST']),
]

app = Starlette(debug=True, routes=routes)
Expand Down
9 changes: 8 additions & 1 deletion base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from pathlib import Path
import easyocr
from dotenv import dotenv_values
import os

config = dotenv_values(".env")


# print("-"*20)
# print("Initializeds")
Expand All @@ -11,4 +16,6 @@

YOLO_PATH = str(ROOT_DIR / "yolov5")

reader = easyocr.Reader(['en'])
reader = easyocr.Reader(['en'])

OPENAI_API_KEY = config.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ torchvision>=0.8.1
tqdm>=4.64.0
protobuf<4.21.3
easyocr>=1.6.2
python-dotenv==1.0.0
# Logging -------------------------------------
# tensorboard>=2.4.1
# wandb
Expand All @@ -31,3 +32,8 @@ thop>=0.1.1 # FLOPs computation
uvicorn==0.20.0
python-multipart==0.0.5
starlette==0.22.0

# Ocr V2 --------------------------------------
langchain==0.0.292
openai==0.28.1
Jinja2==3.1.2
1 change: 1 addition & 0 deletions routes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from routes.predict_v2.index import predict_v2 # noqa
45 changes: 45 additions & 0 deletions routes/predict_v2/ai_helpers/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json

from langchain.chat_models import ChatOpenAI
from langchain.prompts import (
SystemMessagePromptTemplate,
ChatPromptTemplate,
)
from langchain import PromptTemplate, LLMChain

from base import OPENAI_API_KEY


class ChatChain:
def __init__(self, ocr_data):
llm = ChatOpenAI(temperature=0.2, openai_api_key=OPENAI_API_KEY, model="gpt-4")

template = f"""You are an OCR to JSON converter for 5ParaMonitor. You are given ocr json data for a 5 Para monitor, analyze the brand and predict patients reading, you will output the readings in JSON format.
5ParaMonitor OCR data: {json.dumps(ocr_data)}
Tips to analyze the ocr data: monitor can be zoomed in or zoomed out, ocr data is read from left to right of an image from top to bottom(with every row you go down), most of the times readings that we want are at extreme right of the monitor screen, there can be params like spo2, temp, bp etc present in ocr data, use them to your benefits and identify the correct value, temperature is always a decimal value, don't repeat a single value for multiple params, if you are not sure about a value, you can answer it as null, use common sense to get the correct field of a value.
Example output:
{{"spo2": "value/null", "resp": "value/null", "temperature": "value/null", "pulse":"value/null", "bp":"value/null"}}
"""

system_prompt = PromptTemplate(
template=template, input_variables=[], template_format="jinja2"
)
system_message_prompt = SystemMessagePromptTemplate(prompt=system_prompt)

chat_prompt = ChatPromptTemplate.from_messages(
[
system_message_prompt,
]
)

self.chain = LLMChain(
llm=llm,
prompt=chat_prompt,
verbose=True,
)

async def async_predict(self):
prediction = await self.chain.apredict()

parsed_prediction = json.loads(prediction)
return parsed_prediction
25 changes: 25 additions & 0 deletions routes/predict_v2/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import traceback

from starlette.responses import JSONResponse

from routes.predict_v2.ai_helpers.chain import ChatChain


async def predict_v2(request):

try:
data = await request.json()

if "ocr_data" not in data:
return JSONResponse(status_code= 400, content={"error": "ocr_data not found"})

ocr_data = data["ocr_data"]

chat_chain = ChatChain(ocr_data)
response = await chat_chain.async_predict()

return JSONResponse({"data": response})

except Exception as e:
traceback.print_exc()
return JSONResponse(status_code=500, content={"error": "Something went wrong"})