Skip to content

Commit

Permalink
[17/N][VirtualCluster] Support exclusive virtual cluster in Job Submi…
Browse files Browse the repository at this point in the history
…ssion (#435)

* [17/N][VirtualCluster] Enhance Job Submission with Job Cluster and Replica Set Support

Signed-off-by: sule <[email protected]>

* Fix naming and comments

Signed-off-by: sule <[email protected]>

* Add logs for error cases

Signed-off-by: sule <[email protected]>

---------

Signed-off-by: sule <[email protected]>
  • Loading branch information
xsuler authored Dec 24, 2024
1 parent fa2ae24 commit f924c52
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 8 deletions.
19 changes: 19 additions & 0 deletions python/ray/dashboard/modules/job/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ class JobSubmitRequest:
entrypoint_resources: Optional[Dict[str, float]] = None
# Optional virtual cluster ID for job.
virtual_cluster_id: Optional[str] = None
# Optional replica sets for job
replica_sets: Optional[Dict[str, int]] = None

def __post_init__(self):
if not isinstance(self.entrypoint, str):
Expand Down Expand Up @@ -521,6 +523,23 @@ def __post_init__(self):
f"got {type(self.virtual_cluster_id)}"
)

if self.replica_sets is not None:
if not isinstance(self.replica_sets, dict):
raise TypeError(
"replica_sets must be a dict, " f"got {type(self.replica_sets)}"
)
else:
for k in self.replica_sets.keys():
if not isinstance(k, str):
raise TypeError(
"replica_sets keys must be strings, " f"got {type(k)}"
)
for v in self.replica_sets.values():
if not isinstance(v, int):
raise TypeError(
"replica_sets values must be integers, " f"got {type(v)}"
)


@dataclass
class JobSubmitResponse:
Expand Down
48 changes: 48 additions & 0 deletions python/ray/dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
upload_package_to_gcs,
)
from ray._private.utils import get_or_create_event_loop
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import CreateJobClusterRequest
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.modules.job.common import (
JobDeleteResponse,
Expand All @@ -31,6 +33,7 @@
JobSubmitResponse,
http_uri_components_to_uri,
)
from ray.dashboard.modules.job.job_manager import generate_job_id
from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType
from ray.dashboard.modules.job.utils import (
find_job_by_ids,
Expand Down Expand Up @@ -163,6 +166,12 @@ def __init__(self, dashboard_head):
self._gcs_aio_client = dashboard_head.gcs_aio_client
self._job_info_client = None

self._gcs_virtual_cluster_info_stub = (
gcs_service_pb2_grpc.VirtualClusterInfoGcsServiceStub(
dashboard_head.aiogrpc_gcs_channel
)
)

# It contains all `JobAgentSubmissionClient` that
# `JobHead` has ever used, and will not be deleted
# from it unless `JobAgentSubmissionClient` is no
Expand Down Expand Up @@ -340,6 +349,30 @@ async def submit_job(self, req: Request) -> Response:
self.get_target_agent(),
timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT,
)

if (
submit_request.virtual_cluster_id is not None
and submit_request.replica_sets is not None
and len(submit_request.replica_sets) > 0
):
# Use the submission ID or generate a new one
submission_id = submit_request.submission_id or submit_request.job_id
if submission_id is None:
submit_request.submission_id = generate_job_id()
job_cluster_id = await self._create_job_cluster(
submit_request.submission_id,
submit_request.virtual_cluster_id,
submit_request.replica_sets,
)
# If cluster creation fails
if job_cluster_id is None:
return Response(
text="Create Job Cluster Failed.",
status=aiohttp.web.HTTPInternalServerError.status_code,
)
# Overwrite the virtual cluster ID in submit request
submit_request.virtual_cluster_id = job_cluster_id

resp = await job_agent_client.submit_job_internal(submit_request)
except asyncio.TimeoutError:
return Response(
Expand Down Expand Up @@ -580,6 +613,21 @@ def get_job_driver_agent_client(

return self._agents[driver_node_id]

async def _create_job_cluster(self, job_id, virtual_cluster_id, replica_sets):
request = CreateJobClusterRequest(
job_id=job_id,
virtual_cluster_id=virtual_cluster_id,
replica_sets=replica_sets,
)
reply = await (self._gcs_virtual_cluster_info_stub.CreateJobCluster(request))
if reply.status.code != 0:
logger.warning(
f"failed to create job cluster for {job_id} in"
f" {virtual_cluster_id}, message: {reply.status.message}"
)
return None
return reply.job_cluster_id

async def run(self, server):
if not self._job_info_client:
self._job_info_client = JobInfoStorageClient(self._gcs_aio_client)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/dashboard/modules/job/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def submit_job(
metadata: Optional[Dict[str, str]] = None,
submission_id: Optional[str] = None,
virtual_cluster_id: Optional[str] = None,
replica_sets: Optional[Dict[str, int]] = None,
entrypoint_num_cpus: Optional[Union[int, float]] = None,
entrypoint_num_gpus: Optional[Union[int, float]] = None,
entrypoint_memory: Optional[int] = None,
Expand Down Expand Up @@ -231,6 +232,7 @@ def submit_job(
entrypoint=entrypoint,
submission_id=submission_id,
virtual_cluster_id=virtual_cluster_id,
replica_sets=replica_sets,
runtime_env=runtime_env,
metadata=metadata,
entrypoint_num_cpus=entrypoint_num_cpus,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ray.tests.conftest import get_default_fixture_ray_kwargs

TEMPLATE_ID_PREFIX = "template_id_"
kPrimaryClusterID = "kPrimaryClusterID"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,11 +132,11 @@ async def create_virtual_cluster(
[
{
"_system_config": {"gcs_actor_scheduling_enabled": False},
"ntemplates": 5,
"ntemplates": 3,
},
{
"_system_config": {"gcs_actor_scheduling_enabled": True},
"ntemplates": 5,
"ntemplates": 3,
},
],
indirect=True,
Expand All @@ -145,7 +146,7 @@ async def test_mixed_virtual_cluster(job_sdk_client):
head_client, gcs_address, cluster = job_sdk_client
virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_"
node_to_virtual_cluster = {}
ntemplates = 5
ntemplates = 3
for i in range(ntemplates):
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
nodes = await create_virtual_cluster(
Expand Down Expand Up @@ -340,5 +341,229 @@ def _check_recover(
head_client.stop_job(job_id)


@pytest.mark.parametrize(
"job_sdk_client",
[
{
"_system_config": {"gcs_actor_scheduling_enabled": False},
"ntemplates": 4,
},
{
"_system_config": {"gcs_actor_scheduling_enabled": True},
"ntemplates": 4,
},
],
indirect=True,
)
@pytest.mark.asyncio
async def test_exclusive_virtual_cluster(job_sdk_client):
head_client, gcs_address, cluster = job_sdk_client
virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_"
node_to_virtual_cluster = {}
ntemplates = 3
for i in range(ntemplates):
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
nodes = await create_virtual_cluster(
gcs_address,
virtual_cluster_id,
{TEMPLATE_ID_PREFIX + str(i): 2},
AllocationMode.EXCLUSIVE,
)
for node_id in nodes:
assert node_id not in node_to_virtual_cluster
node_to_virtual_cluster[node_id] = virtual_cluster_id

for node in cluster.worker_nodes:
if node.node_id not in node_to_virtual_cluster:
node_to_virtual_cluster[node.node_id] = kPrimaryClusterID

@ray.remote
class ControlActor:
def __init__(self):
self._nodes = set()
self._ready = False

def ready(self):
self._ready = True

def is_ready(self):
return self._ready

def add_node(self, node_id):
self._nodes.add(node_id)

def nodes(self):
return self._nodes

for i in range(ntemplates + 1):
actor_name = f"test_actors_{i}"
pg_name = f"test_pgs_{i}"
control_actor_name = f"control_{i}"
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
if i == ntemplates:
virtual_cluster_id = kPrimaryClusterID
control_actor = ControlActor.options(
name=control_actor_name, namespace="control"
).remote()
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir)
driver_script = """
import ray
import time
import asyncio
ray.init(address="auto")
control = ray.get_actor(name="{control_actor_name}", namespace="control")
@ray.remote(max_restarts=10)
class Actor:
def __init__(self, control, pg):
node_id = ray.get_runtime_context().get_node_id()
ray.get(control.add_node.remote(node_id))
self._pg = pg
async def run(self, control):
node_id = ray.get_runtime_context().get_node_id()
await control.add_node.remote(node_id)
while True:
node_id = ray.util.placement_group_table(self._pg)["bundles_to_node_id"][0]
if node_id == "":
await asyncio.sleep(1)
continue
break
await control.add_node.remote(node_id)
await control.ready.remote()
while True:
await asyncio.sleep(1)
async def get_node_id(self):
while True:
node_id = ray.util.placement_group_table(pg)["bundles_to_node_id"][0]
if node_id == "":
await asyncio.sleep(1)
continue
break
return (ray.get_runtime_context().get_node_id(), node_id)
pg = ray.util.placement_group(
bundles=[{{"CPU": 1}}], name="{pg_name}", lifetime="detached"
)
@ray.remote
def hello(control):
node_id = ray.get_runtime_context().get_node_id()
ray.get(control.add_node.remote(node_id))
ray.get(hello.remote(control))
a = Actor.options(name="{actor_name}",
namespace="control",
num_cpus=1,
lifetime="detached").remote(
control, pg
)
ray.get(a.run.remote(control))
"""
driver_script = driver_script.format(
actor_name=actor_name,
pg_name=pg_name,
control_actor_name=control_actor_name,
)
test_script_file = path / "test_script.py"
with open(test_script_file, "w+") as file:
file.write(driver_script)

runtime_env = {"working_dir": tmp_dir}
runtime_env = upload_working_dir_if_needed(
runtime_env, tmp_dir, logger=logger
)
runtime_env = RuntimeEnv(**runtime_env).to_dict()

job_id = head_client.submit_job(
entrypoint="python test_script.py",
entrypoint_memory=1,
runtime_env=runtime_env,
virtual_cluster_id=virtual_cluster_id,
replica_sets={TEMPLATE_ID_PREFIX + str(i): 2},
)

def _check_ready(control_actor):
return ray.get(control_actor.is_ready.remote())

wait_for_condition(partial(_check_ready, control_actor), timeout=20)

def _check_virtual_cluster(
control_actor, node_to_virtual_cluster, virtual_cluster_id
):
nodes = ray.get(control_actor.nodes.remote())
assert len(nodes) > 0
for node in nodes:
assert node_to_virtual_cluster[node] == virtual_cluster_id
return True

wait_for_condition(
partial(
_check_virtual_cluster,
control_actor,
node_to_virtual_cluster,
virtual_cluster_id,
),
timeout=20,
)

supervisor_actor = ray.get_actor(
name=JOB_ACTOR_NAME_TEMPLATE.format(job_id=job_id),
namespace=SUPERVISOR_ACTOR_RAY_NAMESPACE,
)
actor_info = ray.state.actors(supervisor_actor._actor_id.hex())
driver_node_id = actor_info["Address"]["NodeID"]
assert node_to_virtual_cluster[driver_node_id] == virtual_cluster_id

job_info = head_client.get_job_info(job_id)
assert (
node_to_virtual_cluster[job_info.driver_node_id] == virtual_cluster_id
)

nodes_to_remove = ray.get(control_actor.nodes.remote())
if driver_node_id in nodes_to_remove:
nodes_to_remove.remove(driver_node_id)

to_remove = []
for node in cluster.worker_nodes:
if node.node_id in nodes_to_remove:
to_remove.append(node)
for node in to_remove:
cluster.remove_node(node)

def _check_recover(
nodes_to_remove, actor_name, node_to_virtual_cluster, virtual_cluster_id
):
actor = ray.get_actor(actor_name, namespace="control")
nodes = ray.get(actor.get_node_id.remote())
for node_id in nodes:
assert node_id not in nodes_to_remove
assert node_to_virtual_cluster[node_id] == virtual_cluster_id
return True

wait_for_condition(
partial(
_check_recover,
nodes_to_remove,
actor_name,
node_to_virtual_cluster,
virtual_cluster_id,
),
timeout=120,
)
head_client.stop_job(job_id)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading

0 comments on commit f924c52

Please sign in to comment.