Skip to content

Commit

Permalink
Merge pull request #89 from deepghs/dev/cdc
Browse files Browse the repository at this point in the history
dev(narugo): better cdc model enhancement layer
  • Loading branch information
narugo1992 authored May 10, 2024
2 parents f72c4f4 + cfbad99 commit 0d41d13
Show file tree
Hide file tree
Showing 14 changed files with 463 additions and 212 deletions.
31 changes: 31 additions & 0 deletions docs/source/api_doc/generic/classify.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
imgutils.generic.classify
=======================================

.. currentmodule:: imgutils.generic.classify

.. automodule:: imgutils.generic.classify



ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
:members: __init__, predict_score, predict, clear



classify_predict_score
-----------------------------------------

.. autofunction:: classify_predict_score



classify_predict
-----------------------------------------

.. autofunction:: classify_predict



17 changes: 17 additions & 0 deletions docs/source/api_doc/generic/enhance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
imgutils.generic.enhance
=======================================

.. currentmodule:: imgutils.generic.enhance

.. automodule:: imgutils.generic.enhance



ImageEnhancer
-----------------------------------------

.. autoclass:: ImageEnhancer
:members: __init__, process



13 changes: 13 additions & 0 deletions docs/source/api_doc/generic/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
imgutils.generic
=====================

.. currentmodule:: imgutils.generic

.. automodule:: imgutils.generic


.. toctree::
:maxdepth: 3

classify
enhance
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ configuration file's structure and their versions.
api_doc/data/index
api_doc/detect/index
api_doc/edge/index
api_doc/generic/index
api_doc/metrics/index
api_doc/ocr/index
api_doc/operate/index
Expand Down
5 changes: 5 additions & 0 deletions imgutils/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
"""
Overview:
Generic utilities for some more features.
"""
from .classify import *
from .enhance import *
146 changes: 146 additions & 0 deletions imgutils/generic/classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Overview:
Generic tools for classification models.
"""
import json
import os
from functools import lru_cache
Expand All @@ -19,6 +23,19 @@

def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Encode an image into a numpy array.
:param image: The input image.
:type image: Image.Image
:param size: The size to resize the image to, defaults to (384, 384).
:type size: Tuple[int, int], optional
:param normalize: The mean and standard deviation for normalization, defaults to (0.5, 0.5).
:type normalize: Optional[Tuple[float, float]], optional
:return: The encoded image as a numpy array.
:rtype: np.ndarray
"""
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')

Expand All @@ -32,18 +49,50 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),


class ClassifyModel:
"""
Class for managing classification models.
This class provides methods for loading classification models, predicting scores, and predictions.
Methods:
predict_score: Predicts the scores for each class.
predict: Predicts the class with the highest score.
clear: Clears the loaded models and labels.
Attributes:
None
"""

def __init__(self, repo_id: str):
"""
Initialize the ClassifyModel instance.
:param repo_id: The repository ID containing the models.
:type repo_id: str
"""
self.repo_id = repo_id
self._model_names = None
self._models = {}
self._labels = {}

@classmethod
def _get_hf_token(cls):
"""
Get the Hugging Face token from the environment variable.
:return: The Hugging Face token.
:rtype: str
"""
return os.environ.get('HF_TOKEN')

@property
def model_names(self) -> List[str]:
"""
Get the model names available in the repository.
:return: The list of model names.
:rtype: List[str]
"""
if self._model_names is None:
hf_fs = HfFileSystem(token=self._get_hf_token())
self._model_names = [
Expand All @@ -54,11 +103,28 @@ def model_names(self) -> List[str]:
return self._model_names

def _check_model_name(self, model_name: str):
"""
Check if the model name is valid.
:param model_name: The name of the model.
:type model_name: str
:raises ValueError: If the model name is invalid.
"""
if model_name not in self.model_names:
raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
f'models {self.model_names!r} are available.')

def _open_model(self, model_name: str):
"""
Open the specified model.
:param model_name: The name of the model.
:type model_name: str
:return: The opened model.
:rtype: Any
"""
if model_name not in self._models:
self._check_model_name(model_name)
self._models[model_name] = open_onnx_model(hf_hub_download(
Expand All @@ -69,6 +135,15 @@ def _open_model(self, model_name: str):
return self._models[model_name]

def _open_label(self, model_name: str) -> List[str]:
"""
Open the labels file for the specified model.
:param model_name: The name of the model.
:type model_name: str
:return: The list of labels.
:rtype: List[str]
"""
if model_name not in self._labels:
self._check_model_name(model_name)
with open(hf_hub_download(
Expand All @@ -80,6 +155,17 @@ def _open_label(self, model_name: str) -> List[str]:
return self._labels[model_name]

def _raw_predict(self, image: ImageTyping, model_name: str):
"""
Make a raw prediction on the specified image using the specified model.
:param image: The input image.
:type image: ImageTyping
:param model_name: The name of the model.
:type model_name: str
:return: The raw prediction.
:rtype: np.ndarray
"""
image = load_image(image, force_background='white', mode='RGB')
model = self._open_model(model_name)
batch, channels, height, width = model.get_inputs()[0].shape
Expand All @@ -95,28 +181,88 @@ def _raw_predict(self, image: ImageTyping, model_name: str):
return output

def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]:
"""
Predict the scores for each class.
:param image: The input image.
:type image: ImageTyping
:param model_name: The name of the model.
:type model_name: str
:return: The dictionary containing class scores.
:rtype: Dict[str, float]
"""
output = self._raw_predict(image, model_name)
values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0])))
return values

def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]:
"""
Predict the class with the highest score.
:param image: The input image.
:type image: ImageTyping
:param model_name: The name of the model.
:type model_name: str
:return: The predicted class and its score.
:rtype: Tuple[str, float]
"""
output = self._raw_predict(image, model_name)[0]
max_id = np.argmax(output)
return self._open_label(model_name)[max_id], output[max_id].item()

def clear(self):
"""
Clear the loaded models and labels.
"""
self._models.clear()
self._labels.clear()


@lru_cache()
def _open_models_for_repo_id(repo_id: str) -> ClassifyModel:
"""
Open classification models for the specified repository ID.
:param repo_id: The repository ID containing the models.
:type repo_id: str
:return: The ClassifyModel instance for the repository.
:rtype: ClassifyModel
"""
return ClassifyModel(repo_id)


def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) -> Dict[str, float]:
"""
Predict the scores for each class using the specified model.
:param image: The input image.
:type image: ImageTyping
:param repo_id: The repository ID containing the models.
:type repo_id: str
:param model_name: The name of the model.
:type model_name: str
:return: The dictionary containing class scores.
:rtype: Dict[str, float]
"""
return _open_models_for_repo_id(repo_id).predict_score(image, model_name)


def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple[str, float]:
"""
Predict the class with the highest score using the specified model.
:param image: The input image.
:type image: ImageTyping
:param repo_id: The repository ID containing the models.
:type repo_id: str
:param model_name: The name of the model.
:type model_name: str
:return: The predicted class and its score.
:rtype: Tuple[str, float]
"""
return _open_models_for_repo_id(repo_id).predict(image, model_name)
Loading

0 comments on commit 0d41d13

Please sign in to comment.