Skip to content

Commit

Permalink
Merge pull request #163 from sucuicong/master
Browse files Browse the repository at this point in the history
表格文字识别、条形码/二维码识别、身份证混贴识别、手写文字识别支持function call
  • Loading branch information
seiriosPlus authored Mar 6, 2024
2 parents 1138f55 + 9f6b4fc commit 8bd217d
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 50 deletions.
8 changes: 4 additions & 4 deletions appbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ def check_version(self):
'Message',
'AnimalRecognition',
'DocCropEnhance',
QRcodeOCR,
TableOCR,
'QRcodeOCR',
'TableOCR',

'Embedding',

'Matching',

"PlantRecognition",
HandwriteOCR,
"HandwriteOCR",
"ImageUnderstand",
MixCardOCR,
"MixCardOCR",
]
44 changes: 19 additions & 25 deletions appbuilder/core/components/handwrite_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def run(self, message: Message, timeout: float = None, retry: int = 0) -> Messag
return Message(content=out.model_dump())

def tool_eval(self, name: str, streaming: bool, **kwargs):

result = {}
file_names = kwargs.get("file_names", None)
if not file_names:
Expand All @@ -108,6 +109,24 @@ def tool_eval(self, name: str, streaming: bool, **kwargs):
file_url = file_urls.get(file_name, None)
if file_url is None:
raise InvalidRequestArgumentError(f"file {file_name} url does not exist")
req = HandwriteOCRRequest()
req.url = file_url
req.recognize_granularity = "big"
req.probability = "false"
req.detect_direction = "true"
req.detect_alteration = "true"
response = self._recognize(req)
out = HandwriteOCROutMsg()
out.direction = response.direction
[out.contents.append(
Content(text=w.words))
for w in response.words_result]
result[file_name] = out.dict()

if streaming:
yield json.dumps(result, ensure_ascii=False)
else:
return json.dumps(result, ensure_ascii=False)

def _recognize(self, request: HandwriteOCRRequest, timeout: float = None, retry: int = 0) -> HandwriteOCRResponse:
r"""调用底层接口进行通用文字识别
Expand Down Expand Up @@ -150,28 +169,3 @@ def _check_service_error(request_id: str, data: dict):
service_err_message=data.get("error_msg")
)

req = HandwriteOCRRequest()
req.url = file_url
req.recognize_granularity = "big"
req.probability = "false"
req.detect_direction = "true"
req.detect_alteration = "true"
response = self._recognize(req)
out = HandwriteOCROutMsg()
out.direction = response.direction
[out.contents.append(
Content(text=w.words,
position=Position(
left=w.location.left,
top=w.location.top,
width=w.location.width,
height=w.location.height
)))
for w in response.words_result]
result[file_name] = out.dict()

if streaming:
yield json.dumps(result)
else:
return json.dumps(result)

5 changes: 2 additions & 3 deletions appbuilder/core/components/handwrite_ocr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""手写文字识别数据类"""
import proto
from typing import List
from typing import List, Optional
from pydantic import BaseModel


Expand Down Expand Up @@ -218,7 +218,6 @@ class HandwriteOCRInMsg(BaseModel):
url: str = "" # 图片可下载链接



class Position(BaseModel):
"""位置信息
Expand All @@ -243,7 +242,7 @@ class Content(BaseModel):
position(Position): 文字内容的位置信息
"""
text: str
position: Position
position: Optional[Position] = None


class HandwriteOCROutMsg(BaseModel):
Expand Down
10 changes: 3 additions & 7 deletions appbuilder/core/components/mix_card_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,12 @@ def tool_eval(self, name: str, streaming: bool, **kwargs):
ref = out.front
if card_type == "idcard_back":
ref = out.back
loc = res.card_info.card_location
ref.position = MixCardPosition(left=loc.left, top=loc.top, width=loc.width, height=loc.height)
for key, val in res.card_result.items():
position = MixCardPosition(left=val.location.left, top=val.location.top, width=val.location.width,
height=val.location.height)
ref.fields.append(MixCardField(key=key, value=val.words, position=position))
ref.fields.append(MixCardField(key=key, value=val.words, position=None))
out.direction = response.direction
result[file_name] = out.dict()

if streaming:
yield json.dumps(result)
yield json.dumps(result, ensure_ascii=False)
else:
return json.dumps(result)
return json.dumps(result, ensure_ascii=False)
8 changes: 4 additions & 4 deletions appbuilder/core/components/mix_card_ocr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""身份证混贴数据类"""
import proto
from typing import List
from typing import List, Optional
from pydantic import BaseModel


Expand Down Expand Up @@ -223,7 +223,7 @@ class MixCardField(BaseModel):

key: str
value: str
position: MixCardPosition
position: Optional[MixCardPosition] = None


class MixCardContent(BaseModel):
Expand All @@ -245,6 +245,6 @@ class MixCardOCROutMsg(BaseModel):
back(MixCardField): 国徽面信息
direction(int): 图像旋转角度,0(正向),- 1(逆时针90度),- 2(逆时针180度),- 3(逆时针270度)
"""
front: MixCardField = MixCardContent()
back: MixCardField = MixCardContent()
front: MixCardContent = MixCardContent()
back: MixCardContent = MixCardContent()
direction: int = 0
8 changes: 4 additions & 4 deletions appbuilder/core/components/qrcode_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _check_service_error(request_id: str, data: dict):
def tool_eval(self, name: str, streaming: bool, **kwargs):
result = {}
file_names = kwargs.get("file_names", None)
location = kwargs.get("locations", "true")
location = kwargs.get("locations", "false")
if not file_names:
file_names = kwargs.get("files")

Expand All @@ -163,8 +163,8 @@ def tool_eval(self, name: str, streaming: bool, **kwargs):
raise InvalidRequestArgumentError("location must be a string with value 'true' or 'false'")
req.location = location
resp = self._recognize(req)
result[file_name] = proto.Message.to_dict(resp)
result[file_name] = proto.Message.to_dict(resp)["codes_result"]
if streaming:
yield json.dumps(result)
yield json.dumps(result, ensure_ascii=False)
else:
return json.dumps(result)
return json.dumps(result, ensure_ascii=False)
38 changes: 35 additions & 3 deletions appbuilder/core/components/table_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,36 @@ def _check_service_error(request_id: str, data: dict):
service_err_message=data.get("error_msg")
)

def get_table_markdown(self, tables_result):
"""
根据识别到的表格等结果转化成markdown
:param tables_result:
:return:
"""
markdowns = []
for table in tables_result:
cells = table["body"]
max_row = max(cell['row_end'] for cell in cells)
max_col = max(cell['col_end'] for cell in cells)
# 初始化表格数组
table_arr = [[''] * max_col for _ in range(max_row)]
# 填充表格数据
for cell in cells:
row = cell['row_start']
col = cell['col_start']
table_arr[row][col] = cell['words']

markdown_table = ""
for row in table_arr:
markdown_table += "| " + " | ".join(row) + " |\n"
# 生成分隔行
separator = "| " + " | ".join(['---'] * max_col) + " |\n"
# 插入分隔行在表头下方
header, body = markdown_table.split('\n', 1)
markdown_table = header + '\n' + separator + body
markdowns.append(markdown_table)
return markdowns

def tool_eval(self, name: str, streaming: bool, **kwargs):
result = {}
file_names = kwargs.get("file_names", None)
Expand All @@ -157,8 +187,10 @@ def tool_eval(self, name: str, streaming: bool, **kwargs):
req.url = file_url
req.cell_contents = "false"
resp = self._recognize(req)
result[file_name] = proto.Message.to_dict(resp)
tables_result = proto.Message.to_dict(resp)["tables_result"]
markdowns = self.get_table_markdown(tables_result)
result[file_name] = markdowns
if streaming:
yield json.dumps(result)
yield json.dumps(result, ensure_ascii=False)
else:
return json.dumps(result)
return json.dumps(result, ensure_ascii=False)

0 comments on commit 8bd217d

Please sign in to comment.