Skip to content

Commit

Permalink
Fix env var
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Jun 2, 2024
1 parent b1671b9 commit 5b91fd3
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 57 deletions.
23 changes: 10 additions & 13 deletions docs/source/reference/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SkyPilot offers a programmatic API in Python, which is used under the hood by th
Your feedback is much appreciated in evolving this API!


Core API
Cluster API
-----------

sky.launch
Expand Down Expand Up @@ -53,9 +53,6 @@ sky.autostop
.. autofunction:: sky.autostop


Job Queue API
-----------------

sky.queue
~~~~~~~~~~

Expand Down Expand Up @@ -84,29 +81,29 @@ sky.cancel
.. autofunction:: sky.cancel


Managed Spot Jobs API
Managed (Spot) Jobs API
-----------------------

sky.spot_launch
sky.jobs.launch
~~~~~~~~~~~~~~~~~

.. autofunction:: sky.spot_launch
.. autofunction:: sky.jobs.launch

sky.spot_queue
sky.jobs.queue
~~~~~~~~~~~~~~~

.. autofunction:: sky.spot_queue
.. autofunction:: sky.jobs.queue

sky.spot_cancel
sky.jobs.cancel
~~~~~~~~~~~~~~~~~

.. autofunction:: sky.spot_cancel
.. autofunction:: sky.jobs.cancel


sky.spot_tail_logs
sky.jobs_tail_logs
~~~~~~~~~~~~~~~~~~

.. autofunction:: sky.spot_tail_logs
.. autofunction:: sky.jobs.tail_logs

.. _sky-dag-ref:

Expand Down
19 changes: 15 additions & 4 deletions sky/api/requests/decoders.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Handlers for the REST API return values."""
import base64
import pickle
from typing import Any, Dict, List

from sky import backends
from sky.utils import status_lib

handlers: Dict[str, Any] = {}


def _decode_and_unpickle(obj: str) -> Any:
return pickle.loads(base64.b64decode(obj.encode('utf-8')))


def register_handler(name: str):
"""Decorator to register a handler."""

Expand All @@ -32,9 +37,15 @@ def default_decode_handler(return_value: Any) -> Any:
def decode_status(return_value: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
clusters = return_value
for cluster in clusters:
# TODO(zhwu): We should make backends.ResourceHandle serializable.
cluster['handle'] = backends.CloudVmRayResourceHandle.from_config(
cluster['handle'])
cluster['handle'] = _decode_and_unpickle(cluster['handle'])
cluster['status'] = status_lib.ClusterStatus(cluster['status'])

return clusters


@register_handler('launch')
def decode_launch(return_value: Dict[str, Any]) -> Dict[str, Any]:
return {
'job_id': return_value['job_id'],
'handle': _decode_and_unpickle(return_value['handle']),
}
25 changes: 23 additions & 2 deletions sky/api/requests/encoders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
"""Handlers for the REST API return values."""
from typing import Any, Dict, List
import base64
import pickle
import typing
from typing import Any, Dict, List, Optional, Tuple

if typing.TYPE_CHECKING:
from sky import backends

handlers: Dict[str, Any] = {}


def _pickle_and_encode(obj: Any) -> str:
return base64.b64encode(pickle.dumps(obj)).decode('utf-8')


def register_handler(name: str):
"""Decorator to register a handler."""

Expand All @@ -29,5 +39,16 @@ def default_handler(return_value: Any) -> Any:
def encode_status(clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
for cluster in clusters:
cluster['status'] = cluster['status'].value
cluster['handle'] = cluster['handle'].to_config()
cluster['handle'] = _pickle_and_encode(cluster['handle'])
return clusters


@register_handler('launch')
def encode_launch(
job_id_handle: Tuple[Optional[int], Optional['backends.ResourceHandle']]
) -> Dict[str, Any]:
job_id, handle = job_id_handle
return {
'job_id': job_id,
'handle': _pickle_and_encode(handle),
}
5 changes: 2 additions & 3 deletions sky/api/requests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,15 @@ def from_row(cls, row: Tuple[Any, ...]) -> 'RequestTask':
entrypoint=row[2],
request_body=json.loads(row[3]),
status=RequestStatus(row[4]),
return_value=pickle.loads(row[5]),
return_value=json.loads(row[5]),
error=json.loads(row[6]),
pid=row[7],
)

def to_row(self) -> Tuple[Any, ...]:
return (self.request_id, self.name, self.entrypoint,
json.dumps(self.request_body), self.status.value,
pickle.dumps(self.return_value), json.dumps(self.error),
self.pid)
json.dumps(self.return_value), json.dumps(self.error), self.pid)


_DB_PATH = os.path.expanduser('~/.sky/api_server/tasks.db')
Expand Down
16 changes: 13 additions & 3 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ async def refresh_cluster_status_event():
refresh_cluster_status_event,
]

class RequestBody(pydantic.BaseModel):
env_vars: Dict[str, str] = {}

def wrapper(func: Callable[P, Any], request_id: str, *args: P.args,

def wrapper(func: Callable[P, Any], request_id: str, env_vars: Dict[str, str], *args: P.args,
**kwargs: P.kwargs):
"""Wrapper for a request task."""

Expand Down Expand Up @@ -99,6 +102,11 @@ def restore_output(original_stdout, original_stderr):
# Store copies of the original stdout and stderr file descriptors
original_stdout, original_stderr = redirect_output(f)
try:
os.environ.update(env_vars)
# Make sure the logger takes the new environment variables. This is
# necessary because the logger is initialized before the environment
# variables are set, such as SKYPILOT_DEBUG.
sky_logging.reload_logger()
return_value = func(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
with tasks.update_rest_task(request_id) as request_task:
Expand Down Expand Up @@ -130,6 +138,7 @@ def _start_background_request(request_id: str, request_name: str,
status=tasks.RequestStatus.PENDING)
tasks.dump_reqest(request_task)
request_task.log_path.touch()
kwargs['env_vars'] = request_body.get('env_vars', {})
process = multiprocessing.Process(target=wrapper,
args=(func, request_id, *args),
kwargs=kwargs)
Expand All @@ -142,7 +151,7 @@ async def startup():
asyncio.create_task(event())


class LaunchBody(pydantic.BaseModel):
class LaunchBody(RequestBody):
"""The request body for the launch endpoint."""
task: str
cluster_name: Optional[str] = None
Expand All @@ -153,6 +162,7 @@ class LaunchBody(pydantic.BaseModel):
backend: Optional[str] = None
optimize_target: optimizer.OptimizeTarget = optimizer.OptimizeTarget.COST
detach_setup: bool = False
detach_run: bool = False
no_setup: bool = False
clone_disk_from: Optional[str] = None
# Internal only:
Expand Down Expand Up @@ -190,7 +200,7 @@ async def launch(launch_body: LaunchBody, request: fastapi.Request):
backend=backend,
optimize_target=launch_body.optimize_target,
detach_setup=launch_body.detach_setup,
detach_run=True,
detach_run=launch_body.detach_run,
no_setup=launch_body.no_setup,
clone_disk_from=launch_body.clone_disk_from,
_is_launched_by_jobs_controller=launch_body.
Expand Down
12 changes: 11 additions & 1 deletion sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def wrapper(*args, **kwargs):
return wrapper


def _add_env_vars_to_body(body: Dict[str, Any]):
env_vars = {}
for env_var in os.environ:
if env_var.startswith('SKYPILOT_'):
env_vars[env_var] = os.environ[env_var]
body['env_vars'] = env_vars



@usage_lib.entrypoint
@_check_health
def launch(
Expand Down Expand Up @@ -146,6 +155,7 @@ def launch(
'_is_launched_by_sky_serve_controller': _is_launched_by_sky_serve_controller,
'_disable_controller_check': _disable_controller_check,
}
_add_env_vars_to_body(body)
response = requests.post(
f'{_get_server_url()}/launch',
json=body,
Expand Down Expand Up @@ -308,4 +318,4 @@ def api_logs(follow: bool = True, tail: str = 'all'):
except ValueError as e:
raise ValueError(f'Invalid tail argument: {tail}') from e
log_path = os.path.expanduser(constants.API_SERVER_LOGS)
subprocess.run(['tail', *tail_args, log_path], check=False)
subprocess.run(['tail', *tail_args, f'{log_path}'], check=False)
31 changes: 0 additions & 31 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,37 +2585,6 @@ def __setstate__(self, state):
# so the head IP in the database is not updated.
pass

def to_config(self) -> Dict[str, Any]:
result: Dict[str, Any] = {}
result['cluster_name'] = self.cluster_name
result['cluster_name_on_cloud'] = self.cluster_name_on_cloud
result['cluster_yaml'] = self.cluster_yaml
result[
'stable_internal_external_ips'] = self.stable_internal_external_ips
result['stable_ssh_ports'] = self.stable_ssh_ports
result['launched_nodes'] = self.launched_nodes
result['launched_resources'] = self.launched_resources.to_yaml_config()
result['docker_user'] = self.docker_user
result['tpu_create_script'] = self.tpu_create_script
result['tpu_delete_script'] = self.tpu_delete_script
return result

@classmethod
def from_config(cls, config: dict) -> 'CloudVmRayResourceHandle':
result = cls(
cluster_name=config['cluster_name'],
cluster_name_on_cloud=config['cluster_name_on_cloud'],
cluster_yaml=config['cluster_yaml'],
launched_nodes=config['launched_nodes'],
launched_resources=resources_lib.Resources.from_yaml_config(
config['launched_resources']),
stable_internal_external_ips=config['stable_internal_external_ips'],
stable_ssh_ports=config['stable_ssh_ports'],
tpu_create_script=config['tpu_create_script'],
tpu_delete_script=config['tpu_delete_script'])
result.docker_user = config['docker_user']
return result


@registry.BACKEND_REGISTRY.register
class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
Expand Down

0 comments on commit 5b91fd3

Please sign in to comment.