-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
57 lines (46 loc) · 1.84 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch, torchaudio, torchvision
import os
import gradio as gr
import numpy as np
from preprocess import process_audio_data, process_image_data
from train import WatermelonModel
from infer import infer
def load_model(model_path):
global device
device = torch.device(
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
model = WatermelonModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
return model
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
args = parser.parse_args()
model = load_model(args.model_path)
def predict(audio, image):
audio, sr = audio[-1], audio[0]
audio = np.transpose(audio)
audio = torch.tensor(audio).float()
mfcc = process_audio_data(audio, sr).to(device)
img = torch.tensor(image).float()
img = process_image_data(img).to(device)
if mfcc is not None and img is not None:
sweetness = infer(mfcc, img)
return sweetness.item()
return None
audio_input = gr.Audio(label="Upload or Record Audio")
image_input = gr.Image(label="Upload or Capture Image")
output = gr.Textbox(label="Predicted Sweetness")
interface = gr.Interface(
fn=predict,
inputs=[audio_input, image_input],
outputs=output,
title="Watermelon Sweetness Predictor",
description="Upload an audio file and an image to predict the sweetness of a watermelon."
)
interface.launch()