-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsample_prediction.py
45 lines (34 loc) · 11.6 KB
/
sample_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# This is sample prediction template to predict using google AI Platform
import googleapiclient.discovery
project = 'tensorflow-deployment'
model = 'fashion_mnist_test'
instances = [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.050980392156862744, 0.2627450980392157, 0.0, 0.0, 0.0, 0.0, 0.19607843137254902, 0.14901960784313725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03137254901960784, 0.47058823529411764, 0.8196078431372549, 0.8862745098039215, 0.9686274509803922, 0.9294117647058824, 1.0, 1.0, 1.0, 0.9686274509803922, 0.9333333333333333, 0.9215686274509803, 0.6745098039215687, 0.2823529411764706, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5372549019607843, 0.9372549019607843, 0.9882352941176471, 0.9529411764705882, 0.9176470588235294, 0.8980392156862745, 0.9333333333333333, 0.9568627450980393, 0.9647058823529412, 0.9411764705882353, 0.9019607843137255, 0.9098039215686274, 0.9372549019607843, 0.9725490196078431, 0.984313725490196, 0.7607843137254902, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 1.0, 0.9058823529411765, 0.8941176470588236, 0.8901960784313725, 0.8941176470588236, 0.9137254901960784, 0.9019607843137255, 0.9019607843137255, 0.8980392156862745, 0.8941176470588236, 0.9098039215686274, 0.9098039215686274, 0.9058823529411765, 0.8901960784313725, 0.8784313725490196, 0.9882352941176471, 0.7019607843137254, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9137254901960784, 0.9450980392156862, 0.8980392156862745, 0.9058823529411765, 1.0, 1.0, 0.9333333333333333, 0.9058823529411765, 0.8901960784313725, 0.9333333333333333, 0.9647058823529412, 0.8941176470588236, 0.9019607843137255, 0.8901960784313725, 0.9176470588235294, 0.9215686274509803, 0.8980392156862745, 0.9450980392156862, 0.0784313725490196, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9725490196078431, 0.9450980392156862, 0.9058823529411765, 1.0, 0.5843137254901961, 0.1843137254901961, 0.9882352941176471, 0.8941176470588236, 1.0, 0.9490196078431372, 0.8470588235294118, 0.9333333333333333, 0.9098039215686274, 1.0, 0.8941176470588236, 0.8627450980392157, 0.9176470588235294, 0.9803921568627451, 0.21176470588235294, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.9411764705882353, 0.9098039215686274, 1.0, 0.058823529411764705, 0.0, 1.0, 0.9294117647058824, 0.7490196078431373, 0.0, 0.0, 0.8392156862745098, 1.0, 0.050980392156862744, 0.4823529411764706, 1.0, 0.9176470588235294, 0.9882352941176471, 0.4470588235294118, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.023529411764705882, 1.0, 0.9333333333333333, 0.9372549019607843, 1.0, 0.6941176470588235, 0.0, 1.0, 1.0, 0.0, 0.5098039215686274, 0.4549019607843137, 0.1843137254901961, 0.2549019607843137, 0.16862745098039217, 0.1450980392156863, 1.0, 0.9254901960784314, 0.9764705882352941, 0.6352941176470588, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.12549019607843137, 1.0, 0.9254901960784314, 0.9607843137254902, 1.0, 0.8, 0.0, 1.0, 0.32941176470588235, 0.0, 0.1450980392156863, 0.10980392156862745, 0.12156862745098039, 0.0, 0.09803921568627451, 0.050980392156862744, 1.0, 0.9254901960784314, 0.9764705882352941, 0.7803921568627451, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.20784313725490197, 1.0, 0.9254901960784314, 0.9803921568627451, 0.9803921568627451, 0.9058823529411765, 0.00784313725490196, 1.0, 0.08235294117647059, 0.0, 0.8666666666666667, 1.0, 0.9254901960784314, 0.21176470588235294, 0.9607843137254902, 0.7764705882352941, 0.9529411764705882, 0.9333333333333333, 0.9607843137254902, 0.8745098039215686, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.3137254901960784, 1.0, 0.9294117647058824, 0.9803921568627451, 0.9411764705882353, 1.0, 0.0, 0.0, 0.15294117647058825, 0.615686274509804, 0.0, 0.0, 0.8431372549019608, 0.3686274509803922, 0.0784313725490196, 0.49411764705882355, 1.0, 0.9294117647058824, 0.9372549019607843, 0.9803921568627451, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.396078431372549, 1.0, 0.9215686274509803, 0.9921568627450981, 0.9568627450980393, 0.9529411764705882, 0.5215686274509804, 0.5411764705882353, 0.8156862745098039, 1.0, 0.788235294117647, 0.8392156862745098, 1.0, 0.9019607843137255, 0.027450980392156862, 0.6823529411764706, 1.0, 0.9411764705882353, 0.9333333333333333, 1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.49411764705882355, 1.0, 0.9137254901960784, 1.0, 0.9725490196078431, 0.9137254901960784, 1.0, 1.0, 0.9411764705882353, 0.9098039215686274, 0.9529411764705882, 0.9529411764705882, 0.9058823529411765, 0.984313725490196, 1.0, 1.0, 0.996078431372549, 0.9529411764705882, 0.9333333333333333, 1.0, 0.011764705882352941, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.5764705882352941, 1.0, 0.9137254901960784, 0.9764705882352941, 0.7098039215686275, 0.9529411764705882, 0.8901960784313725, 0.8784313725490196, 0.9019607843137255, 0.9176470588235294, 0.9019607843137255, 0.9019607843137255, 0.9215686274509803, 0.8941176470588236, 0.9215686274509803, 0.8705882352941177, 0.8117647058823529, 1.0, 0.9254901960784314, 1.0, 0.13725490196078433, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.6392156862745098, 1.0, 0.9607843137254902, 0.8666666666666667, 0.33725490196078434, 1.0, 0.9137254901960784, 0.9137254901960784, 0.9215686274509803, 0.9254901960784314, 0.9176470588235294, 0.9176470588235294, 0.9176470588235294, 0.9098039215686274, 0.9490196078431372, 0.9058823529411765, 0.49019607843137253, 1.0, 0.9254901960784314, 1.0, 0.21568627450980393, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7098039215686275, 0.996078431372549, 1.0, 0.7843137254901961, 0.27058823529411763, 1.0, 0.8941176470588236, 0.9098039215686274, 0.9176470588235294, 0.9215686274509803, 0.9176470588235294, 0.9176470588235294, 0.9137254901960784, 0.9215686274509803, 0.9450980392156862, 0.9294117647058824, 0.27450980392156865, 1.0, 0.9215686274509803, 0.9647058823529412, 0.2235294117647059, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7725490196078432, 0.9686274509803922, 1.0, 0.7372549019607844, 0.43137254901960786, 1.0, 0.8784313725490196, 0.9137254901960784, 0.9176470588235294, 0.9176470588235294, 0.9176470588235294, 0.9176470588235294, 0.9176470588235294, 0.9176470588235294, 0.9411764705882353, 0.9921568627450981, 0.27058823529411763, 1.0, 0.9254901960784314, 0.9725490196078431, 0.30196078431372547, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7843137254901961, 0.9647058823529412, 1.0, 0.5843137254901961, 0.5686274509803921, 1.0, 0.8745098039215686, 0.9215686274509803, 0.9176470588235294, 0.9215686274509803, 0.9215686274509803, 0.9215686274509803, 0.9176470588235294, 0.9294117647058824, 0.9137254901960784, 1.0, 0.1843137254901961, 1.0, 0.9372549019607843, 0.9764705882352941, 0.3843137254901961, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.9529411764705882, 1.0, 0.43529411764705883, 0.6784313725490196, 1.0, 0.8901960784313725, 0.9215686274509803, 0.9215686274509803, 0.9254901960784314, 0.9215686274509803, 0.9215686274509803, 0.9215686274509803, 0.9372549019607843, 0.8980392156862745, 1.0, 0.07450980392156863, 0.8901960784313725, 0.9647058823529412, 0.9764705882352941, 0.43137254901960786, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7686274509803922, 0.9411764705882353, 1.0, 0.42745098039215684, 0.8352941176470589, 0.9803921568627451, 0.8980392156862745, 0.9215686274509803, 0.9215686274509803, 0.9254901960784314, 0.9215686274509803, 0.9294117647058824, 0.9254901960784314, 0.9294117647058824, 0.8862745098039215, 1.0, 0.21568627450980393, 0.796078431372549, 0.984313725490196, 0.9607843137254902, 0.47058823529411764, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7529411764705882, 0.9529411764705882, 1.0, 0.4470588235294118, 0.9098039215686274, 0.9411764705882353, 0.9098039215686274, 0.9215686274509803, 0.9215686274509803, 0.9254901960784314, 0.9176470588235294, 0.9294117647058824, 0.9254901960784314, 0.9215686274509803, 0.8980392156862745, 1.0, 0.5254901960784314, 0.6705882352941176, 0.9882352941176471, 0.9568627450980393, 0.5372549019607843, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7411764705882353, 0.984313725490196, 1.0, 0.6039215686274509, 0.9333333333333333, 0.9137254901960784, 0.9254901960784314, 0.9176470588235294, 0.9215686274509803, 0.9254901960784314, 0.9215686274509803, 0.9333333333333333, 0.9254901960784314, 0.9215686274509803, 0.9098039215686274, 1.0, 0.6509803921568628, 0.49019607843137253, 1.0, 0.9529411764705882, 0.5568627450980392, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7176470588235294, 0.9882352941176471, 1.0, 0.6705882352941176, 0.9686274509803922, 0.9098039215686274, 0.9176470588235294, 0.9176470588235294, 0.9137254901960784, 0.9137254901960784, 0.9098039215686274, 0.9176470588235294, 0.9137254901960784, 0.9176470588235294, 0.9137254901960784, 0.9411764705882353, 0.8745098039215686, 0.5019607843137255, 1.0, 0.9490196078431372, 0.592156862745098, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.6980392156862745, 0.9529411764705882, 1.0, 0.2235294117647059, 0.9333333333333333, 0.9450980392156862, 0.9333333333333333, 0.9333333333333333, 0.9333333333333333, 0.9294117647058824, 0.9254901960784314, 0.9294117647058824, 0.9294117647058824, 0.9411764705882353, 0.9294117647058824, 0.996078431372549, 0.6901960784313725, 0.20392156862745098, 1.0, 0.9372549019607843, 0.615686274509804, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.7372549019607844, 0.9411764705882353, 0.9803921568627451, 0.24313725490196078, 0.8549019607843137, 1.0, 0.8627450980392157, 0.8705882352941177, 0.8705882352941177, 0.8705882352941177, 0.8745098039215686, 0.8745098039215686, 0.8784313725490196, 0.8705882352941177, 0.8549019607843137, 1.0, 0.6039215686274509, 0.12549019607843137, 1.0, 0.9254901960784314, 0.7372549019607844, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.5098039215686274, 0.9607843137254902, 0.9490196078431372, 0.09411764705882353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13333333333333333, 0.9490196078431372, 0.9568627450980393, 0.5294117647058824, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.2980392156862745, 1.0, 0.9764705882352941, 0.08627450980392157, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15294117647058825, 0.9764705882352941, 1.0, 0.4823529411764706, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.19215686274509805, 0.803921568627451, 0.7725490196078432, 0.043137254901960784, 0.0, 0.01568627450980392, 0.00392156862745098, 0.00784313725490196, 0.00784313725490196, 0.00784313725490196, 0.00784313725490196, 0.00784313725490196, 0.00784313725490196, 0.00784313725490196, 0.00784313725490196, 0.011764705882352941, 0.0, 0.011764705882352941, 0.6823529411764706, 0.7411764705882353, 0.2627450980392157, 0.0, 0.0, 0.0]]]
def predict_json(project, model, instances, version=None):
"""Send json data to a deployed model for prediction.
Args:
project (str): project where the Cloud ML Engine Model is deployed.
model (str): model name.
instances ([Mapping[str: Any]]): Keys should be the names of Tensors
your deployed model expects as inputs. Values should be datatypes
convertible to Tensors, or (potentially nested) lists of datatypes
convertible to tensors.
version: str, version of the model to target.
Returns:
Mapping[str: any]: dictionary of prediction results defined by the
model.
"""
# Create the ML Engine service object.
# To authenticate set the environment variable
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
service = googleapiclient.discovery.build('ml', 'v1')
name = 'projects/{}/models/{}'.format(project, model)
if version is not None:
name += '/versions/{}'.format(version)
response = service.projects().predict(
name=name,
body={'instances': instances}
).execute()
if 'error' in response:
raise RuntimeError(response['error'])
return response['predictions']
predict_json(project, model, instances)