Skip to content

Commit

Permalink
Add stream operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Nov 13, 2023
1 parent dcec696 commit 298fd81
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 49 deletions.
6 changes: 4 additions & 2 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,10 +843,9 @@ def _launch_with_confirm(

if node_type is None or maybe_status != status_lib.ClusterStatus.UP:
# No need to sky.launch again when interactive node is already up.
sky.launch(
request_id = sdk.launch(
dag,
dryrun=dryrun,
stream_logs=True,
cluster_name=cluster,
detach_setup=detach_setup,
detach_run=detach_run,
Expand All @@ -857,6 +856,9 @@ def _launch_with_confirm(
no_setup=no_setup,
clone_disk_from=clone_disk_from,
)
sdk.stream_and_get(
request_id
)


# TODO: skip installing ray to speed up provisioning.
Expand Down
19 changes: 12 additions & 7 deletions sky/api/requests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from sky.utils import common_utils
from sky.utils import db_utils

TASK_LOG_PATH_PREFIX = '~/sky_logs/requests'


class RequestStatus(enum.Enum):
"""The status of a task."""
Expand All @@ -40,7 +42,6 @@ def __gt__(self, other):
'status',
'return_value',
'error',
'log_path',
'pid',
]

Expand All @@ -56,9 +57,16 @@ class RequestTask:
status: RequestStatus
return_value: Any = None
error: Optional[Dict[str, Any]] = None
log_path: Optional[str] = None
pid: Optional[int] = None

@property
def log_path(self) -> pathlib.Path:
log_path_prefix = pathlib.Path(
TASK_LOG_PATH_PREFIX).expanduser().absolute()
log_path_prefix.mkdir(parents=True, exist_ok=True)
log_path = (log_path_prefix / self.request_id).with_suffix('.log')
return log_path

def set_error(self, error: Exception):
"""Set the error."""
# TODO(zhwu): handle other members of Exception.
Expand Down Expand Up @@ -86,15 +94,13 @@ def from_row(cls, row: Tuple[Any, ...]) -> 'RequestTask':
status=RequestStatus(row[4]),
return_value=json.loads(row[5]),
error=json.loads(row[6]),
log_path=row[7],
pid=row[8],
pid=row[7],
)

def to_row(self) -> Tuple[Any, ...]:
return (self.request_id, self.name, self.entrypoint,
json.dumps(self.request_body), self.status.value,
json.dumps(self.return_value), json.dumps(self.error),
self.log_path, self.pid)
json.dumps(self.return_value), json.dumps(self.error), self.pid)


_DB_PATH = os.path.expanduser('~/.sky/api_server/tasks.db')
Expand Down Expand Up @@ -127,7 +133,6 @@ def create_table(cursor, conn):
status TEXT,
return_value TEXT,
error BLOB,
log_path TEXT,
pid INTEGER)""")


Expand Down
78 changes: 57 additions & 21 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import json
import multiprocessing
import pathlib
import sys
import tempfile
from typing import Any, Callable, Dict, List, Optional
Expand All @@ -16,6 +17,7 @@
from sky import core
from sky import execution
from sky import optimizer
from sky import sky_logging
from sky.api.requests import tasks
from sky.utils import dag_utils
from sky.utils import registry
Expand Down Expand Up @@ -61,27 +63,35 @@ async def refresh_cluster_status_event():

def wrapper(func: Callable[P, Any], request_id: str, *args: P.args,
**kwargs: P.kwargs):
"""Wrapper for a request task."""
print(f'Running task {request_id}')
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
log_path = request_task.log_path
request_task.pid = multiprocessing.current_process().pid
request_task.status = tasks.RequestStatus.RUNNING
try:
return_value = func(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = tasks.RequestStatus.FAILED
request_task.set_error(e)
print(f'Task {request_id} failed')
raise
else:
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = tasks.RequestStatus.SUCCEEDED
request_task.set_return_value(return_value)
print(f'Task {request_id} finished')
return return_value
with log_path.open('w') as f:
# Redirect stdout and stderr to the log file.
sys.stdout = sys.stderr = f
# reconfigure logger since the logger is initialized before
# with previous stdout/stderr
sky_logging.reload_logger()
try:
return_value = func(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = tasks.RequestStatus.FAILED
request_task.set_error(e)
print(f'Task {request_id} failed')
raise
else:
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = tasks.RequestStatus.SUCCEEDED
request_task.set_return_value(return_value)
print(f'Task {request_id} finished')
return return_value


def _start_background_request(request_id: str, request_name: str,
Expand Down Expand Up @@ -212,19 +222,45 @@ class RequestIdBody(pydantic.BaseModel):


@app.get('/get')
async def get(wait_body: RequestIdBody) -> tasks.RequestTask:
async def get(get_body: RequestIdBody) -> tasks.RequestTask:
while True:
request_task = tasks.get_request(wait_body.request_id)
request_task = tasks.get_request(get_body.request_id)
if request_task is None:
print(f'No task with request ID {wait_body.request_id}')
print(f'No task with request ID {get_body.request_id}')
raise fastapi.HTTPException(
status_code=404,
detail=f'Request {wait_body.request_id} not found')
detail=f'Request {get_body.request_id} not found')
if request_task.status > tasks.RequestStatus.RUNNING:
return request_task
await asyncio.sleep(1)

# TODO(zhwu): stream the logs and handle errors.

async def log_streamer(request_id: str, log_path: pathlib.Path):
with log_path.open('rb') as f:
while True:
line = f.readline()
if not line:
request_task = tasks.get_request(request_id)
if request_task.status > tasks.RequestStatus.RUNNING:
break
await asyncio.sleep(1)
continue
yield line


@app.get('/stream')
async def stream(
stream_body: RequestIdBody) -> fastapi.responses.StreamingResponse:
request_id = stream_body.request_id
request_task = tasks.get_request(request_id)
if request_task is None:
print(f'No task with request ID {request_id}')
raise fastapi.HTTPException(status_code=404,
detail=f'Request {request_id} not found')
log_path = request_task.log_path
return fastapi.responses.StreamingResponse(log_streamer(
request_id, log_path),
media_type='text/plain')


@app.post('/abort')
Expand Down
52 changes: 34 additions & 18 deletions sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,24 @@ def launch(

# TODO(zhwu): For all the file_mounts, we need to handle them properly
# similarly to how we deal with it for spot_launch.
body = {
'task': dag_str,
'cluster_name': cluster_name,
'retry_until_up': retry_until_up,
'idle_minutes_to_autostop': idle_minutes_to_autostop,
'dryrun': dryrun,
'down': down,
'backend': backend.NAME if backend else None,
'optimize_target': optimize_target.value,
'detach_setup': detach_setup,
'detach_run': detach_run,
'no_setup': no_setup,
'clone_disk_from': clone_disk_from,
'_is_launched_by_spot_controller': _is_launched_by_spot_controller,
}
response = requests.post(
f'{_get_server_url()}/launch',
json={
'task': dag_str,
'cluster_name': cluster_name,
'retry_until_up': retry_until_up,
'idle_minutes_to_autostop': idle_minutes_to_autostop,
'dryrun': dryrun,
'down': down,
'backend': backend.NAME if backend else None,
'optimize_target': optimize_target,
'detach_setup': detach_setup,
'detach_run': detach_run,
'no_setup': no_setup,
'clone_disk_from': clone_disk_from,
'_is_launched_by_spot_controller': _is_launched_by_spot_controller,
},
json=body,
timeout=5,
)
return response.headers['X-Request-ID']
Expand All @@ -159,8 +160,7 @@ def status(cluster_names: Optional[List[str]] = None,
json={
'cluster_names': cluster_names,
'refresh': refresh,
},
timeout=30)
})
request_id, _ = _handle_response(response)
return request_id

Expand All @@ -170,7 +170,7 @@ def status(cluster_names: Optional[List[str]] = None,
def get(request_id: str) -> Any:
response = requests.get(f'{_get_server_url()}/get',
json={'request_id': request_id},
timeout=30)
timeout=300)
_, return_value = _handle_response(response)
request_task = tasks.RequestTask(**return_value)
if request_task.error:
Expand All @@ -180,6 +180,22 @@ def get(request_id: str) -> Any:
return request_task.get_return_value()


@usage_lib.entrypoint
@_check_health
def stream_and_get(request_id: str) -> Any:
response = requests.get(f'{_get_server_url()}/stream',
json={'request_id': request_id},
timeout=300,
stream=True)

if response.status_code != 200:
return get(request_id)
for line in response.iter_lines():
if line:
print(line.decode('utf-8'))
return get(request_id)


@usage_lib.entrypoint
@_check_health
def down(cluster_name: str, purge: bool = False) -> str:
Expand Down
4 changes: 3 additions & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from sky.utils import command_runner
from sky.utils import common_utils
from sky.utils import log_utils
from sky.utils import registry
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import status_lib
Expand Down Expand Up @@ -2675,6 +2676,7 @@ def from_config(cls, config: dict) -> 'CloudVmRayResourceHandle':
return result


@registry.BACKEND_REGISTRY.register
class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
"""Backend: runs on cloud virtual machines, managed by Ray.
Expand All @@ -2683,7 +2685,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
* Cloud providers' implementations under clouds/
"""

NAME = 'cloudvmray'
NAME = 'cloudvmraybackend'

# Backward compatibility, with the old name of the handle.
ResourceHandle = CloudVmRayResourceHandle # pylint: disable=invalid-name
Expand Down
11 changes: 11 additions & 0 deletions sky/sky_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,14 @@ def is_silent():
# threads.
_logging_config.is_silent = False
return _logging_config.is_silent


def reload_logger():
"""Reload the logger.
This is useful when the logging configuration is changed.
e.g., the logging level is changed or stdout/stderr is reset.
"""
global _default_handler
_root_logger.removeHandler(_default_handler)
_default_handler = None
_setup_logger()

0 comments on commit 298fd81

Please sign in to comment.