Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
nikp1172 committed Feb 19, 2024
1 parent 41130fc commit f0696a4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion mnist-classifaction/deploy_model/fastapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydantic import BaseModel
from predict import predict_fn, load_model
import tensorflow as tf
import numpy as np
import os

model_path = os.path.join(os.environ.get("MODEL_DOWNLOAD_PATH", "."), "mnist_model.h5")
Expand All @@ -14,7 +15,7 @@ class ImageUrl(BaseModel):
url: str = "https://conx.readthedocs.io/en/latest/_images/MNIST_6_0.png"


def load_image(img_url: str):
def load_image(img_url: str) -> np.ndarray:
img_path = tf.keras.utils.get_file("image.jpg", img_url)
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(28, 28))
img_arr = tf.keras.preprocessing.image.img_to_array(img)
Expand Down
4 changes: 2 additions & 2 deletions mnist-classifaction/deploy_model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import numpy as np


def load_model(model_path: str):
def load_model(model_path: str) -> tf.keras.Model:
# Load the trained model
model = tf.keras.models.load_model(model_path)
return model


def predict_fn(model, img_arr: np.ndarray):
def predict_fn(model, img_arr: np.ndarray) -> str:
# Preprocess the image before passing it to the model
img_arr = tf.expand_dims(img_arr, 0)
img_arr = img_arr[:, :, :, 0] # Keep only the first channel (grayscale)
Expand Down

0 comments on commit f0696a4

Please sign in to comment.