-
Notifications
You must be signed in to change notification settings - Fork 5
/
app.py
44 lines (40 loc) · 1.41 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
import gcvit
from gcvit.utils import get_gradcam_model, get_gradcam_prediction
def predict_fn(image, model_name):
"""A predict function that will be invoked by gradio."""
model = getattr(gcvit, model_name)(pretrain=True)
gradcam_model = get_gradcam_model(model)
preds, overlay = get_gradcam_prediction(
image, gradcam_model, cmap="jet", alpha=0.4, pred_index=None
)
preds = {x[1]: float(x[2]) for x in preds}
return [preds, overlay]
demo = gr.Interface(
fn=predict_fn,
inputs=[
gr.inputs.Image(label="Input Image"),
gr.Radio(
["GCViTTiny", "GCViTSmall", "GCViTBase"],
value="GCViTTiny",
label="Model Size",
),
],
outputs=[
gr.outputs.Label(label="Prediction"),
gr.inputs.Image(label="GradCAM"),
],
title="Global Context Vision Transformer (GCViT) Demo",
description="Image Classification with ImageNet Pretrain Models.",
examples=[
["example/hot_air_ballon.jpg", "GCViTTiny"],
["example/chelsea.png", "GCViTTiny"],
["example/german_shepherd.jpg", "GCViTTiny"],
["example/panda.jpg", "GCViTTiny"],
["example/jellyfish.jpg", "GCViTTiny"],
["example/penguin.JPG", "GCViTTiny"],
["example/bus.jpg", "GCViTTiny"],
["example/cat_dog.JPG", "GCViTTiny"],
],
)
demo.launch()