Skip to content

Commit

Permalink
reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Oct 26, 2023
1 parent 325083c commit 899f74b
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 31 deletions.
Empty file added sky/api/requests/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
9 changes: 4 additions & 5 deletions sky/api/request_tasks.py → sky/api/requests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

import filelock

from sky.api import request_return_decoders
from sky.api import request_return_encoders
from sky.api.requests import decoders
from sky.api.requests import encoders
from sky.utils import common_utils
from sky.utils import db_utils

Expand Down Expand Up @@ -70,12 +70,11 @@ def set_error(self, error: Exception):

def set_return_value(self, return_value: Any):
"""Set the return value."""
self.return_value = request_return_encoders.get_handler(
self.name)(return_value)
self.return_value = encoders.get_handler(self.name)(return_value)

def get_return_value(self) -> Any:
"""Get the return value."""
return request_return_decoders.get_handler(self.name)(self.return_value)
return decoders.get_handler(self.name)(self.return_value)

@classmethod
def from_row(cls, row: Tuple[Any, ...]) -> 'RequestTask':
Expand Down
46 changes: 22 additions & 24 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sky import core
from sky import execution
from sky import optimizer
from sky.api import request_tasks
from sky.api.requests import tasks
from sky.utils import dag_utils
from sky.utils import registry
from sky.utils import subprocess_utils
Expand Down Expand Up @@ -62,23 +62,23 @@ async def refresh_cluster_status_event():
def wrapper(func: Callable[P, Any], request_id: str, *args: P.args,
**kwargs: P.kwargs):
print(f'Running task {request_id}')
with request_tasks.update_rest_task(request_id) as request_task:
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.pid = multiprocessing.current_process().pid
request_task.status = request_tasks.RequestStatus.RUNNING
request_task.status = tasks.RequestStatus.RUNNING
try:
return_value = func(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
with request_tasks.update_rest_task(request_id) as request_task:
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = request_tasks.RequestStatus.FAILED
request_task.status = tasks.RequestStatus.FAILED
request_task.set_error(e)
print(f'Task {request_id} failed')
raise
else:
with request_tasks.update_rest_task(request_id) as request_task:
with tasks.update_rest_task(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = request_tasks.RequestStatus.SUCCEEDED
request_task.status = tasks.RequestStatus.SUCCEEDED
request_task.set_return_value(return_value)
print(f'Task {request_id} finished')
return return_value
Expand All @@ -89,13 +89,12 @@ def _start_background_request(request_id: str, request_name: str,
Any],
*args: P.args, **kwargs: P.kwargs):
"""Start a task."""
rest_task = request_tasks.RequestTask(
request_id=request_id,
name=request_name,
entrypoint=func.__module__,
request_body=request_body,
status=request_tasks.RequestStatus.PENDING)
request_tasks.dump_reqest(rest_task)
rest_task = tasks.RequestTask(request_id=request_id,
name=request_name,
entrypoint=func.__module__,
request_body=request_body,
status=tasks.RequestStatus.PENDING)
tasks.dump_reqest(rest_task)
process = multiprocessing.Process(target=wrapper,
args=(func, request_id, *args),
kwargs=kwargs)
Expand Down Expand Up @@ -213,15 +212,15 @@ class RequestIdBody(pydantic.BaseModel):


@app.get('/get')
async def get(wait_body: RequestIdBody) -> request_tasks.RequestTask:
async def get(wait_body: RequestIdBody) -> tasks.RequestTask:
while True:
request_task = request_tasks.get_request(wait_body.request_id)
request_task = tasks.get_request(wait_body.request_id)
if request_task is None:
print(f'No task with request ID {wait_body.request_id}')
raise fastapi.HTTPException(
status_code=404,
detail=f'Request {wait_body.request_id} not found')
if request_task.status > request_tasks.RequestStatus.RUNNING:
if request_task.status > tasks.RequestStatus.RUNNING:
return request_task
await asyncio.sleep(1)

Expand All @@ -231,25 +230,24 @@ async def get(wait_body: RequestIdBody) -> request_tasks.RequestTask:
@app.post('/abort')
async def abort(abort_body: RequestIdBody):
print(f'Trying to kill request ID {abort_body.request_id}')
with request_tasks.update_rest_task(abort_body.request_id) as rest_task:
with tasks.update_rest_task(abort_body.request_id) as rest_task:
if rest_task is None:
print(f'No task with request ID {abort_body.request_id}')
raise fastapi.HTTPException(
status_code=404,
detail=f'Request {abort_body.request_id} not found')
rest_task.status = request_tasks.RequestStatus.ABORTED
rest_task.status = tasks.RequestStatus.ABORTED
if rest_task.pid is not None:
subprocess_utils.kill_children_processes(parent_pid=rest_task.pid)
print(f'Killed request: {abort_body.request_id}')


@app.get('/requests')
async def requests(
request_id: Optional[str] = None) -> List[request_tasks.RequestTask]:
async def requests(request_id: Optional[str] = None) -> List[tasks.RequestTask]:
if request_id is None:
return request_tasks.get_request_tasks()
return tasks.get_request_tasks()
else:
request_task = request_tasks.get_request(request_id)
request_task = tasks.get_request(request_id)
if request_task is None:
raise fastapi.HTTPException(
status_code=404, detail=f'Request {request_id} not found')
Expand All @@ -268,7 +266,7 @@ async def health() -> str:

if __name__ == '__main__':
import uvicorn
request_tasks.reset_db()
tasks.reset_db()

parser = argparse.ArgumentParser()
parser.add_argument('--host', default='0.0.0.0')
Expand Down
4 changes: 2 additions & 2 deletions sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sky import backends
from sky import optimizer
from sky import sky_logging
from sky.api import request_tasks
from sky.api.requests import tasks
from sky.skylet import constants
from sky.usage import usage_lib
from sky.utils import dag_utils
Expand Down Expand Up @@ -172,7 +172,7 @@ def get(request_id: str) -> Any:
json={'request_id': request_id},
timeout=30)
_, return_value = _handle_response(response)
request_task = request_tasks.RequestTask(**return_value)
request_task = tasks.RequestTask(**return_value)
if request_task.error:
# TODO(zhwu): we should have a better way to handle errors.
# Is it possible to raise the original exception?
Expand Down

0 comments on commit 899f74b

Please sign in to comment.