From 8a4098af1cc477c745d30e272b99bd8313168bc6 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 18 Sep 2024 01:23:03 +0800 Subject: [PATCH 1/5] dev(narugo): add gradio demo for yolo --- imgutils/generic/yolo.py | 87 ++++++++++++++++++++++++++++++++++++++++ requirements-demo.txt | 1 + 2 files changed, 88 insertions(+) create mode 100644 requirements-demo.txt diff --git a/imgutils/generic/yolo.py b/imgutils/generic/yolo.py index 3d51f5eca6f..11c684506bc 100644 --- a/imgutils/generic/yolo.py +++ b/imgutils/generic/yolo.py @@ -20,18 +20,39 @@ import numpy as np from PIL import Image +from hbutils.color import rnd_colors from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import HfFileSystem, hf_hub_download +from natsort import natsorted from ..data import load_image, rgb_encode, ImageTyping from ..utils import open_onnx_model +try: + import gradio as gr +except (ImportError, ModuleNotFoundError): + gr = None + __all__ = [ 'YOLOModel', 'yolo_predict', ] +def _check_gradio_env(): + if gr is None: + raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' + f'Please install it with `pip install dghs-imgutils[demo]`.') + + +def _v_fix(v): + return int(round(v)) + + +def _bbox_fix(bbox): + return tuple(map(_v_fix, bbox)) + + def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray: """ Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format. @@ -406,6 +427,72 @@ def clear(self): """ self._models.clear() + def make_ui(self, default_model_name: Optional[str] = None, + default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7): + _check_gradio_env() + model_list = self.model_names + default_model_name = default_model_name or natsorted(self.model_names)[-1] + + def _gr_detect(image: ImageTyping, model_name: str, + iou_threshold: float = 0.7, score_threshold: float = 0.25) \ + -> gr.AnnotatedImage: + _, _, labels = self._open_model(model_name=model_name) + _colors = list(map(str, rnd_colors(len(labels)))) + _color_map = dict(zip(labels, _colors)) + return gr.AnnotatedImage( + value=(image, [ + (_bbox_fix(bbox), label) + for bbox, label, _ in self.predict( + image=image, + model_name=model_name, + iou_threshold=iou_threshold, + conf_threshold=score_threshold, + ) + ]), + color_map=_color_map, + label='Labeled', + ) + + with gr.Row(): + with gr.Column(): + gr_input_image = gr.Image(type='pil', label='Original Image') + gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') + with gr.Row(): + gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold') + gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold') + + gr_submit = gr.Button(value='Submit', variant='primary') + + with gr.Column(): + gr_output_image = gr.AnnotatedImage(label="Labeled") + + gr_submit.click( + _gr_detect, + inputs=[ + gr_input_image, + gr_model, + gr_iou_threshold, + gr_score_threshold, + ], + outputs=[gr_output_image], + ) + + def launch_demo(self, default_model_name: Optional[str] = None, + default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7, + server_port: int = 7860, **kwargs): + _check_gradio_env() + with gr.Blocks() as demo: + self.make_ui( + default_model_name=default_model_name, + default_conf_threshold=default_conf_threshold, + default_iou_threshold=default_iou_threshold, + ) + + demo.launch( + server_port=server_port, + **kwargs, + ) + @lru_cache() def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel: diff --git a/requirements-demo.txt b/requirements-demo.txt new file mode 100644 index 00000000000..f85fcdfb026 --- /dev/null +++ b/requirements-demo.txt @@ -0,0 +1 @@ +gradio>=4.44.0 \ No newline at end of file From 27640891d3a8bf607e77c4fcf97508d398f22f5e Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 18 Sep 2024 01:38:09 +0800 Subject: [PATCH 2/5] dev(narugo): optimize ui --- imgutils/generic/yolo.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/imgutils/generic/yolo.py b/imgutils/generic/yolo.py index 11c684506bc..12392d88dcd 100644 --- a/imgutils/generic/yolo.py +++ b/imgutils/generic/yolo.py @@ -21,9 +21,10 @@ import numpy as np from PIL import Image from hbutils.color import rnd_colors +from hfutils.operate import get_hf_client +from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import HfFileSystem, hf_hub_download -from natsort import natsorted from ..data import load_image, rgb_encode, ImageTyping from ..utils import open_onnx_model @@ -431,7 +432,19 @@ def make_ui(self, default_model_name: Optional[str] = None, default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7): _check_gradio_env() model_list = self.model_names - default_model_name = default_model_name or natsorted(self.model_names)[-1] + if not default_model_name: + hf_client = get_hf_client(hf_token=self._get_hf_token()) + selected_model_name, selected_time = None, None + for fileitem in hf_client.get_paths_info( + repo_id=self.repo_id, + repo_type='model', + paths=[f'{model_name}/model.onnx' for model_name in model_list], + expand=True, + ): + if not selected_time or fileitem.last_commit.date > selected_time: + selected_model_name = os.path.dirname(fileitem.path) + selected_time = fileitem.last_commit.date + default_model_name = selected_model_name def _gr_detect(image: ImageTyping, model_name: str, iou_threshold: float = 0.7, score_threshold: float = 0.25) \ @@ -479,16 +492,25 @@ def _gr_detect(image: ImageTyping, model_name: str, def launch_demo(self, default_model_name: Optional[str] = None, default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7, - server_port: int = 7860, **kwargs): + server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): _check_gradio_env() with gr.Blocks() as demo: - self.make_ui( - default_model_name=default_model_name, - default_conf_threshold=default_conf_threshold, - default_iou_threshold=default_iou_threshold, - ) + with gr.Row(): + with gr.Column(): + repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model') + gr.HTML(f'

YOLO Demo For {self.repo_id}

') + gr.Markdown(f'This is the quick demo for YOLO model [{self.repo_id}]({repo_url}). ' + f'Powered by `dghs-imgutils`\'s quick demo module.') + + with gr.Row(): + self.make_ui( + default_model_name=default_model_name, + default_conf_threshold=default_conf_threshold, + default_iou_threshold=default_iou_threshold, + ) demo.launch( + server_name=server_name, server_port=server_port, **kwargs, ) From 31a1d232a3d0e61b86a6d8341b8749384a887bdc Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 18 Sep 2024 19:27:09 +0800 Subject: [PATCH 3/5] dev(narugo): add docs --- docs/source/api_doc/generic/yolo.rst | 2 +- imgutils/generic/yolo.py | 71 ++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/docs/source/api_doc/generic/yolo.rst b/docs/source/api_doc/generic/yolo.rst index cd8882ab120..0097cfa7066 100644 --- a/docs/source/api_doc/generic/yolo.rst +++ b/docs/source/api_doc/generic/yolo.rst @@ -11,7 +11,7 @@ YOLOModel ----------------------------------------- .. autoclass:: YOLOModel - :members: __init__, predict, clear + :members: __init__, predict, clear, make_ui, launch_demo diff --git a/imgutils/generic/yolo.py b/imgutils/generic/yolo.py index 12392d88dcd..af0c5a34479 100644 --- a/imgutils/generic/yolo.py +++ b/imgutils/generic/yolo.py @@ -41,16 +41,37 @@ def _check_gradio_env(): + """ + Check if the Gradio library is installed and available. + + :raises EnvironmentError: If Gradio is not installed. + """ if gr is None: raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' f'Please install it with `pip install dghs-imgutils[demo]`.') def _v_fix(v): + """ + Round and convert a float value to an integer. + + :param v: The float value to be rounded and converted. + :type v: float + :return: The rounded integer value. + :rtype: int + """ return int(round(v)) def _bbox_fix(bbox): + """ + Fix the bounding box coordinates by rounding them to integers. + + :param bbox: The bounding box coordinates. + :type bbox: tuple + :return: A tuple of fixed (rounded to integer) bounding box coordinates. + :rtype: tuple + """ return tuple(map(_v_fix, bbox)) @@ -425,11 +446,36 @@ def predict(self, image: ImageTyping, model_name: str, def clear(self): """ Clear cached model and metadata. + + This method removes all cached models and their associated metadata from memory. + It's useful for freeing up memory or ensuring that the latest versions of models are loaded. """ self._models.clear() def make_ui(self, default_model_name: Optional[str] = None, default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7): + """ + Create a Gradio-based user interface for object detection. + + This method sets up an interactive UI that allows users to upload images, + select models, and adjust detection parameters. It uses the Gradio library + to create the interface. + + :param default_model_name: The name of the default model to use. + If None, the most recently updated model is selected. + :type default_model_name: Optional[str] + :param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25. + :type default_conf_threshold: float + :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7. + :type default_iou_threshold: float + + :raises ImportError: If Gradio is not installed in the environment. + + :Example: + + >>> model = YOLOModel("username/repo_name") + >>> model.make_ui(default_model_name="yolov5s") + """ _check_gradio_env() model_list = self.model_names if not default_model_name: @@ -493,6 +539,31 @@ def _gr_detect(image: ImageTyping, model_name: str, def launch_demo(self, default_model_name: Optional[str] = None, default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): + """ + Launch a Gradio demo for object detection. + + This method creates and launches a Gradio demo that allows users to interactively + perform object detection on uploaded images using the YOLO model. + + :param default_model_name: The name of the default model to use. + If None, the most recently updated model is selected. + :type default_model_name: Optional[str] + :param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25. + :type default_conf_threshold: float + :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7. + :type default_iou_threshold: float + :param server_name: The name of the server to run the demo on. Default is None. + :type server_name: Optional[str] + :param server_port: The port to run the demo on. Default is None. + :type server_port: Optional[int] + :param kwargs: Additional keyword arguments to pass to gr.Blocks.launch(). + + :raises EnvironmentError: If Gradio is not installed in the environment. + + Example: + >>> model = YOLOModel("username/repo_name") + >>> model.launch_demo(default_model_name="yolov5s", server_name="0.0.0.0", server_port=7860) + """ _check_gradio_env() with gr.Blocks() as demo: with gr.Row(): From 50cd29173b7e81b8c50cff512d64717cf8912a68 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 18 Sep 2024 19:37:46 +0800 Subject: [PATCH 4/5] dev(narugo): add gradio demo for classifiers --- docs/source/api_doc/generic/classify.rst | 2 +- imgutils/generic/classify.py | 73 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/docs/source/api_doc/generic/classify.rst b/docs/source/api_doc/generic/classify.rst index b3aeafd8114..d99ecd81087 100644 --- a/docs/source/api_doc/generic/classify.rst +++ b/docs/source/api_doc/generic/classify.rst @@ -11,7 +11,7 @@ ClassifyModel ----------------------------------------- .. autoclass:: ClassifyModel - :members: __init__, predict_score, predict, clear + :members: __init__, predict_score, predict, clear, make_ui, launch_demo diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py index 57f73784341..432c2cc2253 100644 --- a/imgutils/generic/classify.py +++ b/imgutils/generic/classify.py @@ -23,12 +23,19 @@ import numpy as np from PIL import Image +from hfutils.operate import get_hf_client +from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import hf_hub_download, HfFileSystem from ..data import rgb_encode, ImageTyping, load_image from ..utils import open_onnx_model +try: + import gradio as gr +except (ImportError, ModuleNotFoundError): + gr = None + __all__ = [ 'ClassifyModel', 'classify_predict_score', @@ -36,6 +43,17 @@ ] +def _check_gradio_env(): + """ + Check if the Gradio library is installed and available. + + :raises EnvironmentError: If Gradio is not installed. + """ + if gr is None: + raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' + f'Please install it with `pip install dghs-imgutils[demo]`.') + + def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): """ @@ -287,6 +305,61 @@ def clear(self): self._models.clear() self._labels.clear() + def make_ui(self, default_model_name: Optional[str] = None): + _check_gradio_env() + model_list = self.model_names + if not default_model_name: + hf_client = get_hf_client(hf_token=self._get_hf_token()) + selected_model_name, selected_time = None, None + for fileitem in hf_client.get_paths_info( + repo_id=self.repo_id, + repo_type='model', + paths=[f'{model_name}/model.onnx' for model_name in model_list], + expand=True, + ): + if not selected_time or fileitem.last_commit.date > selected_time: + selected_model_name = os.path.dirname(fileitem.path) + selected_time = fileitem.last_commit.date + default_model_name = selected_model_name + + with gr.Row(): + with gr.Column(): + gr_input_image = gr.Image(type='pil', label='Original Image') + gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') + gr_submit = gr.Button(value='Submit', variant='primary') + + with gr.Column(): + gr_output = gr.Label(label='Prediction') + + gr_submit.click( + self.predict_score, + inputs=[ + gr_input_image, + gr_model, + ], + outputs=[gr_output], + ) + + def launch_demo(self, default_model_name: Optional[str] = None, + server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): + _check_gradio_env() + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model') + gr.HTML(f'

Classifier Demo For {self.repo_id}

') + gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). ' + f'Powered by `dghs-imgutils`\'s quick demo module.') + + with gr.Row(): + self.make_ui(default_model_name=default_model_name) + + demo.launch( + server_name=server_name, + server_port=server_port, + **kwargs, + ) + @lru_cache() def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel: From dbdba3521ad112087b24ab268fb18e28d10ce0ab Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 18 Sep 2024 19:42:00 +0800 Subject: [PATCH 5/5] dev(narugo): add docs for demo --- imgutils/generic/classify.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py index 432c2cc2253..3b3a04485d4 100644 --- a/imgutils/generic/classify.py +++ b/imgutils/generic/classify.py @@ -306,6 +306,24 @@ def clear(self): self._labels.clear() def make_ui(self, default_model_name: Optional[str] = None): + """ + Create the user interface components for the classifier model demo. + + This method sets up the Gradio UI components including an image input, model selection dropdown, + submit button, and output label. It also configures the interaction between these components. + + :param default_model_name: The name of the default model to be selected in the dropdown. + If None, the most recently updated model will be selected. + :type default_model_name: Optional[str] + + :raises ImportError: If Gradio is not installed or properly configured. + + :Example: + >>> model = ClassifyModel("username/repo_name") + >>> model.make_ui(default_model_name="model_v1") + """ + + # demo for classifier model _check_gradio_env() model_list = self.model_names if not default_model_name: @@ -342,6 +360,27 @@ def make_ui(self, default_model_name: Optional[str] = None): def launch_demo(self, default_model_name: Optional[str] = None, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): + """ + Launch the Gradio demo for the classifier model. + + This method creates a Gradio Blocks interface, sets up the UI components using make_ui(), + and launches the demo server. + + :param default_model_name: The name of the default model to be selected in the dropdown. + :type default_model_name: Optional[str] + :param server_name: The name of the server to run the demo on. Defaults to None. + :type server_name: Optional[str] + :param server_port: The port number to run the demo on. Defaults to None. + :type server_port: Optional[int] + :param kwargs: Additional keyword arguments to pass to the Gradio launch method. + + :raises ImportError: If Gradio is not installed or properly configured. + + :Example: + >>> model = ClassifyModel("username/repo_name") + >>> model.launch_demo(default_model_name="model_v1", server_name="0.0.0.0", server_port=7860) + """ + _check_gradio_env() with gr.Blocks() as demo: with gr.Row():