Skip to content

Commit

Permalink
support saveimage for text_rec and fix PIL version
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored and TingquanGao committed Oct 16, 2024
1 parent 3051653 commit bffbdb1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
11 changes: 8 additions & 3 deletions paddlex/inference/results/clas.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class MLClassResult(TopkResult):
def _to_img(self):
"""Draw label on image"""
image = self._img_reader.read(self["input_path"])
image = image.convert("RGB")
label_names = self["label_names"]
scores = self["scores"]
image = image.convert("RGB")
Expand All @@ -104,7 +103,10 @@ def _to_img(self):
row_text = "\t"
for label_name, score in zip(label_names, scores):
text = f"{label_name}({score})\t"
text_width, row_height = font.getsize(text)
if int(PIL.__version__.split(".")[0]) < 10:
text_width, row_height = font.getsize(text)
else:
text_width, row_height = font.getbbox(text)[2:]
if row_width + text_width <= image_width:
row_text += text
row_width += text_width
Expand All @@ -122,7 +124,10 @@ def _to_img(self):
draw = ImageDraw.Draw(new_image)
font_color = tuple(self._get_font_colormap(3))
for i, text in enumerate(text_lines):
text_width, _ = font.getsize(text)
if int(PIL.__version__.split(".")[0]) < 10:
text_width, _ = font.getsize(text)
else:
text_width, _ = font.getbbox(text)[2:]
draw.text(
(0, image_height + i * int(row_height * 1.2)),
text,
Expand Down
48 changes: 46 additions & 2 deletions paddlex/inference/results/text_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import PIL
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2

from ...utils.fonts import PINGFANG_FONT_FILE_PATH
from ...utils import logging
from .base import CVResult


class TextRecResult(CVResult):
def _to_img(self):
logging.warning("TextRecResult don't support save to img!")
return None
"""Draw label on image"""
image = self._img_reader.read(self["input_path"])
rec_text = self["rec_text"]
rec_score = self["rec_score"]
image = image.convert("RGB")
image_width, image_height = image.size
text = f"{rec_text} ({rec_score})"
font = self.adjust_font_size(image_width, text, PINGFANG_FONT_FILE_PATH)
row_height = font.getbbox(text)[3]
new_image_height = image_height + int(row_height * 1.2)
new_image = Image.new("RGB", (image_width, new_image_height), (255, 255, 255))
new_image.paste(image, (0, 0))

draw = ImageDraw.Draw(new_image)
draw.text(
(0, image_height),
text,
fill=(0, 0, 0),
font=font,
)
return new_image

def adjust_font_size(self, image_width, text, font_path):
font_size = int(image_width * 0.06)
font = ImageFont.truetype(font_path, font_size)

if int(PIL.__version__.split(".")[0]) < 10:
text_width, _ = font.getsize(text)
else:
text_width, _ = font.getbbox(text)[2:]

while text_width > image_width:
font_size -= 1
font = ImageFont.truetype(font_path, font_size)
if int(PIL.__version__.split(".")[0]) < 10:
text_width, _ = font.getsize(text)
else:
text_width, _ = font.getbbox(text)[2:]

return font

0 comments on commit bffbdb1

Please sign in to comment.