Skip to content

Commit

Permalink
tests, sampling logic
Browse files Browse the repository at this point in the history
Signed-off-by: Ruiyang Wang <[email protected]>
  • Loading branch information
rynewang committed Jan 18, 2025
1 parent 45101dc commit 4ffae76
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 51 deletions.
45 changes: 28 additions & 17 deletions python/ray/dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import json
import logging
import traceback
from random import sample
from random import choice
from typing import AsyncIterator, Dict, List, Optional, Tuple

import aiohttp.web
from aiohttp.client import ClientResponse
from aiohttp.web import Request, Response

import ray
from ray import NodeID
import ray.dashboard.consts as dashboard_consts
from ray.dashboard.consts import (
GCS_RPC_TIMEOUT_SECONDS,
DASHBOARD_AGENT_ADDR_PREFIX,
CANDIDATE_AGENT_NUMBER,
TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS,
WAIT_AVAILABLE_AGENT_TIMEOUT,
)
Expand Down Expand Up @@ -171,8 +172,8 @@ def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig):
# `JobHead` has ever used, and will not be deleted
# from it unless `JobAgentSubmissionClient` is no
# longer available (the corresponding agent process is dead)
# {node_id_hex: JobAgentSubmissionClient}
self._agents: Dict[str, JobAgentSubmissionClient] = dict()
# {node_id: JobAgentSubmissionClient}
self._agents: Dict[NodeID, JobAgentSubmissionClient] = dict()

async def get_target_agent(self) -> Optional[JobAgentSubmissionClient]:
if RAY_JOB_AGENT_USE_HEAD_NODE_ONLY:
Expand Down Expand Up @@ -204,13 +205,13 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]:
client = self._agents.pop(dead_node)
await client.close()

if len(self._agents) >= CANDIDATE_AGENT_NUMBER:
node_id = sample(list(set(self._agents)), 1)[0]
if len(self._agents) >= dashboard_consts.CANDIDATE_AGENT_NUMBER:
node_id = choice(list(self._agents))
return self._agents[node_id]
else:
# Randomly select one from among all agents, it is possible that
# the selected one already exists in `self._agents`
node_id = sample(sorted(agent_infos), 1)[0]
node_id = choice(list(agent_infos))
agent_info = agent_infos[node_id]

if node_id not in self._agents:
Expand All @@ -223,16 +224,20 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]:
async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]:
"""Retrieves HTTP client for `JobAgent` running on the Head node"""

head_node_id = await get_head_node_id(self.gcs_aio_client)
head_node_id_binary = await get_head_node_id(self.gcs_aio_client)

if not head_node_id:
if not head_node_id_binary:
logger.warning("Head node id has not yet been persisted in GCS")
return None

head_node_id = NodeID.from_hex(head_node_id_binary)

if head_node_id not in self._agents:
agent_infos = await self._fetch_agent_infos([head_node_id])
if head_node_id not in agent_infos:
logger.error("Head node agent's information was not found")
logger.error(
f"Head node agent's information was not found: {head_node_id} not in {agent_infos}"
)
return None

ip, http_port, grpc_port = agent_infos[head_node_id]
Expand All @@ -242,7 +247,7 @@ async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]:

return self._agents[head_node_id]

async def _fetch_all_agent_infos(self) -> Dict[str, Tuple[str, int, int]]:
async def _fetch_all_agent_infos(self) -> Dict[NodeID, Tuple[str, int, int]]:
"""
Fetches all agent infos for all nodes in the cluster.
Expand All @@ -268,8 +273,11 @@ async def _fetch_all_agent_infos(self) -> Dict[str, Tuple[str, int, int]]:
namespace=KV_NAMESPACE_DASHBOARD,
timeout=GCS_RPC_TIMEOUT_SECONDS,
)
prefix_len = len(DASHBOARD_AGENT_ADDR_PREFIX)
return {
key.decode(): json.loads(value.decode())
NodeID.from_hex(key[prefix_len:].decode()): json.loads(
value.decode()
)
for key, value in values.items()
}

Expand All @@ -280,8 +288,8 @@ async def _fetch_all_agent_infos(self) -> Dict[str, Tuple[str, int, int]]:
await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS)

async def _fetch_agent_infos(
self, target_node_ids: List[str]
) -> Dict[str, Tuple[str, int, int]]:
self, target_node_ids: List[NodeID]
) -> Dict[NodeID, Tuple[str, int, int]]:
"""
Fetches agent infos for nodes identified by provided node-ids.
Expand All @@ -294,8 +302,8 @@ async def _fetch_agent_infos(
while True:
try:
keys = [
f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id_hex}"
for node_id_hex in target_node_ids
f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}"
for node_id in target_node_ids
]
values: Dict[
bytes, bytes
Expand All @@ -307,8 +315,11 @@ async def _fetch_agent_infos(
if not values or len(values) != len(target_node_ids):
# Not all agent infos found, retry
raise Exception()
prefix_len = len(DASHBOARD_AGENT_ADDR_PREFIX)
return {
key.decode(): json.loads(value.decode())
NodeID.from_hex(key[prefix_len:].decode()): json.loads(
value.decode()
)
for key, value in values.items()
}
except Exception:
Expand Down
81 changes: 51 additions & 30 deletions python/ray/dashboard/modules/job/tests/test_http_job_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@
import tempfile
import time
from pathlib import Path
from typing import Optional
from typing import Optional, List
from unittest.mock import patch

import pytest
import yaml

import ray
from ray import NodeID
from ray._private.test_utils import (
chdir,
format_web_url,
ray_constants,
wait_for_condition,
wait_until_server_available,
)
from ray.dashboard.consts import DASHBOARD_AGENT_ADDR_PREFIX
from ray.dashboard.modules.dashboard_sdk import ClusterInfo, parse_cluster_info
from ray.dashboard.modules.job.job_head import JobHead
from ray.dashboard.modules.job.pydantic_models import JobDetails
Expand Down Expand Up @@ -736,29 +738,58 @@ async def test_job_head_pick_random_job_agent(monkeypatch):

importlib.reload(ray.dashboard.consts)

from ray.dashboard.datacenter import DataSource
# Fake GCS client
class _FakeGcsClient:
def __init__(self):
self._kv = {}

async def internal_kv_put(self, key: bytes, value: bytes, **kwargs):
self._kv[key] = value

async def internal_kv_get(self, key: bytes, **kwargs):
return self._kv.get(key, None)

async def internal_kv_multi_get(self, keys: List[bytes], **kwargs):
return {key: self._kv.get(key, None) for key in keys}

async def internal_kv_del(self, key: bytes, **kwargs):
self._kv.pop(key)

async def internal_kv_keys(self, prefix: bytes, **kwargs):
return [key for key in self._kv.keys() if key.startswith(prefix)]

class MockJobHead(JobHead):
def __init__(self):
self._agents = dict()

DataSource.nodes = {}
job_head = MockJobHead()
job_head._gcs_aio_client = _FakeGcsClient()

def add_agent(agent):
async def add_agent(agent):
node_id = agent[0]
node_ip = agent[1]["ipAddress"]
http_port = agent[1]["httpPort"]
grpc_port = agent[1]["grpcPort"]
DataSource.nodes[node_id] = {"nodeManagerAddress": node_ip}
# DO NOT SUBMIT: changed to internal kv on node_id
DataSource.agents[node_id] = (node_ip, http_port, grpc_port)

def del_agent(agent):
node_id = agent[0]
DataSource.nodes.pop(node_id)
await job_head._gcs_aio_client.internal_kv_put(
f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}".encode(),
json.dumps([node_ip, http_port, grpc_port]).encode(),
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)

head_node_id = "node1"
async def del_agent(agent):
node_id = agent[0]
await job_head._gcs_aio_client.internal_kv_del(
f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}".encode(),
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)

head_node_id = NodeID.from_random()
await job_head._gcs_aio_client.internal_kv_put(
ray_constants.KV_HEAD_NODE_ID_KEY,
head_node_id.hex().encode(),
namespace=ray_constants.KV_NAMESPACE_JOB,
)

agent_1 = (
head_node_id,
Expand All @@ -770,7 +801,7 @@ def del_agent(agent):
),
)
agent_2 = (
"node2",
NodeID.from_random(),
dict(
ipAddress="2.2.2.2",
httpPort=2,
Expand All @@ -779,7 +810,7 @@ def del_agent(agent):
),
)
agent_3 = (
"node3",
NodeID.from_random(),
dict(
ipAddress="3.3.3.3",
httpPort=3,
Expand All @@ -795,12 +826,12 @@ def del_agent(agent):
)

# Check only 1 agent present, only agent being returned
add_agent(agent_1)
await add_agent(agent_1)
job_agent_client = await job_head.get_target_agent()
assert job_agent_client._agent_address == "http://1.1.1.1:1"

# Remove only agent, no agents present, should time out
del_agent(agent_1)
await del_agent(agent_1)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(job_head.get_target_agent(), timeout=3)

Expand All @@ -811,19 +842,9 @@ def del_agent(agent):
)

# Add 3 agents
add_agent(agent_1)
add_agent(agent_2)
add_agent(agent_3)

# Mock GCS client
class _MockedGCSClient:
async def internal_kv_get(self, key: bytes, **kwargs):
if key == ray_constants.KV_HEAD_NODE_ID_KEY:
return head_node_id.encode()

return None

job_head._gcs_aio_client = _MockedGCSClient()
await add_agent(agent_1)
await add_agent(agent_2)
await add_agent(agent_3)

# Make sure returned agent is a head-node
# NOTE: We run 3 tims to make sure we're not hitting branch probabilistically
Expand Down Expand Up @@ -852,7 +873,7 @@ async def internal_kv_get(self, key: bytes, **kwargs):
for agent in [agent_1, agent_2, agent_3]:
if f"http://{agent[1]['httpAddress']}" in addresses_2:
break
del_agent(agent)
await del_agent(agent)

# Theoretically, the probability of failure is 1/2^100
addresses_3 = set()
Expand All @@ -870,7 +891,7 @@ async def internal_kv_get(self, key: bytes, **kwargs):
for agent in [agent_1, agent_2, agent_3]:
if f"http://{agent[1]['httpAddress']}" in addresses_4:
break
del_agent(agent)
await del_agent(agent)
address = None
for _ in range(3):
job_agent_client = await job_head.get_target_agent()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/dashboard/state_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ async def get_agent_address(
returns a tuple of (ip, http_port, grpc_port).
If either of them are not found, return None.
If not found, return None.
"""
agent_addr_json = await gcs_aio_client.internal_kv_get(
f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}".encode(),
Expand Down
15 changes: 12 additions & 3 deletions python/ray/dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,19 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
raise ex
assert dump_info["result"] is True
dump_data = dump_info["data"]
assert len(dump_data["agents"]) == 1
node_id, (node_ip, http_port, grpc_port) = next(
iter(dump_data["agents"].items())
assert len(dump_data["nodes"]) == 1

node_id_hex = ray_start_with_dashboard["node_id"]

# Not using state_api_utils.get_agent_address because that's async and we
# need a sync one here.
gcs_client = make_gcs_client(ray_start_with_dashboard)
agent_addr_json = gcs_client.internal_kv_get(
f"{dashboard_consts.DASHBOARD_AGENT_ADDR_PREFIX}{node_id_hex}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
assert agent_addr_json is not None
node_ip, http_port, grpc_port = json.loads(agent_addr_json)

response = requests.get(
f"http://{node_ip}:{http_port}"
Expand Down

0 comments on commit 4ffae76

Please sign in to comment.