Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[17/N][VirtualCluster] Support exclusive virtual cluster in Job Submission #435

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/ray/dashboard/modules/job/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ class JobSubmitRequest:
entrypoint_resources: Optional[Dict[str, float]] = None
# Optional virtual cluster ID for job.
virtual_cluster_id: Optional[str] = None
# Optional replica sets for job
replica_sets: Optional[Dict[str, int]] = None

def __post_init__(self):
if not isinstance(self.entrypoint, str):
Expand Down Expand Up @@ -521,6 +523,23 @@ def __post_init__(self):
f"got {type(self.virtual_cluster_id)}"
)

if self.replica_sets is not None:
if not isinstance(self.replica_sets, dict):
raise TypeError(
"replica_sets must be a dict, " f"got {type(self.replica_sets)}"
)
else:
for k in self.replica_sets.keys():
if not isinstance(k, str):
raise TypeError(
"replica_sets keys must be strings, " f"got {type(k)}"
)
for v in self.replica_sets.values():
if not isinstance(v, int):
raise TypeError(
"replica_sets values must be integers, " f"got {type(v)}"
)


@dataclass
class JobSubmitResponse:
Expand Down
48 changes: 48 additions & 0 deletions python/ray/dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
upload_package_to_gcs,
)
from ray._private.utils import get_or_create_event_loop
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import CreateJobClusterRequest
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.modules.job.common import (
JobDeleteResponse,
Expand All @@ -31,6 +33,7 @@
JobSubmitResponse,
http_uri_components_to_uri,
)
from ray.dashboard.modules.job.job_manager import generate_job_id
from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType
from ray.dashboard.modules.job.utils import (
find_job_by_ids,
Expand Down Expand Up @@ -163,6 +166,12 @@ def __init__(self, dashboard_head):
self._gcs_aio_client = dashboard_head.gcs_aio_client
self._job_info_client = None

self._gcs_virtual_cluster_info_stub = (
gcs_service_pb2_grpc.VirtualClusterInfoGcsServiceStub(
dashboard_head.aiogrpc_gcs_channel
)
)

# It contains all `JobAgentSubmissionClient` that
# `JobHead` has ever used, and will not be deleted
# from it unless `JobAgentSubmissionClient` is no
Expand Down Expand Up @@ -340,6 +349,30 @@ async def submit_job(self, req: Request) -> Response:
self.get_target_agent(),
timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT,
)

if (
submit_request.virtual_cluster_id is not None
and submit_request.replica_sets is not None
and len(submit_request.replica_sets) > 0
):
# Use the submission ID or generate a new one
submission_id = submit_request.submission_id or submit_request.job_id
if submission_id is None:
submit_request.submission_id = generate_job_id()
job_cluster_id = await self._create_job_cluster(
submit_request.submission_id,
submit_request.virtual_cluster_id,
submit_request.replica_sets,
)
# If cluster creation fails
if job_cluster_id is None:
return Response(
text="Create Job Cluster Failed.",
status=aiohttp.web.HTTPInternalServerError.status_code,
)
# Overwrite the virtual cluster ID in submit request
submit_request.virtual_cluster_id = job_cluster_id

resp = await job_agent_client.submit_job_internal(submit_request)
except asyncio.TimeoutError:
return Response(
Expand Down Expand Up @@ -580,6 +613,21 @@ def get_job_driver_agent_client(

return self._agents[driver_node_id]

async def _create_job_cluster(self, job_id, virtual_cluster_id, replica_sets):
request = CreateJobClusterRequest(
job_id=job_id,
virtual_cluster_id=virtual_cluster_id,
replica_sets=replica_sets,
)
reply = await (self._gcs_virtual_cluster_info_stub.CreateJobCluster(request))
if reply.status.code != 0:
logger.warning(
f"failed to create job cluster for {job_id} in"
f" {virtual_cluster_id}, message: {reply.status.message}"
)
return None
return reply.job_cluster_id

async def run(self, server):
if not self._job_info_client:
self._job_info_client = JobInfoStorageClient(self._gcs_aio_client)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/dashboard/modules/job/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def submit_job(
metadata: Optional[Dict[str, str]] = None,
submission_id: Optional[str] = None,
virtual_cluster_id: Optional[str] = None,
replica_sets: Optional[Dict[str, int]] = None,
entrypoint_num_cpus: Optional[Union[int, float]] = None,
entrypoint_num_gpus: Optional[Union[int, float]] = None,
entrypoint_memory: Optional[int] = None,
Expand Down Expand Up @@ -231,6 +232,7 @@ def submit_job(
entrypoint=entrypoint,
submission_id=submission_id,
virtual_cluster_id=virtual_cluster_id,
replica_sets=replica_sets,
runtime_env=runtime_env,
metadata=metadata,
entrypoint_num_cpus=entrypoint_num_cpus,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ray.tests.conftest import get_default_fixture_ray_kwargs

TEMPLATE_ID_PREFIX = "template_id_"
kPrimaryClusterID = "kPrimaryClusterID"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,11 +132,11 @@ async def create_virtual_cluster(
[
{
"_system_config": {"gcs_actor_scheduling_enabled": False},
"ntemplates": 5,
"ntemplates": 3,
},
{
"_system_config": {"gcs_actor_scheduling_enabled": True},
"ntemplates": 5,
"ntemplates": 3,
},
],
indirect=True,
Expand All @@ -145,7 +146,7 @@ async def test_mixed_virtual_cluster(job_sdk_client):
head_client, gcs_address, cluster = job_sdk_client
virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_"
node_to_virtual_cluster = {}
ntemplates = 5
ntemplates = 3
for i in range(ntemplates):
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
nodes = await create_virtual_cluster(
Expand Down Expand Up @@ -340,5 +341,229 @@ def _check_recover(
head_client.stop_job(job_id)


@pytest.mark.parametrize(
"job_sdk_client",
[
{
"_system_config": {"gcs_actor_scheduling_enabled": False},
"ntemplates": 4,
},
{
"_system_config": {"gcs_actor_scheduling_enabled": True},
"ntemplates": 4,
},
],
indirect=True,
)
@pytest.mark.asyncio
async def test_exclusive_virtual_cluster(job_sdk_client):
head_client, gcs_address, cluster = job_sdk_client
virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_"
node_to_virtual_cluster = {}
ntemplates = 3
for i in range(ntemplates):
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
nodes = await create_virtual_cluster(
gcs_address,
virtual_cluster_id,
{TEMPLATE_ID_PREFIX + str(i): 2},
AllocationMode.EXCLUSIVE,
)
for node_id in nodes:
assert node_id not in node_to_virtual_cluster
node_to_virtual_cluster[node_id] = virtual_cluster_id

for node in cluster.worker_nodes:
if node.node_id not in node_to_virtual_cluster:
node_to_virtual_cluster[node.node_id] = kPrimaryClusterID

@ray.remote
class ControlActor:
def __init__(self):
self._nodes = set()
self._ready = False

def ready(self):
self._ready = True

def is_ready(self):
return self._ready

def add_node(self, node_id):
self._nodes.add(node_id)

def nodes(self):
return self._nodes

for i in range(ntemplates + 1):
actor_name = f"test_actors_{i}"
pg_name = f"test_pgs_{i}"
control_actor_name = f"control_{i}"
virtual_cluster_id = virtual_cluster_id_prefix + str(i)
if i == ntemplates:
virtual_cluster_id = kPrimaryClusterID
control_actor = ControlActor.options(
name=control_actor_name, namespace="control"
).remote()
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir)
driver_script = """
import ray
import time
import asyncio

ray.init(address="auto")

control = ray.get_actor(name="{control_actor_name}", namespace="control")


@ray.remote(max_restarts=10)
class Actor:
def __init__(self, control, pg):
node_id = ray.get_runtime_context().get_node_id()
ray.get(control.add_node.remote(node_id))
self._pg = pg

async def run(self, control):
node_id = ray.get_runtime_context().get_node_id()
await control.add_node.remote(node_id)

while True:
node_id = ray.util.placement_group_table(self._pg)["bundles_to_node_id"][0]
if node_id == "":
await asyncio.sleep(1)
continue
break

await control.add_node.remote(node_id)

await control.ready.remote()
while True:
await asyncio.sleep(1)

async def get_node_id(self):
while True:
node_id = ray.util.placement_group_table(pg)["bundles_to_node_id"][0]
if node_id == "":
await asyncio.sleep(1)
continue
break
return (ray.get_runtime_context().get_node_id(), node_id)


pg = ray.util.placement_group(
bundles=[{{"CPU": 1}}], name="{pg_name}", lifetime="detached"
)


@ray.remote
def hello(control):
node_id = ray.get_runtime_context().get_node_id()
ray.get(control.add_node.remote(node_id))


ray.get(hello.remote(control))
a = Actor.options(name="{actor_name}",
namespace="control",
num_cpus=1,
lifetime="detached").remote(
control, pg
)
ray.get(a.run.remote(control))
"""
driver_script = driver_script.format(
actor_name=actor_name,
pg_name=pg_name,
control_actor_name=control_actor_name,
)
test_script_file = path / "test_script.py"
with open(test_script_file, "w+") as file:
file.write(driver_script)

runtime_env = {"working_dir": tmp_dir}
runtime_env = upload_working_dir_if_needed(
runtime_env, tmp_dir, logger=logger
)
runtime_env = RuntimeEnv(**runtime_env).to_dict()

job_id = head_client.submit_job(
entrypoint="python test_script.py",
entrypoint_memory=1,
runtime_env=runtime_env,
virtual_cluster_id=virtual_cluster_id,
replica_sets={TEMPLATE_ID_PREFIX + str(i): 2},
)

def _check_ready(control_actor):
return ray.get(control_actor.is_ready.remote())

wait_for_condition(partial(_check_ready, control_actor), timeout=20)

def _check_virtual_cluster(
control_actor, node_to_virtual_cluster, virtual_cluster_id
):
nodes = ray.get(control_actor.nodes.remote())
assert len(nodes) > 0
for node in nodes:
assert node_to_virtual_cluster[node] == virtual_cluster_id
return True

wait_for_condition(
partial(
_check_virtual_cluster,
control_actor,
node_to_virtual_cluster,
virtual_cluster_id,
),
timeout=20,
)

supervisor_actor = ray.get_actor(
name=JOB_ACTOR_NAME_TEMPLATE.format(job_id=job_id),
namespace=SUPERVISOR_ACTOR_RAY_NAMESPACE,
)
actor_info = ray.state.actors(supervisor_actor._actor_id.hex())
driver_node_id = actor_info["Address"]["NodeID"]
assert node_to_virtual_cluster[driver_node_id] == virtual_cluster_id

job_info = head_client.get_job_info(job_id)
assert (
node_to_virtual_cluster[job_info.driver_node_id] == virtual_cluster_id
)

nodes_to_remove = ray.get(control_actor.nodes.remote())
if driver_node_id in nodes_to_remove:
nodes_to_remove.remove(driver_node_id)

to_remove = []
for node in cluster.worker_nodes:
if node.node_id in nodes_to_remove:
to_remove.append(node)
for node in to_remove:
cluster.remove_node(node)

def _check_recover(
nodes_to_remove, actor_name, node_to_virtual_cluster, virtual_cluster_id
):
actor = ray.get_actor(actor_name, namespace="control")
nodes = ray.get(actor.get_node_id.remote())
for node_id in nodes:
assert node_id not in nodes_to_remove
assert node_to_virtual_cluster[node_id] == virtual_cluster_id
return True

wait_for_condition(
partial(
_check_recover,
nodes_to_remove,
actor_name,
node_to_virtual_cluster,
virtual_cluster_id,
),
timeout=120,
)
head_client.stop_job(job_id)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading