Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(narugo): add quick gradio demo for classifiers/yolos #108

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_doc/generic/classify.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
:members: __init__, predict_score, predict, clear
:members: __init__, predict_score, predict, clear, make_ui, launch_demo



Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_doc/generic/yolo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ YOLOModel
-----------------------------------------

.. autoclass:: YOLOModel
:members: __init__, predict, clear
:members: __init__, predict, clear, make_ui, launch_demo



Expand Down
112 changes: 112 additions & 0 deletions imgutils/generic/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,37 @@

import numpy as np
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model

try:
import gradio as gr
except (ImportError, ModuleNotFoundError):
gr = None

__all__ = [
'ClassifyModel',
'classify_predict_score',
'classify_predict',
]


def _check_gradio_env():
"""
Check if the Gradio library is installed and available.

:raises EnvironmentError: If Gradio is not installed.
"""
if gr is None:
raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'

Check warning on line 53 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L52-L53

Added lines #L52 - L53 were not covered by tests
f'Please install it with `pip install dghs-imgutils[demo]`.')


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Expand Down Expand Up @@ -287,6 +305,100 @@
self._models.clear()
self._labels.clear()

def make_ui(self, default_model_name: Optional[str] = None):
"""
Create the user interface components for the classifier model demo.

This method sets up the Gradio UI components including an image input, model selection dropdown,
submit button, and output label. It also configures the interaction between these components.

:param default_model_name: The name of the default model to be selected in the dropdown.
If None, the most recently updated model will be selected.
:type default_model_name: Optional[str]

:raises ImportError: If Gradio is not installed or properly configured.

:Example:
>>> model = ClassifyModel("username/repo_name")
>>> model.make_ui(default_model_name="model_v1")
"""

# demo for classifier model
_check_gradio_env()
model_list = self.model_names
if not default_model_name:
hf_client = get_hf_client(hf_token=self._get_hf_token())
selected_model_name, selected_time = None, None
for fileitem in hf_client.get_paths_info(

Check warning on line 332 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L327-L332

Added lines #L327 - L332 were not covered by tests
repo_id=self.repo_id,
repo_type='model',
paths=[f'{model_name}/model.onnx' for model_name in model_list],
expand=True,
):
if not selected_time or fileitem.last_commit.date > selected_time:
selected_model_name = os.path.dirname(fileitem.path)
selected_time = fileitem.last_commit.date
default_model_name = selected_model_name

Check warning on line 341 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L338-L341

Added lines #L338 - L341 were not covered by tests

with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
gr_submit = gr.Button(value='Submit', variant='primary')

Check warning on line 347 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L343-L347

Added lines #L343 - L347 were not covered by tests

with gr.Column():
gr_output = gr.Label(label='Prediction')

Check warning on line 350 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L349-L350

Added lines #L349 - L350 were not covered by tests

gr_submit.click(

Check warning on line 352 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L352

Added line #L352 was not covered by tests
self.predict_score,
inputs=[
gr_input_image,
gr_model,
],
outputs=[gr_output],
)

def launch_demo(self, default_model_name: Optional[str] = None,
server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
"""
Launch the Gradio demo for the classifier model.

This method creates a Gradio Blocks interface, sets up the UI components using make_ui(),
and launches the demo server.

:param default_model_name: The name of the default model to be selected in the dropdown.
:type default_model_name: Optional[str]
:param server_name: The name of the server to run the demo on. Defaults to None.
:type server_name: Optional[str]
:param server_port: The port number to run the demo on. Defaults to None.
:type server_port: Optional[int]
:param kwargs: Additional keyword arguments to pass to the Gradio launch method.

:raises ImportError: If Gradio is not installed or properly configured.

:Example:
>>> model = ClassifyModel("username/repo_name")
>>> model.launch_demo(default_model_name="model_v1", server_name="0.0.0.0", server_port=7860)
"""

_check_gradio_env()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
gr.HTML(f'<h2 style="text-align: center;">Classifier Demo For {self.repo_id}</h2>')
gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). '

Check warning on line 390 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L384-L390

Added lines #L384 - L390 were not covered by tests
f'Powered by `dghs-imgutils`\'s quick demo module.')

with gr.Row():
self.make_ui(default_model_name=default_model_name)

Check warning on line 394 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L393-L394

Added lines #L393 - L394 were not covered by tests

demo.launch(

Check warning on line 396 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L396

Added line #L396 was not covered by tests
server_name=server_name,
server_port=server_port,
**kwargs,
)


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
Expand Down
180 changes: 180 additions & 0 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,61 @@

import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model

try:
import gradio as gr
except (ImportError, ModuleNotFoundError):
gr = None

__all__ = [
'YOLOModel',
'yolo_predict',
]


def _check_gradio_env():
"""
Check if the Gradio library is installed and available.

:raises EnvironmentError: If Gradio is not installed.
"""
if gr is None:
raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'

Check warning on line 50 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L49-L50

Added lines #L49 - L50 were not covered by tests
f'Please install it with `pip install dghs-imgutils[demo]`.')


def _v_fix(v):
"""
Round and convert a float value to an integer.

:param v: The float value to be rounded and converted.
:type v: float
:return: The rounded integer value.
:rtype: int
"""
return int(round(v))

Check warning on line 63 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L63

Added line #L63 was not covered by tests


def _bbox_fix(bbox):
"""
Fix the bounding box coordinates by rounding them to integers.

:param bbox: The bounding box coordinates.
:type bbox: tuple
:return: A tuple of fixed (rounded to integer) bounding box coordinates.
:rtype: tuple
"""
return tuple(map(_v_fix, bbox))

Check warning on line 75 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L75

Added line #L75 was not covered by tests


def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format.
Expand Down Expand Up @@ -403,9 +446,146 @@
def clear(self):
"""
Clear cached model and metadata.

This method removes all cached models and their associated metadata from memory.
It's useful for freeing up memory or ensuring that the latest versions of models are loaded.
"""
self._models.clear()

def make_ui(self, default_model_name: Optional[str] = None,
default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
"""
Create a Gradio-based user interface for object detection.

This method sets up an interactive UI that allows users to upload images,
select models, and adjust detection parameters. It uses the Gradio library
to create the interface.

:param default_model_name: The name of the default model to use.
If None, the most recently updated model is selected.
:type default_model_name: Optional[str]
:param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
:type default_conf_threshold: float
:param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
:type default_iou_threshold: float

:raises ImportError: If Gradio is not installed in the environment.

:Example:

>>> model = YOLOModel("username/repo_name")
>>> model.make_ui(default_model_name="yolov5s")
"""
_check_gradio_env()
model_list = self.model_names
if not default_model_name:
hf_client = get_hf_client(hf_token=self._get_hf_token())
selected_model_name, selected_time = None, None
for fileitem in hf_client.get_paths_info(

Check warning on line 484 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L479-L484

Added lines #L479 - L484 were not covered by tests
repo_id=self.repo_id,
repo_type='model',
paths=[f'{model_name}/model.onnx' for model_name in model_list],
expand=True,
):
if not selected_time or fileitem.last_commit.date > selected_time:
selected_model_name = os.path.dirname(fileitem.path)
selected_time = fileitem.last_commit.date
default_model_name = selected_model_name

Check warning on line 493 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L490-L493

Added lines #L490 - L493 were not covered by tests

def _gr_detect(image: ImageTyping, model_name: str,

Check warning on line 495 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L495

Added line #L495 was not covered by tests
iou_threshold: float = 0.7, score_threshold: float = 0.25) \
-> gr.AnnotatedImage:
_, _, labels = self._open_model(model_name=model_name)
_colors = list(map(str, rnd_colors(len(labels))))
_color_map = dict(zip(labels, _colors))
return gr.AnnotatedImage(

Check warning on line 501 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L498-L501

Added lines #L498 - L501 were not covered by tests
value=(image, [
(_bbox_fix(bbox), label)
for bbox, label, _ in self.predict(
image=image,
model_name=model_name,
iou_threshold=iou_threshold,
conf_threshold=score_threshold,
)
]),
color_map=_color_map,
label='Labeled',
)

with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
with gr.Row():
gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')

Check warning on line 521 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L515-L521

Added lines #L515 - L521 were not covered by tests

gr_submit = gr.Button(value='Submit', variant='primary')

Check warning on line 523 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L523

Added line #L523 was not covered by tests

with gr.Column():
gr_output_image = gr.AnnotatedImage(label="Labeled")

Check warning on line 526 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L525-L526

Added lines #L525 - L526 were not covered by tests

gr_submit.click(

Check warning on line 528 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L528

Added line #L528 was not covered by tests
_gr_detect,
inputs=[
gr_input_image,
gr_model,
gr_iou_threshold,
gr_score_threshold,
],
outputs=[gr_output_image],
)

def launch_demo(self, default_model_name: Optional[str] = None,
default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
"""
Launch a Gradio demo for object detection.

This method creates and launches a Gradio demo that allows users to interactively
perform object detection on uploaded images using the YOLO model.

:param default_model_name: The name of the default model to use.
If None, the most recently updated model is selected.
:type default_model_name: Optional[str]
:param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
:type default_conf_threshold: float
:param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
:type default_iou_threshold: float
:param server_name: The name of the server to run the demo on. Default is None.
:type server_name: Optional[str]
:param server_port: The port to run the demo on. Default is None.
:type server_port: Optional[int]
:param kwargs: Additional keyword arguments to pass to gr.Blocks.launch().

:raises EnvironmentError: If Gradio is not installed in the environment.

Example:
>>> model = YOLOModel("username/repo_name")
>>> model.launch_demo(default_model_name="yolov5s", server_name="0.0.0.0", server_port=7860)
"""
_check_gradio_env()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
gr.HTML(f'<h2 style="text-align: center;">YOLO Demo For {self.repo_id}</h2>')
gr.Markdown(f'This is the quick demo for YOLO model [{self.repo_id}]({repo_url}). '

Check warning on line 573 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L567-L573

Added lines #L567 - L573 were not covered by tests
f'Powered by `dghs-imgutils`\'s quick demo module.')

with gr.Row():
self.make_ui(

Check warning on line 577 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L576-L577

Added lines #L576 - L577 were not covered by tests
default_model_name=default_model_name,
default_conf_threshold=default_conf_threshold,
default_iou_threshold=default_iou_threshold,
)

demo.launch(

Check warning on line 583 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L583

Added line #L583 was not covered by tests
server_name=server_name,
server_port=server_port,
**kwargs,
)


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel:
Expand Down
1 change: 1 addition & 0 deletions requirements-demo.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gradio>=4.44.0
Loading