Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Aug 1, 2024
1 parent fe009bc commit 215e4b5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 81 deletions.
5 changes: 5 additions & 0 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.api.sdk import cost_report
from sky.api.sdk import down
from sky.api.sdk import exec # pylint: disable=redefined-builtin
from sky.api.sdk import get
from sky.api.sdk import job_status
from sky.api.sdk import launch
from sky.api.sdk import queue
Expand All @@ -96,6 +97,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.api.sdk import stop
from sky.api.sdk import storage_delete
from sky.api.sdk import storage_ls
from sky.api.sdk import stream_and_get
from sky.api.sdk import tail_logs
from sky.clouds.service_catalog import list_accelerators
from sky.dag import Dag
Expand Down Expand Up @@ -185,4 +187,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
# core APIs Storage Management
'storage_ls',
'storage_delete',
# Request APIs
'get',
'stream_and_get',
]
53 changes: 29 additions & 24 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""REST API for SkyPilot."""
import argparse
import asyncio
import json
import multiprocessing
import os
import pathlib
Expand Down Expand Up @@ -143,13 +144,15 @@ def restore_output(original_stdout, original_stderr):
return return_value


def _start_background_request(request_id: str,
request_name: str,
request_body: Dict[str, Any],
func: Callable[P, Any],
ignore_return_value: bool = False,
*args: P.args,
**kwargs: P.kwargs):
def _start_background_request(
request_id: str,
request_name: str,
request_body: Dict[str, Any],
func: Callable[P, Any],
ignore_return_value: bool = False,
# pylint: disable=keyword-arg-before-vararg
*args: P.args,
**kwargs: P.kwargs):
"""Start a task."""
request_task = tasks.RequestTask(request_id=request_id,
name=request_name,
Expand Down Expand Up @@ -186,7 +189,7 @@ async def optimize(optimize_body: payloads.OptimizeBody,
_start_background_request(
request_id,
request_name='optimize',
request_body=optimize_body.model_dump(),
request_body=json.loads(optimize_body.model_dump_json()),
ignore_return_value=True,
func=optimizer.Optimizer.optimize,
dag=dag,
Expand Down Expand Up @@ -301,7 +304,7 @@ async def launch(launch_body: payloads.LaunchBody, request: fastapi.Request):
_start_background_request(
request_id,
request_name='launch',
request_body=launch_body.model_dump(),
request_body=json.loads(launch_body.model_dump_json()),
func=execution.launch,
task=dag,
cluster_name=launch_body.cluster_name,
Expand Down Expand Up @@ -341,7 +344,7 @@ async def exec(request: fastapi.Request, exec_body: payloads.ExecBody):
_start_background_request(
request_id=request.state.request_id,
request_name='exec',
request_body=exec_body.model_dump(),
request_body=json.loads(exec_body.model_dump_json()),
func=execution.exec,
task=task,
cluster_name=exec_body.cluster_name,
Expand All @@ -357,7 +360,7 @@ async def stop(request: fastapi.Request, stop_body: payloads.StopOrDownBody):
_start_background_request(
request_id=request.state.request_id,
request_name='stop',
request_body=stop_body.model_dump(),
request_body=json.loads(stop_body.model_dump_json()),
func=core.stop,
cluster_name=stop_body.cluster_name,
purge=stop_body.purge,
Expand All @@ -372,7 +375,7 @@ async def status(
_start_background_request(
request_id=request.state.request_id,
request_name='status',
request_body=status_body.model_dump(),
request_body=json.loads(status_body.model_dump_json()),
func=core.status,
cluster_names=status_body.cluster_names,
refresh=status_body.refresh,
Expand All @@ -385,7 +388,7 @@ async def endpoints(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='endpoints',
request_body=endpoint_body.model_dump(),
request_body=json.loads(endpoint_body.model_dump_json()),
func=core.endpoints,
cluster_name=endpoint_body.cluster_name,
port=endpoint_body.port,
Expand All @@ -397,7 +400,7 @@ async def down(request: fastapi.Request, down_body: payloads.StopOrDownBody):
_start_background_request(
request_id=request.state.request_id,
request_name='down',
request_body=down_body.model_dump(),
request_body=json.loads(down_body.model_dump_json()),
func=core.down,
cluster_name=down_body.cluster_name,
purge=down_body.purge,
Expand All @@ -410,7 +413,7 @@ async def start(request: fastapi.Request, start_body: payloads.StartBody):
_start_background_request(
request_id=request.state.request_id,
request_name='start',
request_body=start_body.model_dump(),
request_body=json.loads(start_body.model_dump_json()),
func=core.start,
cluster_name=start_body.cluster_name,
idle_minutes_to_autostop=start_body.idle_minutes_to_autostop,
Expand All @@ -427,7 +430,7 @@ async def autostop(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='autostop',
request_body=autostop_body.model_dump(),
request_body=json.loads(autostop_body.model_dump_json()),
func=core.autostop,
cluster_name=autostop_body.cluster_name,
idle_minutes_to_autostop=autostop_body.idle_minutes_to_autostop,
Expand All @@ -441,7 +444,7 @@ async def queue(request: fastapi.Request, queue_body: payloads.QueueBody):
_start_background_request(
request_id=request.state.request_id,
request_name='queue',
request_body=queue_body.model_dump(),
request_body=json.loads(queue_body.model_dump_json()),
func=core.queue,
cluster_name=queue_body.cluster_name,
skip_finished=queue_body.skip_finished,
Expand All @@ -456,7 +459,7 @@ async def job_status(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='job_status',
request_body=job_status_body.model_dump(),
request_body=json.loads(job_status_body.model_dump_json()),
func=core.job_status,
cluster_name=job_status_body.cluster_name,
job_ids=job_status_body.job_ids,
Expand All @@ -469,7 +472,7 @@ async def cancel(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='cancel',
request_body=cancel_body.model_dump(),
request_body=json.loads(cancel_body.model_dump_json()),
func=core.cancel,
cluster_name=cancel_body.cluster_name,
job_ids=cancel_body.job_ids,
Expand All @@ -483,7 +486,7 @@ async def logs(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='logs',
request_body=cluster_job_body.model_dump(),
request_body=json.loads(cluster_job_body.model_dump_json()),
func=core.tail_logs,
cluster_name=cluster_job_body.cluster_name,
job_id=cluster_job_body.job_id,
Expand All @@ -494,8 +497,10 @@ async def logs(request: fastapi.Request,
# TODO(zhwu): expose download_logs
# @app.get('/download_logs')
# async def download_logs(request: fastapi.Request,
# cluster_jobs_body: payloads.ClusterJobsBody) -> Dict[str, str]:
# """Download logs to API server and returns the job id to log dir mapping."""
# cluster_jobs_body: payloads.ClusterJobsBody,
# ) -> Dict[str, str]:
# """Download logs to API server and returns the job id to log dir
# mapping."""
# # Call the function directly to download the logs to the API server first.
# log_dirs = core.download_logs(cluster_name=cluster_jobs_body.cluster_name,
# job_ids=cluster_jobs_body.job_ids)
Expand All @@ -509,7 +514,7 @@ async def cost_report(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='cost_report',
request_body=cost_report_body.model_dump(),
request_body=json.loads(cost_report_body.model_dump_json()),
func=core.cost_report,
all=cost_report_body.all,
)
Expand All @@ -531,7 +536,7 @@ async def storage_delete(request: fastapi.Request,
_start_background_request(
request_id=request.state.request_id,
request_name='storage_delete',
request_body=storage_body.model_dump(),
request_body=json.loads(storage_body.model_dump_json()),
func=core.storage_delete,
name=storage_body.name,
)
Expand Down
Loading

0 comments on commit 215e4b5

Please sign in to comment.