Skip to content

Commit

Permalink
Merge pull request #7 from hiseulgi/6-train-deep-learning-model-and-d…
Browse files Browse the repository at this point in the history
…eployment

Train Deep Learning Model and Deployment
  • Loading branch information
hiseulgi authored Dec 28, 2023
2 parents 11389ca + 34529bd commit 3122314
Show file tree
Hide file tree
Showing 31 changed files with 513 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
API_PORT=6969
WEB_PORT=8051
WEB_PORT=8501
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Medical Leaf Image Classification

![Web App Demo](asset/00_web_app_demo.gif)

Unofficial implementation of [*Mengenali Jenis Tanaman Obat Berbasis Pola Citra Daun Dengan Algoritma K-Nearest Neighbors*](https://ejournal.unesa.ac.id/index.php/jinacs/article/download/42746/36728) (**Recognizing Types of Medicinal Plants Based on Leaf Image Patterns with K-Nearest Neighbors Algorithm**). This project is a part of my final project in Image Processing course.

![Web App Architecture](asset/05_web_app_arch.jpg)

## Dataset

The dataset used in this project is [Medical Leaf Image Dataset](https://data.mendeley.com/datasets/3f83gxmv57/1).
Expand Down Expand Up @@ -94,12 +98,17 @@ bash scripts/run_docker.sh
4. Open and test the service at API docs `http://localhost:6969/`
![API Docs Swagger UI](asset/02_fastapi_docs.png)

5. Open and test the service at Web App `http://localhost:8051/`
5. Open and test the service at Web App `http://localhost:8501/`
![Streamlit Web App](asset/03_web_app.png)

## Extra (Deep Learning Model)

According to KNN and other machine learning model result, I think the problem is in the dataset. So, I tried to train with deep learning model. I used MobileNetV3 as the base model and trained with transfer learning. The result was better than KNN and other machine learning model.
![Deep Learning Model Result](asset/04_mobilenetv3_result.png)
Here the training notebook: [Medical Leaf Image Classification (Deep Learning)](https://colab.research.google.com/drive/1-YK-djfIu3LtHOH6UiUHG7oScG-BzU0h?usp=sharing)

## Future Works

* [x] Deployment API
* [x] Web App Deployment (Streamlit / Gradio)
* [ ] Train with Deep Learning
* [x] Train with Deep Learning
Binary file added asset/00_web_app_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/04_mobilenetv3_result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/05_web_app_arch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 32 additions & 32 deletions src/api/core/knn_core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import pickle
from io import BytesIO
from typing import Dict, List
from pathlib import Path
from typing import Dict

import numpy as np
import rootutils
from PIL import Image
from skimage.color import rgb2gray
from skimage.filters import median, threshold_otsu
from skimage.measure import label, regionprops_table
Expand All @@ -27,35 +26,39 @@


class KnnCore:
def __init__(self):
def __init__(
self,
model_path: str = str(
ROOT / "src" / "api" / "static" / "model" / "knn_model.pkl"
),
scaler_path: str = str(
ROOT / "src" / "api" / "static" / "model" / "scaler.pkl"
),
class_mapping_path: str = str(
ROOT / "src" / "api" / "static" / "class_mapping.json"
),
) -> None:
"""Initialize KNN Core"""

self.model_path = Path(model_path)
self.scaler_path = Path(scaler_path)
self.class_mapping_path = Path(class_mapping_path)
self._setup()

def _setup(self):
def _setup(self) -> None:
"""Setup KNN Core"""
# load scaler
with open(ROOT / "src" / "api" / "static" / "model" / "scaler.pkl", "rb") as f:
with open(self.scaler_path, "rb") as f:
self.scaler: MinMaxScaler = pickle.load(f)

# load model
with open(
ROOT / "src" / "api" / "static" / "model" / "knn_model.pkl", "rb"
) as f:
with open(self.model_path, "rb") as f:
self.model: KNeighborsClassifier = pickle.load(f)

# load class mapping from json
with open(ROOT / "src" / "api" / "static" / "class_mapping.json", "r") as f:
with open(self.class_mapping_path, "r") as f:
self.class_mapping: Dict[str, str] = json.load(f)

async def preprocess_img_bytes(self, img_bytes: bytes) -> np.ndarray:
"""Preprocess image bytes."""
img = Image.open(BytesIO(img_bytes))
img = np.array(img)
# if PNG, convert to RGB
if img.shape[-1] == 4:
img = img[..., :3]

return img

async def preprocess_img_knn(self, img_np: np.ndarray) -> np.ndarray:
"""Preprocess image for KNN."""
# read as grayscale
Expand Down Expand Up @@ -92,12 +95,9 @@ async def feature_extraction_knn(self, img_np: np.ndarray) -> np.ndarray:
props_np = np.array(list(props.values())).reshape(1, -1)
return props_np

async def predict(self, img_bytes: bytes) -> List[PredictionsResultSchema]:
async def predict(self, img_np: np.ndarray) -> PredictionsResultSchema:
"""Predict using KNN model."""

# read image as numpy array
img_np = await self.preprocess_img_bytes(img_bytes)

# preprocess image
img_np = await self.preprocess_img_knn(img_np)

Expand All @@ -111,16 +111,16 @@ async def predict(self, img_bytes: bytes) -> List[PredictionsResultSchema]:
predictions = self.model.predict_proba(features.reshape(1, -1))

# get result
result: List[PredictionsResultSchema] = []
labels = []
scores = []
top_5_pred = np.argsort(predictions, axis=1)[0, -5:][::-1]

for idx in top_5_pred:
result.append(
PredictionsResultSchema(
label=self.class_mapping[str(idx)],
score=predictions[0, idx],
)
)
for i in top_5_pred:
labels.append(self.class_mapping[str(i)])
scores.append(float(predictions[0, i]))

result = PredictionsResultSchema(labels=labels, scores=scores)

log.info(f"Predictions: {result}")

return result
142 changes: 142 additions & 0 deletions src/api/core/mobilenet_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
from typing import List, Union

import cv2
import numpy as np
import rootutils

from src.api.core.onnx_core import OnnxCore
from src.api.schema.predictions_schema import PredictionsResultSchema
from src.api.utils.logger import get_logger

ROOT = rootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)

log = get_logger()


class MobilenetCore(OnnxCore):
"""Mobilenet Core runtime engine module"""

def __init__(
self,
engine_path: str = str(ROOT / "src/api/static/model/mobilenetv3_best.onnx"),
class_path: str = str(ROOT / "src/api/static/class_mapping.json"),
provider: str = "cpu",
) -> None:
"""
Initialize Mobilenet Core runtime engine module.
Args:
engine_path (str): Path to ONNX runtime engine file.
class_path (str): Path to class mapping json file.
provider (str): Provider for ONNX runtime engine.
"""
super().__init__(engine_path, provider)
self.class_path = class_path
self._open_class_mapping()

def _open_class_mapping(self) -> None:
"""Open class mapping json file."""
with open(self.class_path, "r") as f:
self.class_mapping = json.load(f)

def predict(
self, imgs: Union[np.ndarray, List[np.ndarray]]
) -> List[PredictionsResultSchema]:
"""
Classify image(s) (batch) and return top 5 predictions.
Args:
imgs (np.ndarray): Input image.
Returns:
List[PredictionsResultSchema]: List of predictions result, in size (Batch, 5).
"""
if isinstance(imgs, np.ndarray):
imgs = [imgs]

imgs = self.preprocess_imgs(imgs)
outputs = self.engine.run(None, {self.metadata[0].input_name: imgs})
outputs = self.postprocess_imgs(outputs)
return outputs

def preprocess_imgs(
self,
imgs: Union[np.ndarray, List[np.ndarray]],
normalize: bool = False,
) -> np.ndarray:
"""
Preprocess image(s) (batch) like resize and normalize.
Args:
imgs (Union[np.ndarray, List[np.ndarray]]): Image(s) to preprocess.
normalize (bool, optional): Whether to normalize image(s). Defaults to True.
Returns:
np.ndarray: Preprocessed image(s) in size (B, C, H, W).
"""
if isinstance(imgs, np.ndarray):
imgs = [imgs]

# resize images
dst_h, dst_w = self.img_shape
resized_imgs = np.zeros((len(imgs), dst_h, dst_w, 3), dtype=np.float32)

for i, img in enumerate(imgs):
# resize img to 224x224 (according to model input)
img = cv2.resize(img, dsize=(dst_h, dst_w), interpolation=cv2.INTER_CUBIC)
resized_imgs[i] = img

# normalize images
# resized_imgs = resized_imgs.transpose(0, 3, 1, 2)
resized_imgs /= 255.0 if normalize else 1.0

return resized_imgs

def postprocess_imgs(
self, outputs: List[np.ndarray]
) -> List[PredictionsResultSchema]:
"""
Postprocess model output(s) into top 5 predictions probability.
Args:
outputs (List[np.ndarray]): Model output(s) (batch), in size (Batch, Class).
Returns:
List[PredictionsResultSchema]: List of predictions result, in size (Batch, 5).
"""
results: List[PredictionsResultSchema] = []
for output in outputs:
softmax_output = self.softmax(output[0])

labels = []
scores = []
top_5_pred = np.argsort(softmax_output)[::-1][:5]

for i in top_5_pred:
labels.append(self.class_mapping[str(i)])
scores.append(float(softmax_output[i]))

results.append(PredictionsResultSchema(labels=labels, scores=scores))

log.info(f"Predictions: {results}")

return results

def softmax(self, x: np.ndarray) -> np.ndarray:
"""
Compute softmax values for each sets of scores in x.
Args:
x (np.ndarray): Input logits.
Returns:
np.ndarray: Softmax calculation result.
"""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
91 changes: 91 additions & 0 deletions src/api/core/onnx_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from pathlib import Path
from typing import List, Union

import onnxruntime as ort
import rootutils

from src.api.schema.onnx_schema import OnnxMetadataSchema
from src.api.utils.logger import get_logger

ROOT = rootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)

log = get_logger()


class OnnxCore:
"""Common ONNX runtime engine module."""

def __init__(self, engine_path: str, provider: str = "cpu") -> None:
"""
Initialize ONNX runtime common engine.
Args:
engine_path (str): Path to ONNX runtime engine file.
provider (str): Provider for ONNX runtime engine.
"""
self.engine_path = Path(engine_path)
self.provider = provider
self.provider = self.check_providers(provider)

def setup(self) -> None:
"""Setup ONNX runtime engine."""
log.info(f"Setup ONNX engine")
self.engine = ort.InferenceSession(
str(self.engine_path), providers=self.provider
)
self.metadata = self.get_metadata()

# img_shape tergantung pada file onnx-nya (lihat di netron)
self.img_shape = self.metadata[0].input_shape[1:3]

log.info(f"ONNX engine is ready!")

def get_metadata(self) -> List[OnnxMetadataSchema]:
"""
Get model metadata.
Returns:
List[OnnxMetadataSchema]: List of model metadata.
"""
inputs = self.engine.get_inputs()
outputs = self.engine.get_outputs()

result: List[OnnxMetadataSchema] = []
for inp, out in zip(inputs, outputs):
result.append(
OnnxMetadataSchema(
input_name=inp.name,
input_shape=inp.shape,
output_name=out.name,
output_shape=out.shape,
)
)

return result

def check_providers(self, provider: Union[str, List]) -> List:
"""
Check available providers. If provider is not available, use CPU instead.
Args:
provider (Union[str, List]): Provider for ONNX runtime engine.
Returns:
List: List of available providers.
"""
assert provider in ["cpu", "gpu"], "Invalid provider"
available_providers = ort.get_available_providers()
log.debug(f"Available providers: {available_providers}")
if provider == "cpu" and "OpenVINOExecutionProvider" in available_providers:
provider = ["CPUExecutionProvider", "OpenVINOExecutionProvider"]
elif provider == "gpu" and "CUDAExecutionProvider" in available_providers:
provider = ["CUDAExecutionProvider"]
else:
provider = ["CPUExecutionProvider"]

return provider
2 changes: 2 additions & 0 deletions src/api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
colorlog==6.8.0
fastapi==0.105.0
numpy==1.26.2
onnxruntime==1.15.1
opencv-python-headless==4.8.*
pandas==2.1.4
Pillow==10.0.1
pydantic==1.10.13
Expand Down
Loading

0 comments on commit 3122314

Please sign in to comment.