diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index f3d855d479b..7f490743f8b 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -150,6 +150,18 @@ _MAX_INLINE_SCRIPT_LENGTH = 120 * 1024 +def _is_command_length_over_limit(command: str) -> bool: + """Check if the length of the command exceeds the limit. + + We calculate the length of the command after quoting the command twice as + when it is executed by the CommandRunner, the command will be quoted twice + to ensure the correctness, which will add significant length to the command. + """ + + quoted_length = len(shlex.quote(shlex.quote(command))) + return quoted_length > _MAX_INLINE_SCRIPT_LENGTH + + def _get_cluster_config_template(cloud): cloud_to_template = { clouds.AWS: 'aws-ray.yml.j2', @@ -3159,8 +3171,7 @@ def _setup_node(node_id: int) -> None: setup_script = log_lib.make_task_bash_script(setup, env_vars=setup_envs) encoded_script = shlex.quote(setup_script) - if (detach_setup or - len(encoded_script) > _MAX_INLINE_SCRIPT_LENGTH): + if detach_setup or _is_command_length_over_limit(encoded_script): with tempfile.NamedTemporaryFile('w', prefix='sky_setup_') as f: f.write(setup_script) f.flush() @@ -3271,7 +3282,7 @@ def _exec_code_on_head( code = job_lib.JobLibCodeGen.queue_job(job_id, job_submit_cmd) job_submit_cmd = ' && '.join([mkdir_code, create_script_code, code]) - if len(job_submit_cmd) > _MAX_INLINE_SCRIPT_LENGTH: + if _is_command_length_over_limit(job_submit_cmd): runners = handle.get_command_runners() head_runner = runners[0] with tempfile.NamedTemporaryFile('w', prefix='sky_app_') as fp: