Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query cloud specific env vars in task setup #2347

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sky/adaptors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ def config_exception():
@import_package
def max_retry_error():
return urllib3.exceptions.MaxRetryError


@import_package
def stream():
return kubernetes.stream.stream
16 changes: 16 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2822,12 +2822,27 @@ def _sync_file_mounts(
self._execute_file_mounts(handle, all_file_mounts)
self._execute_storage_mounts(handle, storage_mounts)

def _update_envs_for_k8s(self, handle: CloudVmRayResourceHandle,
task: task_lib.Task) -> None:
"""Update envs for a task with Kubernetes specific env vars if cloud is Kubernetes."""
if isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
temp_envs = copy.deepcopy(task.envs)
cloud_env_vars = handle.launched_resources.cloud.query_env_vars(
handle.cluster_name)
task.update_envs(cloud_env_vars)

# Re update the envs with the original envs to give priority to
# the original envs.
task.update_envs(temp_envs)

def _setup(self, handle: CloudVmRayResourceHandle, task: task_lib.Task,
detach_setup: bool) -> None:
start = time.time()
style = colorama.Style
fore = colorama.Fore

self._update_envs_for_k8s(handle, task)

if task.setup is None:
return

Expand Down Expand Up @@ -3138,6 +3153,7 @@ def _execute(
# Check the task resources vs the cluster resources. Since `sky exec`
# will not run the provision and _check_existing_cluster
self.check_resources_fit_cluster(handle, task)
self._update_envs_for_k8s(handle, task)

resources_str = backend_utils.get_task_resources_str(task)

Expand Down
23 changes: 23 additions & 0 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,26 @@ def query_status(cls, name: str, tag_filters: Dict[str, str],
cluster_status.append(status_lib.ClusterStatus.INIT)
# If pods are not found, we don't add them to the return list
return cluster_status

@classmethod
def query_env_vars(cls, name: str) -> Dict[str, str]:
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
pod = kubernetes.core_api().list_namespaced_pod(
namespace,
label_selector=f'skypilot-cluster={name},ray-node-type=head'
).items[0]
response = kubernetes.stream()(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to set a timeout on this stream call (like we have _request_timeout=kubernetes.API_TIMEOUT in the rest of our code)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it should be. I'll add the timeout.

kubernetes.core_api().connect_get_namespaced_pod_exec,
pod.metadata.name,
namespace,
command=['env'],
stderr=True,
stdin=False,
stdout=True,
tty=False,
_request_timeout=kubernetes.API_TIMEOUT)
lines: List[List[str]] = [
line.split('=', 1) for line in response.split('\n') if '=' in line
]
return dict(
[line for line in lines if common_utils.is_valid_env_var(line[0])])
9 changes: 2 additions & 7 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sky.skylet import constants
from sky.utils import schemas
from sky.utils import ux_utils
from sky.utils import common_utils

if typing.TYPE_CHECKING:
from sky import resources as resources_lib
Expand All @@ -27,7 +28,6 @@
CommandOrCommandGen = Union[str, CommandGen]

_VALID_NAME_REGEX = '[a-z0-9]+(?:[._-]{1,2}[a-z0-9]+)*'
_VALID_ENV_VAR_REGEX = '[a-zA-Z_][a-zA-Z0-9_]*'
_VALID_NAME_DESCR = ('ASCII characters and may contain lowercase and'
' uppercase letters, digits, underscores, periods,'
' and dashes. Must start and end with alphanumeric'
Expand Down Expand Up @@ -64,11 +64,6 @@ def _is_valid_name(name: str) -> bool:
return bool(re.fullmatch(_VALID_NAME_REGEX, name))


def _is_valid_env_var(name: str) -> bool:
"""Checks if the task environment variable name is valid."""
return bool(re.fullmatch(_VALID_ENV_VAR_REGEX, name))


def _fill_in_env_vars_in_file_mounts(
file_mounts: Dict[str, Any],
task_envs: Dict[str, str],
Expand Down Expand Up @@ -446,7 +441,7 @@ def update_envs(
if not isinstance(key, str):
with ux_utils.print_exception_no_traceback():
raise ValueError('Env keys must be strings.')
if not _is_valid_env_var(key):
if not common_utils.is_valid_env_var(key):
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Invalid env key: {key}')
else:
Expand Down
7 changes: 7 additions & 0 deletions sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
_PAYLOAD_PATTERN = re.compile(r'<sky-payload>(.*)</sky-payload>')
_PAYLOAD_STR = '<sky-payload>{}</sky-payload>'

_VALID_ENV_VAR_REGEX = '[a-zA-Z_][a-zA-Z0-9_]*'

logger = sky_logging.init_logger(__name__)

_usage_run_id = None
Expand Down Expand Up @@ -409,3 +411,8 @@ def find_free_port(start_port: int) -> int:
except OSError:
pass
raise OSError('No free ports available.')


def is_valid_env_var(name: str) -> bool:
"""Checks if the task environment variable name is valid."""
return bool(re.fullmatch(_VALID_ENV_VAR_REGEX, name))