-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from truefoundry/np-add-service-mnist
add service deployment mnist
- Loading branch information
Showing
11 changed files
with
494 additions
and
35 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import argparse | ||
import logging | ||
from servicefoundry import ( | ||
Build, | ||
PythonBuild, | ||
Service, | ||
Resources, | ||
Port, | ||
ArtifactsDownload, | ||
TruefoundryArtifactSource, | ||
) | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--workspace_fqn", type=str, required=True) | ||
parser.add_argument("--model_version_fqn", type=str, required=True) | ||
parser.add_argument("--host", type=str, required=True) | ||
parser.add_argument("--path", type=str, required=False) | ||
args = parser.parse_args() | ||
|
||
service = Service( | ||
name="mnist-classification-svc", | ||
image=Build( | ||
build_spec=PythonBuild( | ||
command="python gradio_demo.py", | ||
requirements_path="requirements.txt", | ||
) | ||
), | ||
ports=[Port(port=8000, host=args.host, path=args.path)], | ||
resources=Resources( | ||
memory_limit=500, | ||
memory_request=500, | ||
ephemeral_storage_limit=600, | ||
ephemeral_storage_request=600, | ||
cpu_limit=0.3, | ||
cpu_request=0.3, | ||
), | ||
artifacts_download=ArtifactsDownload( | ||
artifacts=[ | ||
TruefoundryArtifactSource( | ||
artifact_version_fqn=args.model_version_fqn, | ||
download_path_env_variable="MODEL_DOWNLOAD_PATH", | ||
) | ||
] | ||
), | ||
) | ||
service.deploy(workspace_fqn=args.workspace_fqn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from fastapi import FastAPI | ||
from pydantic import BaseModel | ||
from predict import predict_fn, load_model | ||
import tensorflow as tf | ||
import numpy as np | ||
import os | ||
|
||
model_path = os.path.join(os.environ.get("MODEL_DOWNLOAD_PATH", "."), "mnist_model.h5") | ||
model = load_model(model_path) | ||
|
||
app = FastAPI(docs_url="/", root_path=os.getenv("TFY_SERVICE_ROOT_PATH", "/")) | ||
|
||
|
||
class ImageUrl(BaseModel): | ||
url: str = "https://conx.readthedocs.io/en/latest/_images/MNIST_6_0.png" | ||
|
||
|
||
def load_image(img_url: str) -> np.ndarray: | ||
img_path = tf.keras.utils.get_file("image.jpg", img_url) | ||
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(28, 28)) | ||
img_arr = tf.keras.preprocessing.image.img_to_array(img) | ||
return img_arr | ||
|
||
|
||
app = FastAPI() | ||
|
||
|
||
@app.post("/predict") | ||
async def predict(body: ImageUrl): | ||
img_arr = load_image(body.url) | ||
prediction = predict_fn(model, img_arr) | ||
return {"prediction": prediction} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from predict import predict_fn, load_model | ||
import os | ||
import gradio as gr | ||
|
||
model_path = os.path.join(os.environ.get("MODEL_DOWNLOAD_PATH", "."), "mnist_model.h5") | ||
model = load_model(model_path) | ||
|
||
|
||
def get_inference(img_arr): | ||
return predict_fn(model, img_arr) | ||
|
||
|
||
interface = gr.Interface( | ||
fn=get_inference, | ||
inputs="image", | ||
outputs="label", | ||
examples=[["sample_images/0.jpg"], ["sample_images/1.jpg"]], | ||
) | ||
|
||
interface.launch(server_name="0.0.0.0", server_port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
|
||
|
||
def load_model(model_path: str) -> tf.keras.Model: | ||
# Load the trained model | ||
model = tf.keras.models.load_model(model_path) | ||
return model | ||
|
||
|
||
def predict_fn(model, img_arr: np.ndarray) -> str: | ||
# Preprocess the image before passing it to the model | ||
img_arr = tf.expand_dims(img_arr, 0) | ||
img_arr = img_arr[:, :, :, 0] # Keep only the first channel (grayscale) | ||
|
||
# Make predictions | ||
predictions = model.predict(img_arr) | ||
predicted_label = tf.argmax(predictions[0]).numpy() | ||
|
||
return str(predicted_label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
tensorflow==2.15.0 | ||
gradio==3.39.0 | ||
fastapi==0.89.1 |
File renamed without changes
File renamed without changes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters