-
Notifications
You must be signed in to change notification settings - Fork 3
/
endpoint_utils.py
124 lines (98 loc) · 5.26 KB
/
endpoint_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Databricks notebook source
# MAGIC %md
# MAGIC # Utils for working with Endpoints
# COMMAND ----------
# Currently, there is no python API for the serving Endpoints so we will create a function for it
import urllib
import json
import time
import requests
class EndpointApiClient:
def __init__(self):
self.base_url =dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
self.token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
self.headers = {"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}
def create_inference_endpoint(self, endpoint_name, served_models):
data = {"name": endpoint_name, "config": {"served_models": served_models}}
return self._post("api/2.0/serving-endpoints", data)
def get_inference_endpoint(self, endpoint_name):
return self._get(f"api/2.0/serving-endpoints/{endpoint_name}", allow_error=True)
def inference_endpoint_exists(self, endpoint_name):
ep = self.get_inference_endpoint(endpoint_name)
if 'error_code' in ep and ep['error_code'] == 'RESOURCE_DOES_NOT_EXIST':
return False
if 'error_code' in ep and ep['error_code'] != 'RESOURCE_DOES_NOT_EXIST':
raise Exception(f"enpoint exists ? {ep}")
return True
def create_endpoint_if_not_exists(self, endpoint_name, model_name, model_version, workload_size, workload_type, scale_to_zero_enabled=True, wait_start=True, environment_vars = {}):
models = [{
"model_name": model_name,
"model_version": model_version,
"workload_size": workload_size,
"workload_type": workload_type,
"scale_to_zero_enabled": scale_to_zero_enabled,
"environment_vars": environment_vars
}]
if not self.inference_endpoint_exists(endpoint_name):
r = self.create_inference_endpoint(endpoint_name, models)
#Make sure we have the proper version deployed
else:
ep = self.get_inference_endpoint(endpoint_name)
if 'pending_config' in ep:
self.wait_endpoint_start(endpoint_name)
ep = self.get_inference_endpoint(endpoint_name)
if 'pending_config' in ep:
model_deployed = ep['pending_config']['served_models'][0]
print(f"Error with the model deployed: {model_deployed} - state {ep['state']}")
else:
model_deployed = ep['config']['served_models'][0]
if model_deployed['model_version'] != model_version:
print(f"Current model is version {model_deployed['model_version']}. Updating to {model_version}...")
u = self.update_model_endpoint(endpoint_name, {"served_models": models})
if wait_start:
self.wait_endpoint_start(endpoint_name)
def list_inference_endpoints(self):
return self._get("api/2.0/serving-endpoints")
def update_model_endpoint(self, endpoint_name, conf):
return self._put(f"api/2.0/serving-endpoints/{endpoint_name}/config", conf)
def delete_inference_endpoint(self, endpoint_name):
return self._delete(f"api/2.0/serving-endpoints/{endpoint_name}")
def wait_endpoint_start(self, endpoint_name):
i = 0
while self.get_inference_endpoint(endpoint_name)['state']['config_update'] == "IN_PROGRESS" and i < 500:
if i % 10 == 0:
print("waiting for endpoint to build model image and start...")
time.sleep(10)
i += 1
ep = self.get_inference_endpoint(endpoint_name)
if ep['state'].get("ready", None) != "READY":
print(f"Error creating the endpoint: {ep}")
# Making predictions
def query_inference_endpoint(self, endpoint_name, data):
return self._post(f"realtime-inference/{endpoint_name}/invocations", data)
# Debugging
def get_served_model_build_logs(self, endpoint_name, served_model_name):
return self._get(
f"api/2.0/serving-endpoints/{endpoint_name}/served-models/{served_model_name}/build-logs"
)
def get_served_model_server_logs(self, endpoint_name, served_model_name):
return self._get(
f"api/2.0/serving-endpoints/{endpoint_name}/served-models/{served_model_name}/logs"
)
def get_inference_endpoint_events(self, endpoint_name):
return self._get(f"api/2.0/serving-endpoints/{endpoint_name}/events")
def _get(self, uri, data = {}, allow_error = False):
r = requests.get(f"{self.base_url}/{uri}", params=data, headers=self.headers)
return self._process(r, allow_error)
def _post(self, uri, data = {}, allow_error = False):
return self._process(requests.post(f"{self.base_url}/{uri}", json=data, headers=self.headers), allow_error)
def _put(self, uri, data = {}, allow_error = False):
return self._process(requests.put(f"{self.base_url}/{uri}", json=data, headers=self.headers), allow_error)
def _delete(self, uri, data = {}, allow_error = False):
return self._process(requests.delete(f"{self.base_url}/{uri}", json=data, headers=self.headers), allow_error)
def _process(self, r, allow_error = False):
if r.status_code == 500 or r.status_code == 403 or not allow_error:
print(r.text)
r.raise_for_status()
return r.json()
# COMMAND ----------