Skip to content

Commit

Permalink
[Serve] Support customizable readiness probe timeout (#3472)
Browse files Browse the repository at this point in the history
* init

* format

* fix doc and comment

* add comment

* add smoke test. TODO: test it.

* add initial delay

* wait for it to fail

---------

Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
cblmemo and Michaelvll authored Jun 20, 2024
1 parent 1f418d9 commit 65bbcf5
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 11 deletions.
7 changes: 7 additions & 0 deletions docs/source/serving/service-yaml-spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ Available fields:
# highly related to your service, so it is recommended to set this value
# based on your service's startup time.
initial_delay_seconds: 1200
# The Timeout in seconds for a readiness probe request (optional).
# Defaults to 15 seconds. If the readiness probe takes longer than this
# time to respond, the probe will be considered as failed. This is
# useful when your service is slow to respond to readiness probe
# requests. Note, having a too high timeout will delay the detection
# of a real failure of your service replica.
timeout_seconds: 15
# Simplified version of readiness probe that only contains the readiness
# probe path. If you want to use GET method for readiness probe and the
Expand Down
3 changes: 1 addition & 2 deletions sky/serve/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
# The default timeout in seconds for a readiness probe request. We set the
# timeout to 15s since using actual generation in LLM services as readiness
# probe is very time-consuming (33B, 70B, ...).
# TODO(tian): Expose this option to users in yaml file.
READINESS_PROBE_TIMEOUT_SECONDS = 15
DEFAULT_READINESS_PROBE_TIMEOUT_SECONDS = 15

# Autoscaler window size in seconds for query per second. We calculate qps by
# divide the number of queries in last window size by this window size.
Expand Down
21 changes: 12 additions & 9 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def probe(
self,
readiness_path: str,
post_data: Optional[Dict[str, Any]],
timeout: int,
headers: Optional[Dict[str, str]],
) -> Tuple['ReplicaInfo', bool, float]:
"""Probe the readiness of the replica.
Expand All @@ -512,17 +513,15 @@ def probe(
logger.info(f'Probing {replica_identity} with {readiness_path}.')
if post_data is not None:
msg += 'POST'
response = requests.post(
readiness_path,
headers=headers,
json=post_data,
timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS)
response = requests.post(readiness_path,
json=post_data,
headers=headers,
timeout=timeout)
else:
msg += 'GET'
response = requests.get(
readiness_path,
headers=headers,
timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS)
response = requests.get(readiness_path,
headers=headers,
timeout=timeout)
msg += (f' request to {replica_identity} returned status '
f'code {response.status_code}')
if response.status_code == 200:
Expand Down Expand Up @@ -1043,6 +1042,7 @@ def _probe_all_replicas(self) -> None:
(
self._get_readiness_path(info.version),
self._get_post_data(info.version),
self._get_readiness_timeout_seconds(info.version),
self._get_readiness_headers(info.version),
),
),)
Expand Down Expand Up @@ -1230,3 +1230,6 @@ def _get_readiness_headers(self, version: int) -> Optional[Dict[str, str]]:

def _get_initial_delay_seconds(self, version: int) -> int:
return self._get_version_spec(version).initial_delay_seconds

def _get_readiness_timeout_seconds(self, version: int) -> int:
return self._get_version_spec(version).readiness_timeout_seconds
16 changes: 16 additions & 0 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self,
readiness_path: str,
initial_delay_seconds: int,
readiness_timeout_seconds: int,
min_replicas: int,
max_replicas: Optional[int] = None,
target_qps_per_replica: Optional[float] = None,
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(

self._readiness_path: str = readiness_path
self._initial_delay_seconds: int = initial_delay_seconds
self._readiness_timeout_seconds: int = readiness_timeout_seconds
self._min_replicas: int = min_replicas
self._max_replicas: Optional[int] = max_replicas
self._target_qps_per_replica: Optional[float] = target_qps_per_replica
Expand Down Expand Up @@ -113,16 +115,23 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec':
service_config['readiness_path'] = readiness_section
initial_delay_seconds = None
post_data = None
readiness_timeout_seconds = None
readiness_headers = None
else:
service_config['readiness_path'] = readiness_section['path']
initial_delay_seconds = readiness_section.get(
'initial_delay_seconds', None)
post_data = readiness_section.get('post_data', None)
readiness_timeout_seconds = readiness_section.get(
'timeout_seconds', None)
readiness_headers = readiness_section.get('headers', None)
if initial_delay_seconds is None:
initial_delay_seconds = constants.DEFAULT_INITIAL_DELAY_SECONDS
service_config['initial_delay_seconds'] = initial_delay_seconds
if readiness_timeout_seconds is None:
readiness_timeout_seconds = (
constants.DEFAULT_READINESS_PROBE_TIMEOUT_SECONDS)
service_config['readiness_timeout_seconds'] = readiness_timeout_seconds
if isinstance(post_data, str):
try:
post_data = json.loads(post_data)
Expand Down Expand Up @@ -209,6 +218,8 @@ def add_if_not_none(section, key, value, no_empty: bool = False):
add_if_not_none('readiness_probe', 'initial_delay_seconds',
self.initial_delay_seconds)
add_if_not_none('readiness_probe', 'post_data', self.post_data)
add_if_not_none('readiness_probe', 'timeout_seconds',
self.readiness_timeout_seconds)
add_if_not_none('readiness_probe', 'headers', self._readiness_headers)
add_if_not_none('replica_policy', 'min_replicas', self.min_replicas)
add_if_not_none('replica_policy', 'max_replicas', self.max_replicas)
Expand Down Expand Up @@ -268,6 +279,7 @@ def __repr__(self) -> str:
return textwrap.dedent(f"""\
Readiness probe method: {self.probe_str()}
Readiness initial delay seconds: {self.initial_delay_seconds}
Readiness probe timeout seconds: {self.readiness_timeout_seconds}
Replica autoscaling policy: {self.autoscaling_policy_str()}
Spot Policy: {self.spot_policy_str()}
""")
Expand All @@ -280,6 +292,10 @@ def readiness_path(self) -> str:
def initial_delay_seconds(self) -> int:
return self._initial_delay_seconds

@property
def readiness_timeout_seconds(self) -> int:
return self._readiness_timeout_seconds

@property
def min_replicas(self) -> int:
return self._min_replicas
Expand Down
3 changes: 3 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def get_service_schema():
'initial_delay_seconds': {
'type': 'number',
},
'timeout_seconds': {
'type': 'number',
},
'post_data': {
'anyOf': [{
'type': 'string',
Expand Down
27 changes: 27 additions & 0 deletions tests/skyserve/readiness_timeout/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse
import asyncio

import fastapi
import uvicorn

app = fastapi.FastAPI()


@app.get('/')
async def root():
return 'Hi, SkyPilot here!'


@app.get('/health')
async def health():
# Simulate a readiness probe with long processing time.
await asyncio.sleep(20)
return {'status': 'ok'}


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='SkyServe Readiness Timeout Test Server')
parser.add_argument('--port', type=int, required=True)
args = parser.parse_args()
uvicorn.run(app, host='0.0.0.0', port=args.port)
14 changes: 14 additions & 0 deletions tests/skyserve/readiness_timeout/task.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# test.yaml
service:
readiness_probe:
path: /health
initial_delay_seconds: 120
replicas: 1

workdir: tests/skyserve/readiness_timeout

resources:
cpus: 2+
ports: 8081

run: python3 server.py --port 8081
15 changes: 15 additions & 0 deletions tests/skyserve/readiness_timeout/task_large_timeout.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# test.yaml
service:
readiness_probe:
path: /health
initial_delay_seconds: 120
timeout_seconds: 30
replicas: 1

workdir: tests/skyserve/readiness_timeout

resources:
cpus: 2+
ports: 8081

run: python3 server.py --port 8081
41 changes: 41 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -3630,6 +3630,47 @@ def test_skyserve_streaming(generic_cloud: str):
run_one_test(test)


@pytest.mark.serve
def test_skyserve_readiness_timeout_fail(generic_cloud: str):
"""Test skyserve with large readiness probe latency, expected to fail"""
name = _get_service_name()
test = Test(
f'test-skyserve-readiness-timeout-fail',
[
f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task.yaml',
# None of the readiness probe will pass, so the service will be
# terminated after the initial delay.
f's=$(sky serve status {name}); '
f'until echo "$s" | grep "FAILED_INITIAL_DELAY"; do '
'echo "Waiting for replica to be failed..."; sleep 5; '
f's=$(sky serve status {name}); echo "$s"; done;',
'sleep 60',
f'{_SERVE_STATUS_WAIT.format(name=name)}; echo "$s" | grep "{name}" | grep "FAILED_INITIAL_DELAY" | wc -l | grep 1;'
],
_TEARDOWN_SERVICE.format(name=name),
timeout=20 * 60,
)
run_one_test(test)


@pytest.mark.serve
def test_skyserve_large_readiness_timeout(generic_cloud: str):
"""Test skyserve with customized large readiness timeout"""
name = _get_service_name()
test = Test(
f'test-skyserve-large-readiness-timeout',
[
f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task_large_timeout.yaml',
_SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; '
'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"',
],
_TEARDOWN_SERVICE.format(name=name),
timeout=20 * 60,
)
run_one_test(test)


@pytest.mark.serve
def test_skyserve_update(generic_cloud: str):
"""Test skyserve with update"""
Expand Down

0 comments on commit 65bbcf5

Please sign in to comment.