From 4ffae76e15e6a65f4b796353ecd52f5c43d54f39 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Fri, 17 Jan 2025 16:38:17 -0800 Subject: [PATCH] tests, sampling logic Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/job/job_head.py | 45 +++++++---- .../modules/job/tests/test_http_job_server.py | 81 ++++++++++++------- python/ray/dashboard/state_api_utils.py | 2 +- python/ray/dashboard/tests/test_dashboard.py | 15 +++- 4 files changed, 92 insertions(+), 51 deletions(-) diff --git a/python/ray/dashboard/modules/job/job_head.py b/python/ray/dashboard/modules/job/job_head.py index 46e851daceb98..0fb256c5a9f2c 100644 --- a/python/ray/dashboard/modules/job/job_head.py +++ b/python/ray/dashboard/modules/job/job_head.py @@ -3,7 +3,7 @@ import json import logging import traceback -from random import sample +from random import choice from typing import AsyncIterator, Dict, List, Optional, Tuple import aiohttp.web @@ -11,10 +11,11 @@ from aiohttp.web import Request, Response import ray +from ray import NodeID +import ray.dashboard.consts as dashboard_consts from ray.dashboard.consts import ( GCS_RPC_TIMEOUT_SECONDS, DASHBOARD_AGENT_ADDR_PREFIX, - CANDIDATE_AGENT_NUMBER, TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS, WAIT_AVAILABLE_AGENT_TIMEOUT, ) @@ -171,8 +172,8 @@ def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): # `JobHead` has ever used, and will not be deleted # from it unless `JobAgentSubmissionClient` is no # longer available (the corresponding agent process is dead) - # {node_id_hex: JobAgentSubmissionClient} - self._agents: Dict[str, JobAgentSubmissionClient] = dict() + # {node_id: JobAgentSubmissionClient} + self._agents: Dict[NodeID, JobAgentSubmissionClient] = dict() async def get_target_agent(self) -> Optional[JobAgentSubmissionClient]: if RAY_JOB_AGENT_USE_HEAD_NODE_ONLY: @@ -204,13 +205,13 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: client = self._agents.pop(dead_node) await client.close() - if len(self._agents) >= CANDIDATE_AGENT_NUMBER: - node_id = sample(list(set(self._agents)), 1)[0] + if len(self._agents) >= dashboard_consts.CANDIDATE_AGENT_NUMBER: + node_id = choice(list(self._agents)) return self._agents[node_id] else: # Randomly select one from among all agents, it is possible that # the selected one already exists in `self._agents` - node_id = sample(sorted(agent_infos), 1)[0] + node_id = choice(list(agent_infos)) agent_info = agent_infos[node_id] if node_id not in self._agents: @@ -223,16 +224,20 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: """Retrieves HTTP client for `JobAgent` running on the Head node""" - head_node_id = await get_head_node_id(self.gcs_aio_client) + head_node_id_binary = await get_head_node_id(self.gcs_aio_client) - if not head_node_id: + if not head_node_id_binary: logger.warning("Head node id has not yet been persisted in GCS") return None + head_node_id = NodeID.from_hex(head_node_id_binary) + if head_node_id not in self._agents: agent_infos = await self._fetch_agent_infos([head_node_id]) if head_node_id not in agent_infos: - logger.error("Head node agent's information was not found") + logger.error( + f"Head node agent's information was not found: {head_node_id} not in {agent_infos}" + ) return None ip, http_port, grpc_port = agent_infos[head_node_id] @@ -242,7 +247,7 @@ async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: return self._agents[head_node_id] - async def _fetch_all_agent_infos(self) -> Dict[str, Tuple[str, int, int]]: + async def _fetch_all_agent_infos(self) -> Dict[NodeID, Tuple[str, int, int]]: """ Fetches all agent infos for all nodes in the cluster. @@ -268,8 +273,11 @@ async def _fetch_all_agent_infos(self) -> Dict[str, Tuple[str, int, int]]: namespace=KV_NAMESPACE_DASHBOARD, timeout=GCS_RPC_TIMEOUT_SECONDS, ) + prefix_len = len(DASHBOARD_AGENT_ADDR_PREFIX) return { - key.decode(): json.loads(value.decode()) + NodeID.from_hex(key[prefix_len:].decode()): json.loads( + value.decode() + ) for key, value in values.items() } @@ -280,8 +288,8 @@ async def _fetch_all_agent_infos(self) -> Dict[str, Tuple[str, int, int]]: await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) async def _fetch_agent_infos( - self, target_node_ids: List[str] - ) -> Dict[str, Tuple[str, int, int]]: + self, target_node_ids: List[NodeID] + ) -> Dict[NodeID, Tuple[str, int, int]]: """ Fetches agent infos for nodes identified by provided node-ids. @@ -294,8 +302,8 @@ async def _fetch_agent_infos( while True: try: keys = [ - f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id_hex}" - for node_id_hex in target_node_ids + f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}" + for node_id in target_node_ids ] values: Dict[ bytes, bytes @@ -307,8 +315,11 @@ async def _fetch_agent_infos( if not values or len(values) != len(target_node_ids): # Not all agent infos found, retry raise Exception() + prefix_len = len(DASHBOARD_AGENT_ADDR_PREFIX) return { - key.decode(): json.loads(value.decode()) + NodeID.from_hex(key[prefix_len:].decode()): json.loads( + value.decode() + ) for key, value in values.items() } except Exception: diff --git a/python/ray/dashboard/modules/job/tests/test_http_job_server.py b/python/ray/dashboard/modules/job/tests/test_http_job_server.py index ecbef0828cd6b..192b2f27cf53b 100644 --- a/python/ray/dashboard/modules/job/tests/test_http_job_server.py +++ b/python/ray/dashboard/modules/job/tests/test_http_job_server.py @@ -8,13 +8,14 @@ import tempfile import time from pathlib import Path -from typing import Optional +from typing import Optional, List from unittest.mock import patch import pytest import yaml import ray +from ray import NodeID from ray._private.test_utils import ( chdir, format_web_url, @@ -22,6 +23,7 @@ wait_for_condition, wait_until_server_available, ) +from ray.dashboard.consts import DASHBOARD_AGENT_ADDR_PREFIX from ray.dashboard.modules.dashboard_sdk import ClusterInfo, parse_cluster_info from ray.dashboard.modules.job.job_head import JobHead from ray.dashboard.modules.job.pydantic_models import JobDetails @@ -736,29 +738,58 @@ async def test_job_head_pick_random_job_agent(monkeypatch): importlib.reload(ray.dashboard.consts) - from ray.dashboard.datacenter import DataSource + # Fake GCS client + class _FakeGcsClient: + def __init__(self): + self._kv = {} + + async def internal_kv_put(self, key: bytes, value: bytes, **kwargs): + self._kv[key] = value + + async def internal_kv_get(self, key: bytes, **kwargs): + return self._kv.get(key, None) + + async def internal_kv_multi_get(self, keys: List[bytes], **kwargs): + return {key: self._kv.get(key, None) for key in keys} + + async def internal_kv_del(self, key: bytes, **kwargs): + self._kv.pop(key) + + async def internal_kv_keys(self, prefix: bytes, **kwargs): + return [key for key in self._kv.keys() if key.startswith(prefix)] class MockJobHead(JobHead): def __init__(self): self._agents = dict() - DataSource.nodes = {} job_head = MockJobHead() + job_head._gcs_aio_client = _FakeGcsClient() - def add_agent(agent): + async def add_agent(agent): node_id = agent[0] node_ip = agent[1]["ipAddress"] http_port = agent[1]["httpPort"] grpc_port = agent[1]["grpcPort"] - DataSource.nodes[node_id] = {"nodeManagerAddress": node_ip} - # DO NOT SUBMIT: changed to internal kv on node_id - DataSource.agents[node_id] = (node_ip, http_port, grpc_port) - def del_agent(agent): - node_id = agent[0] - DataSource.nodes.pop(node_id) + await job_head._gcs_aio_client.internal_kv_put( + f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}".encode(), + json.dumps([node_ip, http_port, grpc_port]).encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) - head_node_id = "node1" + async def del_agent(agent): + node_id = agent[0] + await job_head._gcs_aio_client.internal_kv_del( + f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}".encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + head_node_id = NodeID.from_random() + await job_head._gcs_aio_client.internal_kv_put( + ray_constants.KV_HEAD_NODE_ID_KEY, + head_node_id.hex().encode(), + namespace=ray_constants.KV_NAMESPACE_JOB, + ) agent_1 = ( head_node_id, @@ -770,7 +801,7 @@ def del_agent(agent): ), ) agent_2 = ( - "node2", + NodeID.from_random(), dict( ipAddress="2.2.2.2", httpPort=2, @@ -779,7 +810,7 @@ def del_agent(agent): ), ) agent_3 = ( - "node3", + NodeID.from_random(), dict( ipAddress="3.3.3.3", httpPort=3, @@ -795,12 +826,12 @@ def del_agent(agent): ) # Check only 1 agent present, only agent being returned - add_agent(agent_1) + await add_agent(agent_1) job_agent_client = await job_head.get_target_agent() assert job_agent_client._agent_address == "http://1.1.1.1:1" # Remove only agent, no agents present, should time out - del_agent(agent_1) + await del_agent(agent_1) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(job_head.get_target_agent(), timeout=3) @@ -811,19 +842,9 @@ def del_agent(agent): ) # Add 3 agents - add_agent(agent_1) - add_agent(agent_2) - add_agent(agent_3) - - # Mock GCS client - class _MockedGCSClient: - async def internal_kv_get(self, key: bytes, **kwargs): - if key == ray_constants.KV_HEAD_NODE_ID_KEY: - return head_node_id.encode() - - return None - - job_head._gcs_aio_client = _MockedGCSClient() + await add_agent(agent_1) + await add_agent(agent_2) + await add_agent(agent_3) # Make sure returned agent is a head-node # NOTE: We run 3 tims to make sure we're not hitting branch probabilistically @@ -852,7 +873,7 @@ async def internal_kv_get(self, key: bytes, **kwargs): for agent in [agent_1, agent_2, agent_3]: if f"http://{agent[1]['httpAddress']}" in addresses_2: break - del_agent(agent) + await del_agent(agent) # Theoretically, the probability of failure is 1/2^100 addresses_3 = set() @@ -870,7 +891,7 @@ async def internal_kv_get(self, key: bytes, **kwargs): for agent in [agent_1, agent_2, agent_3]: if f"http://{agent[1]['httpAddress']}" in addresses_4: break - del_agent(agent) + await del_agent(agent) address = None for _ in range(3): job_agent_client = await job_head.get_target_agent() diff --git a/python/ray/dashboard/state_api_utils.py b/python/ray/dashboard/state_api_utils.py index 7e64f3aaec2b2..14f7e5b9e7fe7 100644 --- a/python/ray/dashboard/state_api_utils.py +++ b/python/ray/dashboard/state_api_utils.py @@ -267,7 +267,7 @@ async def get_agent_address( returns a tuple of (ip, http_port, grpc_port). - If either of them are not found, return None. + If not found, return None. """ agent_addr_json = await gcs_aio_client.internal_kv_get( f"{DASHBOARD_AGENT_ADDR_PREFIX}{node_id.hex()}".encode(), diff --git a/python/ray/dashboard/tests/test_dashboard.py b/python/ray/dashboard/tests/test_dashboard.py index 1c51570fc6168..4e1b93d4823cf 100644 --- a/python/ray/dashboard/tests/test_dashboard.py +++ b/python/ray/dashboard/tests/test_dashboard.py @@ -382,10 +382,19 @@ def test_http_get(enable_test_module, ray_start_with_dashboard): raise ex assert dump_info["result"] is True dump_data = dump_info["data"] - assert len(dump_data["agents"]) == 1 - node_id, (node_ip, http_port, grpc_port) = next( - iter(dump_data["agents"].items()) + assert len(dump_data["nodes"]) == 1 + + node_id_hex = ray_start_with_dashboard["node_id"] + + # Not using state_api_utils.get_agent_address because that's async and we + # need a sync one here. + gcs_client = make_gcs_client(ray_start_with_dashboard) + agent_addr_json = gcs_client.internal_kv_get( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_PREFIX}{node_id_hex}", + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, ) + assert agent_addr_json is not None + node_ip, http_port, grpc_port = json.loads(agent_addr_json) response = requests.get( f"http://{node_ip}:{http_port}"