Skip to content

Commit

Permalink
Allow request prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Aug 1, 2024
1 parent d8f5698 commit 86eeb0e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
16 changes: 5 additions & 11 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,11 +1211,6 @@ def exec(cluster: Optional[str], cluster_option: Optional[str],
env = _merge_env_vars(env_file, env)
controller_utils.check_cluster_name_not_controller(
cluster, operation_str='Executing task on it')
handle = global_user_state.get_handle_from_cluster_name(cluster)
if handle is None:
raise click.BadParameter(f'Cluster {cluster!r} not found. '
'Use `sky launch` to provision first.')
backend = backend_utils.get_backend_from_handle(handle)

task_or_dag = _make_task_or_dag_from_entrypoint_with_overrides(
entrypoint=entrypoint,
Expand Down Expand Up @@ -1245,7 +1240,6 @@ def exec(cluster: Optional[str], cluster_option: Optional[str],

click.secho(f'Executing task on cluster {cluster}...', fg='yellow')
request_id = sdk.exec(task,
backend=backend,
cluster_name=cluster,
detach_run=detach_run)
if not async_call:
Expand Down Expand Up @@ -1965,7 +1959,7 @@ def logs(
# return

assert job_ids is None or len(job_ids) <= 1, job_ids
job_id: int
job_id: Optional[int] = None
job_ids_to_query: Optional[List[int]] = None
if job_ids:
# Already check that len(job_ids) <= 1. This variable is used later
Expand Down Expand Up @@ -5220,18 +5214,18 @@ def api_stop():
sdk.api_stop()


@api.command('stream', cls=_DocumentedCodeCommand)
@api.command('get', cls=_DocumentedCodeCommand)
@click.argument('request_id', required=False, type=str)
@usage_lib.entrypoint
def api_stream(request_id: Optional[str]):
def api_get(request_id: Optional[str]):
"""Stream the logs of a request running on API server."""
if request_id is None:
# TODO(zhwu): get the latest request ID.
raise click.BadParameter('Please provide the request ID.')
sdk.stream_and_get(request_id)


@api.command('logs', cls=_DocumentedCodeCommand)
@api.command('server_logs', cls=_DocumentedCodeCommand)
@click.option('--follow',
'-f',
is_flag=True,
Expand All @@ -5245,7 +5239,7 @@ def api_stream(request_id: Optional[str]):
'(default "all")'))
# Follow the arguments of `docker logs` command.
@usage_lib.entrypoint
def api_logs(follow: bool, tail: str):
def api_server_logs(follow: bool, tail: str):
"""Shows the API server logs."""
sdk.api_logs(follow, tail)

Expand Down
1 change: 1 addition & 0 deletions sky/api/requests/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
API_SERVER_REQUEST_DB_PATH = '~/.sky/api_server/tasks.db'
7 changes: 4 additions & 3 deletions sky/api/requests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from sky.api.requests import decoders
from sky.api.requests import encoders
from sky.api.requests import constants
from sky.utils import common_utils
from sky.utils import db_utils

Expand Down Expand Up @@ -150,7 +151,7 @@ def decode(cls, payload: RequestTaskPayload) -> 'RequestTask':
)


_DB_PATH = os.path.expanduser('~/.sky/api_server/tasks.db')
_DB_PATH = os.path.expanduser(constants.API_SERVER_REQUEST_DB_PATH)
pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True)


Expand Down Expand Up @@ -225,8 +226,8 @@ def _get_rest_task_no_lock(request_id: str) -> Optional[RequestTask]:
assert _DB is not None
with _DB.conn:
cursor = _DB.conn.cursor()
cursor.execute('SELECT * FROM rest_tasks WHERE request_id=?',
(request_id,))
cursor.execute('SELECT * FROM rest_tasks WHERE request_id LIKE ?',
(request_id + '%',))
row = cursor.fetchone()
if row is None:
return None
Expand Down
17 changes: 16 additions & 1 deletion sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sky import sky_logging
from sky.api.requests import payloads
from sky.api.requests import tasks
from sky.api.requests import constants as requests_constants
from sky.backends import backend_utils
from sky.data import data_utils
from sky.skylet import constants
Expand Down Expand Up @@ -603,6 +604,11 @@ def get(request_id: str) -> Any:
@usage_lib.entrypoint
@_check_health
def stream_and_get(request_id: str) -> Any:
"""Stream the logs of a request and get the final result.
This will block until the request is finished. The request id can be a
prefix of the full request id.
"""
body = payloads.RequestIdBody(request_id=request_id)
response = requests.get(
f'{_get_server_url()}/stream',
Expand Down Expand Up @@ -660,6 +666,15 @@ def api_stop():
process.kill()
found = True

# Remove the database for requests including any files starting with
# constants.API_SERVER_REQUEST_DB_PATH
db_path = os.path.expanduser(requests_constants.API_SERVER_REQUEST_DB_PATH)
for extension in ['', '-shm', '-wal']:
try:
os.remove(f'{db_path}{extension}')
except FileNotFoundError as e:
logger.info(f'Database file {db_path}{extension} not found.')

if found:
logger.info(f'{colorama.Fore.GREEN}SkyPilot API server stopped.'
f'{colorama.Style.RESET_ALL}')
Expand All @@ -669,7 +684,7 @@ def api_stop():

# Use the same args as `docker logs`
@usage_lib.entrypoint
def api_logs(follow: bool = True, tail: str = 'all'):
def api_server_logs(follow: bool = True, tail: str = 'all'):
"""Stream the API server logs."""
server_url = _get_server_url()
if server_url != DEFAULT_SERVER_URL:
Expand Down

0 comments on commit 86eeb0e

Please sign in to comment.