From 298fd81b9d0f9839b8eb97aff58906c0370122e9 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 13 Nov 2023 22:50:30 +0000 Subject: [PATCH] Add stream operation --- sky/api/cli.py | 6 ++- sky/api/requests/tasks.py | 19 ++++--- sky/api/rest.py | 78 ++++++++++++++++++++-------- sky/api/sdk.py | 52 ++++++++++++------- sky/backends/cloud_vm_ray_backend.py | 4 +- sky/sky_logging.py | 11 ++++ 6 files changed, 121 insertions(+), 49 deletions(-) diff --git a/sky/api/cli.py b/sky/api/cli.py index 1c95fbd59d2..efbc3ed94ad 100644 --- a/sky/api/cli.py +++ b/sky/api/cli.py @@ -843,10 +843,9 @@ def _launch_with_confirm( if node_type is None or maybe_status != status_lib.ClusterStatus.UP: # No need to sky.launch again when interactive node is already up. - sky.launch( + request_id = sdk.launch( dag, dryrun=dryrun, - stream_logs=True, cluster_name=cluster, detach_setup=detach_setup, detach_run=detach_run, @@ -857,6 +856,9 @@ def _launch_with_confirm( no_setup=no_setup, clone_disk_from=clone_disk_from, ) + sdk.stream_and_get( + request_id + ) # TODO: skip installing ray to speed up provisioning. diff --git a/sky/api/requests/tasks.py b/sky/api/requests/tasks.py index c2cb4dc6b66..5c81d46dcd2 100644 --- a/sky/api/requests/tasks.py +++ b/sky/api/requests/tasks.py @@ -17,6 +17,8 @@ from sky.utils import common_utils from sky.utils import db_utils +TASK_LOG_PATH_PREFIX = '~/sky_logs/requests' + class RequestStatus(enum.Enum): """The status of a task.""" @@ -40,7 +42,6 @@ def __gt__(self, other): 'status', 'return_value', 'error', - 'log_path', 'pid', ] @@ -56,9 +57,16 @@ class RequestTask: status: RequestStatus return_value: Any = None error: Optional[Dict[str, Any]] = None - log_path: Optional[str] = None pid: Optional[int] = None + @property + def log_path(self) -> pathlib.Path: + log_path_prefix = pathlib.Path( + TASK_LOG_PATH_PREFIX).expanduser().absolute() + log_path_prefix.mkdir(parents=True, exist_ok=True) + log_path = (log_path_prefix / self.request_id).with_suffix('.log') + return log_path + def set_error(self, error: Exception): """Set the error.""" # TODO(zhwu): handle other members of Exception. @@ -86,15 +94,13 @@ def from_row(cls, row: Tuple[Any, ...]) -> 'RequestTask': status=RequestStatus(row[4]), return_value=json.loads(row[5]), error=json.loads(row[6]), - log_path=row[7], - pid=row[8], + pid=row[7], ) def to_row(self) -> Tuple[Any, ...]: return (self.request_id, self.name, self.entrypoint, json.dumps(self.request_body), self.status.value, - json.dumps(self.return_value), json.dumps(self.error), - self.log_path, self.pid) + json.dumps(self.return_value), json.dumps(self.error), self.pid) _DB_PATH = os.path.expanduser('~/.sky/api_server/tasks.db') @@ -127,7 +133,6 @@ def create_table(cursor, conn): status TEXT, return_value TEXT, error BLOB, - log_path TEXT, pid INTEGER)""") diff --git a/sky/api/rest.py b/sky/api/rest.py index 1876873d288..e1fb88f2b83 100644 --- a/sky/api/rest.py +++ b/sky/api/rest.py @@ -3,6 +3,7 @@ import asyncio import json import multiprocessing +import pathlib import sys import tempfile from typing import Any, Callable, Dict, List, Optional @@ -16,6 +17,7 @@ from sky import core from sky import execution from sky import optimizer +from sky import sky_logging from sky.api.requests import tasks from sky.utils import dag_utils from sky.utils import registry @@ -61,27 +63,35 @@ async def refresh_cluster_status_event(): def wrapper(func: Callable[P, Any], request_id: str, *args: P.args, **kwargs: P.kwargs): + """Wrapper for a request task.""" print(f'Running task {request_id}') with tasks.update_rest_task(request_id) as request_task: assert request_task is not None, request_id + log_path = request_task.log_path request_task.pid = multiprocessing.current_process().pid request_task.status = tasks.RequestStatus.RUNNING - try: - return_value = func(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except - with tasks.update_rest_task(request_id) as request_task: - assert request_task is not None, request_id - request_task.status = tasks.RequestStatus.FAILED - request_task.set_error(e) - print(f'Task {request_id} failed') - raise - else: - with tasks.update_rest_task(request_id) as request_task: - assert request_task is not None, request_id - request_task.status = tasks.RequestStatus.SUCCEEDED - request_task.set_return_value(return_value) - print(f'Task {request_id} finished') - return return_value + with log_path.open('w') as f: + # Redirect stdout and stderr to the log file. + sys.stdout = sys.stderr = f + # reconfigure logger since the logger is initialized before + # with previous stdout/stderr + sky_logging.reload_logger() + try: + return_value = func(*args, **kwargs) + except Exception as e: # pylint: disable=broad-except + with tasks.update_rest_task(request_id) as request_task: + assert request_task is not None, request_id + request_task.status = tasks.RequestStatus.FAILED + request_task.set_error(e) + print(f'Task {request_id} failed') + raise + else: + with tasks.update_rest_task(request_id) as request_task: + assert request_task is not None, request_id + request_task.status = tasks.RequestStatus.SUCCEEDED + request_task.set_return_value(return_value) + print(f'Task {request_id} finished') + return return_value def _start_background_request(request_id: str, request_name: str, @@ -212,19 +222,45 @@ class RequestIdBody(pydantic.BaseModel): @app.get('/get') -async def get(wait_body: RequestIdBody) -> tasks.RequestTask: +async def get(get_body: RequestIdBody) -> tasks.RequestTask: while True: - request_task = tasks.get_request(wait_body.request_id) + request_task = tasks.get_request(get_body.request_id) if request_task is None: - print(f'No task with request ID {wait_body.request_id}') + print(f'No task with request ID {get_body.request_id}') raise fastapi.HTTPException( status_code=404, - detail=f'Request {wait_body.request_id} not found') + detail=f'Request {get_body.request_id} not found') if request_task.status > tasks.RequestStatus.RUNNING: return request_task await asyncio.sleep(1) - # TODO(zhwu): stream the logs and handle errors. + +async def log_streamer(request_id: str, log_path: pathlib.Path): + with log_path.open('rb') as f: + while True: + line = f.readline() + if not line: + request_task = tasks.get_request(request_id) + if request_task.status > tasks.RequestStatus.RUNNING: + break + await asyncio.sleep(1) + continue + yield line + + +@app.get('/stream') +async def stream( + stream_body: RequestIdBody) -> fastapi.responses.StreamingResponse: + request_id = stream_body.request_id + request_task = tasks.get_request(request_id) + if request_task is None: + print(f'No task with request ID {request_id}') + raise fastapi.HTTPException(status_code=404, + detail=f'Request {request_id} not found') + log_path = request_task.log_path + return fastapi.responses.StreamingResponse(log_streamer( + request_id, log_path), + media_type='text/plain') @app.post('/abort') diff --git a/sky/api/sdk.py b/sky/api/sdk.py index 66f2b6a4997..840beb23e68 100644 --- a/sky/api/sdk.py +++ b/sky/api/sdk.py @@ -127,23 +127,24 @@ def launch( # TODO(zhwu): For all the file_mounts, we need to handle them properly # similarly to how we deal with it for spot_launch. + body = { + 'task': dag_str, + 'cluster_name': cluster_name, + 'retry_until_up': retry_until_up, + 'idle_minutes_to_autostop': idle_minutes_to_autostop, + 'dryrun': dryrun, + 'down': down, + 'backend': backend.NAME if backend else None, + 'optimize_target': optimize_target.value, + 'detach_setup': detach_setup, + 'detach_run': detach_run, + 'no_setup': no_setup, + 'clone_disk_from': clone_disk_from, + '_is_launched_by_spot_controller': _is_launched_by_spot_controller, + } response = requests.post( f'{_get_server_url()}/launch', - json={ - 'task': dag_str, - 'cluster_name': cluster_name, - 'retry_until_up': retry_until_up, - 'idle_minutes_to_autostop': idle_minutes_to_autostop, - 'dryrun': dryrun, - 'down': down, - 'backend': backend.NAME if backend else None, - 'optimize_target': optimize_target, - 'detach_setup': detach_setup, - 'detach_run': detach_run, - 'no_setup': no_setup, - 'clone_disk_from': clone_disk_from, - '_is_launched_by_spot_controller': _is_launched_by_spot_controller, - }, + json=body, timeout=5, ) return response.headers['X-Request-ID'] @@ -159,8 +160,7 @@ def status(cluster_names: Optional[List[str]] = None, json={ 'cluster_names': cluster_names, 'refresh': refresh, - }, - timeout=30) + }) request_id, _ = _handle_response(response) return request_id @@ -170,7 +170,7 @@ def status(cluster_names: Optional[List[str]] = None, def get(request_id: str) -> Any: response = requests.get(f'{_get_server_url()}/get', json={'request_id': request_id}, - timeout=30) + timeout=300) _, return_value = _handle_response(response) request_task = tasks.RequestTask(**return_value) if request_task.error: @@ -180,6 +180,22 @@ def get(request_id: str) -> Any: return request_task.get_return_value() +@usage_lib.entrypoint +@_check_health +def stream_and_get(request_id: str) -> Any: + response = requests.get(f'{_get_server_url()}/stream', + json={'request_id': request_id}, + timeout=300, + stream=True) + + if response.status_code != 200: + return get(request_id) + for line in response.iter_lines(): + if line: + print(line.decode('utf-8')) + return get(request_id) + + @usage_lib.entrypoint @_check_health def down(cluster_name: str, purge: bool = False) -> str: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 4f0df471d6f..95410c41ef6 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -51,6 +51,7 @@ from sky.utils import command_runner from sky.utils import common_utils from sky.utils import log_utils +from sky.utils import registry from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import status_lib @@ -2675,6 +2676,7 @@ def from_config(cls, config: dict) -> 'CloudVmRayResourceHandle': return result +@registry.BACKEND_REGISTRY.register class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']): """Backend: runs on cloud virtual machines, managed by Ray. @@ -2683,7 +2685,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']): * Cloud providers' implementations under clouds/ """ - NAME = 'cloudvmray' + NAME = 'cloudvmraybackend' # Backward compatibility, with the old name of the handle. ResourceHandle = CloudVmRayResourceHandle # pylint: disable=invalid-name diff --git a/sky/sky_logging.py b/sky/sky_logging.py index 04dad68c819..548b17adeca 100644 --- a/sky/sky_logging.py +++ b/sky/sky_logging.py @@ -116,3 +116,14 @@ def is_silent(): # threads. _logging_config.is_silent = False return _logging_config.is_silent + + +def reload_logger(): + """Reload the logger. + This is useful when the logging configuration is changed. + e.g., the logging level is changed or stdout/stderr is reset. + """ + global _default_handler + _root_logger.removeHandler(_default_handler) + _default_handler = None + _setup_logger()