Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[17/N][VirtualCluster] Support exclusive virtual cluster in Job Submission #435

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[17/N][VirtualCluster] Enhance Job Submission with Job Cluster and Re…
…plica Set Support

Signed-off-by: sule <[email protected]>
xsuler committed Dec 24, 2024
commit 06d2fa631bbe379a251c09f302cc908d4f670646
19 changes: 19 additions & 0 deletions python/ray/dashboard/modules/job/common.py
Original file line number Diff line number Diff line change
@@ -428,6 +428,8 @@ class JobSubmitRequest:
entrypoint_resources: Optional[Dict[str, float]] = None
# Optional virtual cluster ID for job.
virtual_cluster_id: Optional[str] = None
# Optional replica sets for job
replica_sets: Optional[Dict[str, int]] = None

def __post_init__(self):
if not isinstance(self.entrypoint, str):
@@ -521,6 +523,23 @@ def __post_init__(self):
f"got {type(self.virtual_cluster_id)}"
)

if self.replica_sets is not None:
if not isinstance(self.replica_sets, dict):
raise TypeError(
"replica_sets must be a dict, " f"got {type(self.replica_sets)}"
)
else:
for k in self.replica_sets.keys():
if not isinstance(k, str):
raise TypeError(
"replica_sets keys must be strings, " f"got {type(k)}"
)
for v in self.replica_sets.values():
if not isinstance(v, int):
raise TypeError(
"replica_sets values must be integers, " f"got {type(v)}"
)


@dataclass
class JobSubmitResponse:
47 changes: 47 additions & 0 deletions python/ray/dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@
upload_package_to_gcs,
)
from ray._private.utils import get_or_create_event_loop
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import CreateJobClusterRequest
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.modules.job.common import (
JobDeleteResponse,
@@ -31,6 +33,7 @@
JobSubmitResponse,
http_uri_components_to_uri,
)
from ray.dashboard.modules.job.job_manager import generate_job_id
from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType
from ray.dashboard.modules.job.utils import (
find_job_by_ids,
@@ -163,6 +166,12 @@ def __init__(self, dashboard_head):
self._gcs_aio_client = dashboard_head.gcs_aio_client
self._job_info_client = None

self._gcs_virtual_cluster_info_stub = (
gcs_service_pb2_grpc.VirtualClusterInfoGcsServiceStub(
dashboard_head.aiogrpc_gcs_channel
)
)

# It contains all `JobAgentSubmissionClient` that
# `JobHead` has ever used, and will not be deleted
# from it unless `JobAgentSubmissionClient` is no
@@ -340,6 +349,30 @@ async def submit_job(self, req: Request) -> Response:
self.get_target_agent(),
timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT,
)

if (
submit_request.virtual_cluster_id is not None
and submit_request.replica_sets is not None
and len(submit_request.replica_sets) > 0
):
# Use the submission ID or generate a new one
submission_id = submit_request.submission_id or submit_request.job_id
if submission_id is None:
submit_request.submission_id = generate_job_id()
job_cluster_id = await self._create_job_cluster(
submit_request.submission_id,
submit_request.virtual_cluster_id,
submit_request.replica_sets,
)
# If cluster creation fails
if job_cluster_id is None:
return Response(
text="Create Job Cluster Failed.",
status=aiohttp.web.HTTPInternalServerError.status_code,
)
# Overwrite the virtual cluster ID in submit request
submit_request.virtual_cluster_id = job_cluster_id

resp = await job_agent_client.submit_job_internal(submit_request)
except asyncio.TimeoutError:
return Response(
@@ -580,6 +613,20 @@ def get_job_driver_agent_client(

return self._agents[driver_node_id]

async def _create_job_cluster(self, job_id, virtual_cluster_id, replica_sets):
request = CreateJobClusterRequest(
job_id=job_id,
virtual_cluster_id=virtual_cluster_id,
replica_sets=replica_sets,
)
reply = await (self._gcs_virtual_cluster_info_stub.CreateJobCluster(request))
if reply.status.code != 0:
logger.warning(
f"failed to create job cluster for {job_id} in {virtual_cluster_id}"
)
return None
return reply.job_cluster_id

async def run(self, server):
if not self._job_info_client:
self._job_info_client = JobInfoStorageClient(self._gcs_aio_client)
2 changes: 2 additions & 0 deletions python/ray/dashboard/modules/job/sdk.py
Original file line number Diff line number Diff line change
@@ -131,6 +131,7 @@ def submit_job(
metadata: Optional[Dict[str, str]] = None,
submission_id: Optional[str] = None,
virtual_cluster_id: Optional[str] = None,
replica_sets: Optional[Dict[str, int]] = None,
entrypoint_num_cpus: Optional[Union[int, float]] = None,
entrypoint_num_gpus: Optional[Union[int, float]] = None,
entrypoint_memory: Optional[int] = None,
@@ -231,6 +232,7 @@ def submit_job(
entrypoint=entrypoint,
submission_id=submission_id,
virtual_cluster_id=virtual_cluster_id,
replica_sets=replica_sets,
runtime_env=runtime_env,
metadata=metadata,
entrypoint_num_cpus=entrypoint_num_cpus,
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
from ray.tests.conftest import get_default_fixture_ray_kwargs

TEMPLATE_ID_PREFIX = "template_id_"
kPrimaryClusterID = "kPrimaryClusterID"

logger = logging.getLogger(__name__)

@@ -131,11 +132,11 @@ async def create_virtual_cluster(
[
{
"_system_config": {"gcs_actor_scheduling_enabled": False},
"ntemplates": 5,
"ntemplates": 3,
},
{
"_system_config": {"gcs_actor_scheduling_enabled": True},
"ntemplates": 5,
"ntemplates": 3,
},
],
indirect=True,
@@ -145,7 +146,7 @@ async def test_mixed_virtual_cluster(job_sdk_client):
head_client, gcs_address, cluster = job_sdk_client
virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_"
node_to_virtual_cluster = {}
ntemplates = 5
ntemplates = 3
for i in range(ntemplates):
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
nodes = await create_virtual_cluster(
@@ -340,5 +341,229 @@ def _check_recover(
head_client.stop_job(job_id)


@pytest.mark.parametrize(
"job_sdk_client",
[
{
"_system_config": {"gcs_actor_scheduling_enabled": False},
"ntemplates": 4,
},
{
"_system_config": {"gcs_actor_scheduling_enabled": True},
"ntemplates": 4,
},
],
indirect=True,
)
@pytest.mark.asyncio
async def test_exclusive_virtual_cluster(job_sdk_client):
head_client, gcs_address, cluster = job_sdk_client
virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_"
node_to_virtual_cluster = {}
ntemplates = 3
for i in range(ntemplates):
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
nodes = await create_virtual_cluster(
gcs_address,
virtual_cluster_id,
{TEMPLATE_ID_PREFIX + str(i): 2},
AllocationMode.EXCLUSIVE,
)
for node_id in nodes:
assert node_id not in node_to_virtual_cluster
node_to_virtual_cluster[node_id] = virtual_cluster_id

for node in cluster.worker_nodes:
if node.node_id not in node_to_virtual_cluster:
node_to_virtual_cluster[node.node_id] = kPrimaryClusterID

@ray.remote
class ControlActor:
def __init__(self):
self._nodes = set()
self._ready = False

def ready(self):
self._ready = True

def is_ready(self):
return self._ready

def add_node(self, node_id):
self._nodes.add(node_id)

def nodes(self):
return self._nodes

for i in range(ntemplates + 1):
actor_name = f"test_actors_{i}"
pg_name = f"test_pgs_{i}"
control_actor_name = f"control_{i}"
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
if i == ntemplates:
virtual_cluster_id = kPrimaryClusterID
control_actor = ControlActor.options(
name=control_actor_name, namespace="control"
).remote()
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir)
driver_script = """
import ray
import time
import asyncio

ray.init(address="auto")

control = ray.get_actor(name="{control_actor_name}", namespace="control")


@ray.remote(max_restarts=10)
class Actor:
def __init__(self, control, pg):
node_id = ray.get_runtime_context().get_node_id()
ray.get(control.add_node.remote(node_id))
self._pg = pg

async def run(self, control):
node_id = ray.get_runtime_context().get_node_id()
await control.add_node.remote(node_id)

while True:
node_id = ray.util.placement_group_table(self._pg)["bundles_to_node_id"][0]
if node_id == "":
await asyncio.sleep(1)
continue
break

await control.add_node.remote(node_id)

await control.ready.remote()
while True:
await asyncio.sleep(1)

async def get_node_id(self):
while True:
node_id = ray.util.placement_group_table(pg)["bundles_to_node_id"][0]
if node_id == "":
await asyncio.sleep(1)
continue
break
return (ray.get_runtime_context().get_node_id(), node_id)


pg = ray.util.placement_group(
bundles=[{{"CPU": 1}}], name="{pg_name}", lifetime="detached"
)


@ray.remote
def hello(control):
node_id = ray.get_runtime_context().get_node_id()
ray.get(control.add_node.remote(node_id))


ray.get(hello.remote(control))
a = Actor.options(name="{actor_name}",
namespace="control",
num_cpus=1,
lifetime="detached").remote(
control, pg
)
ray.get(a.run.remote(control))
"""
driver_script = driver_script.format(
actor_name=actor_name,
pg_name=pg_name,
control_actor_name=control_actor_name,
)
test_script_file = path / "test_script.py"
with open(test_script_file, "w+") as file:
file.write(driver_script)

runtime_env = {"working_dir": tmp_dir}
runtime_env = upload_working_dir_if_needed(
runtime_env, tmp_dir, logger=logger
)
runtime_env = RuntimeEnv(**runtime_env).to_dict()

job_id = head_client.submit_job(
entrypoint="python test_script.py",
entrypoint_memory=1,
runtime_env=runtime_env,
virtual_cluster_id=virtual_cluster_id,
replica_sets={TEMPLATE_ID_PREFIX + str(i): 2},
)

def _check_ready(control_actor):
return ray.get(control_actor.is_ready.remote())

wait_for_condition(partial(_check_ready, control_actor), timeout=20)

def _check_virtual_cluster(
control_actor, node_to_virtual_cluster, virtual_cluster_id
):
nodes = ray.get(control_actor.nodes.remote())
assert len(nodes) > 0
for node in nodes:
assert node_to_virtual_cluster[node] == virtual_cluster_id
return True

wait_for_condition(
partial(
_check_virtual_cluster,
control_actor,
node_to_virtual_cluster,
virtual_cluster_id,
),
timeout=20,
)

supervisor_actor = ray.get_actor(
name=JOB_ACTOR_NAME_TEMPLATE.format(job_id=job_id),
namespace=SUPERVISOR_ACTOR_RAY_NAMESPACE,
)
actor_info = ray.state.actors(supervisor_actor._actor_id.hex())
driver_node_id = actor_info["Address"]["NodeID"]
assert node_to_virtual_cluster[driver_node_id] == virtual_cluster_id

job_info = head_client.get_job_info(job_id)
assert (
node_to_virtual_cluster[job_info.driver_node_id] == virtual_cluster_id
)

nodes_to_remove = ray.get(control_actor.nodes.remote())
if driver_node_id in nodes_to_remove:
nodes_to_remove.remove(driver_node_id)

to_remove = []
for node in cluster.worker_nodes:
if node.node_id in nodes_to_remove:
to_remove.append(node)
for node in to_remove:
cluster.remove_node(node)

def _check_recover(
nodes_to_remove, actor_name, node_to_virtual_cluster, virtual_cluster_id
):
actor = ray.get_actor(actor_name, namespace="control")
nodes = ray.get(actor.get_node_id.remote())
for node_id in nodes:
assert node_id not in nodes_to_remove
assert node_to_virtual_cluster[node_id] == virtual_cluster_id
return True

wait_for_condition(
partial(
_check_recover,
nodes_to_remove,
actor_name,
node_to_virtual_cluster,
virtual_cluster_id,
),
timeout=120,
)
head_client.stop_job(job_id)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading