Skip to content

Commit

Permalink
dev(narugo): optimize ui
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Sep 17, 2024
1 parent 8a4098a commit 2764089
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
from natsort import natsorted

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model
Expand Down Expand Up @@ -431,7 +432,19 @@ def make_ui(self, default_model_name: Optional[str] = None,
default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
_check_gradio_env()
model_list = self.model_names
default_model_name = default_model_name or natsorted(self.model_names)[-1]
if not default_model_name:
hf_client = get_hf_client(hf_token=self._get_hf_token())
selected_model_name, selected_time = None, None
for fileitem in hf_client.get_paths_info(
repo_id=self.repo_id,
repo_type='model',
paths=[f'{model_name}/model.onnx' for model_name in model_list],
expand=True,
):
if not selected_time or fileitem.last_commit.date > selected_time:
selected_model_name = os.path.dirname(fileitem.path)
selected_time = fileitem.last_commit.date
default_model_name = selected_model_name

def _gr_detect(image: ImageTyping, model_name: str,
iou_threshold: float = 0.7, score_threshold: float = 0.25) \
Expand Down Expand Up @@ -479,16 +492,25 @@ def _gr_detect(image: ImageTyping, model_name: str,

def launch_demo(self, default_model_name: Optional[str] = None,
default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
server_port: int = 7860, **kwargs):
server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
_check_gradio_env()
with gr.Blocks() as demo:
self.make_ui(
default_model_name=default_model_name,
default_conf_threshold=default_conf_threshold,
default_iou_threshold=default_iou_threshold,
)
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
gr.HTML(f'<h2 style="text-align: center;">YOLO Demo For {self.repo_id}</h2>')
gr.Markdown(f'This is the quick demo for YOLO model [{self.repo_id}]({repo_url}). '
f'Powered by `dghs-imgutils`\'s quick demo module.')

with gr.Row():
self.make_ui(
default_model_name=default_model_name,
default_conf_threshold=default_conf_threshold,
default_iou_threshold=default_iou_threshold,
)

demo.launch(
server_name=server_name,
server_port=server_port,
**kwargs,
)
Expand Down

0 comments on commit 2764089

Please sign in to comment.