From 86eeb0e12f0a1a39ddeea45e10bfb682c9145890 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 1 Aug 2024 19:03:28 +0000 Subject: [PATCH] Allow request prefix --- sky/api/cli.py | 16 +++++----------- sky/api/requests/constants.py | 1 + sky/api/requests/tasks.py | 7 ++++--- sky/api/sdk.py | 17 ++++++++++++++++- 4 files changed, 26 insertions(+), 15 deletions(-) create mode 100644 sky/api/requests/constants.py diff --git a/sky/api/cli.py b/sky/api/cli.py index d9341b0c78d..e55b71ed781 100644 --- a/sky/api/cli.py +++ b/sky/api/cli.py @@ -1211,11 +1211,6 @@ def exec(cluster: Optional[str], cluster_option: Optional[str], env = _merge_env_vars(env_file, env) controller_utils.check_cluster_name_not_controller( cluster, operation_str='Executing task on it') - handle = global_user_state.get_handle_from_cluster_name(cluster) - if handle is None: - raise click.BadParameter(f'Cluster {cluster!r} not found. ' - 'Use `sky launch` to provision first.') - backend = backend_utils.get_backend_from_handle(handle) task_or_dag = _make_task_or_dag_from_entrypoint_with_overrides( entrypoint=entrypoint, @@ -1245,7 +1240,6 @@ def exec(cluster: Optional[str], cluster_option: Optional[str], click.secho(f'Executing task on cluster {cluster}...', fg='yellow') request_id = sdk.exec(task, - backend=backend, cluster_name=cluster, detach_run=detach_run) if not async_call: @@ -1965,7 +1959,7 @@ def logs( # return assert job_ids is None or len(job_ids) <= 1, job_ids - job_id: int + job_id: Optional[int] = None job_ids_to_query: Optional[List[int]] = None if job_ids: # Already check that len(job_ids) <= 1. This variable is used later @@ -5220,10 +5214,10 @@ def api_stop(): sdk.api_stop() -@api.command('stream', cls=_DocumentedCodeCommand) +@api.command('get', cls=_DocumentedCodeCommand) @click.argument('request_id', required=False, type=str) @usage_lib.entrypoint -def api_stream(request_id: Optional[str]): +def api_get(request_id: Optional[str]): """Stream the logs of a request running on API server.""" if request_id is None: # TODO(zhwu): get the latest request ID. @@ -5231,7 +5225,7 @@ def api_stream(request_id: Optional[str]): sdk.stream_and_get(request_id) -@api.command('logs', cls=_DocumentedCodeCommand) +@api.command('server_logs', cls=_DocumentedCodeCommand) @click.option('--follow', '-f', is_flag=True, @@ -5245,7 +5239,7 @@ def api_stream(request_id: Optional[str]): '(default "all")')) # Follow the arguments of `docker logs` command. @usage_lib.entrypoint -def api_logs(follow: bool, tail: str): +def api_server_logs(follow: bool, tail: str): """Shows the API server logs.""" sdk.api_logs(follow, tail) diff --git a/sky/api/requests/constants.py b/sky/api/requests/constants.py new file mode 100644 index 00000000000..f0505dee03b --- /dev/null +++ b/sky/api/requests/constants.py @@ -0,0 +1 @@ +API_SERVER_REQUEST_DB_PATH = '~/.sky/api_server/tasks.db' diff --git a/sky/api/requests/tasks.py b/sky/api/requests/tasks.py index 15f043f14b5..a5764e21d90 100644 --- a/sky/api/requests/tasks.py +++ b/sky/api/requests/tasks.py @@ -13,6 +13,7 @@ from sky.api.requests import decoders from sky.api.requests import encoders +from sky.api.requests import constants from sky.utils import common_utils from sky.utils import db_utils @@ -150,7 +151,7 @@ def decode(cls, payload: RequestTaskPayload) -> 'RequestTask': ) -_DB_PATH = os.path.expanduser('~/.sky/api_server/tasks.db') +_DB_PATH = os.path.expanduser(constants.API_SERVER_REQUEST_DB_PATH) pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True) @@ -225,8 +226,8 @@ def _get_rest_task_no_lock(request_id: str) -> Optional[RequestTask]: assert _DB is not None with _DB.conn: cursor = _DB.conn.cursor() - cursor.execute('SELECT * FROM rest_tasks WHERE request_id=?', - (request_id,)) + cursor.execute('SELECT * FROM rest_tasks WHERE request_id LIKE ?', + (request_id + '%',)) row = cursor.fetchone() if row is None: return None diff --git a/sky/api/sdk.py b/sky/api/sdk.py index 09171676ee8..1e89a3deac7 100644 --- a/sky/api/sdk.py +++ b/sky/api/sdk.py @@ -28,6 +28,7 @@ from sky import sky_logging from sky.api.requests import payloads from sky.api.requests import tasks +from sky.api.requests import constants as requests_constants from sky.backends import backend_utils from sky.data import data_utils from sky.skylet import constants @@ -603,6 +604,11 @@ def get(request_id: str) -> Any: @usage_lib.entrypoint @_check_health def stream_and_get(request_id: str) -> Any: + """Stream the logs of a request and get the final result. + + This will block until the request is finished. The request id can be a + prefix of the full request id. + """ body = payloads.RequestIdBody(request_id=request_id) response = requests.get( f'{_get_server_url()}/stream', @@ -660,6 +666,15 @@ def api_stop(): process.kill() found = True + # Remove the database for requests including any files starting with + # constants.API_SERVER_REQUEST_DB_PATH + db_path = os.path.expanduser(requests_constants.API_SERVER_REQUEST_DB_PATH) + for extension in ['', '-shm', '-wal']: + try: + os.remove(f'{db_path}{extension}') + except FileNotFoundError as e: + logger.info(f'Database file {db_path}{extension} not found.') + if found: logger.info(f'{colorama.Fore.GREEN}SkyPilot API server stopped.' f'{colorama.Style.RESET_ALL}') @@ -669,7 +684,7 @@ def api_stop(): # Use the same args as `docker logs` @usage_lib.entrypoint -def api_logs(follow: bool = True, tail: str = 'all'): +def api_server_logs(follow: bool = True, tail: str = 'all'): """Stream the API server logs.""" server_url = _get_server_url() if server_url != DEFAULT_SERVER_URL: