From e533201252e40f45938570f85263950868c52fd9 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Thu, 19 Sep 2024 18:14:43 +0900 Subject: [PATCH 1/3] refactor: Speed up and deduplicate scheduler test codes --- src/ai/backend/manager/models/utils.py | 19 + .../backend/manager/scheduler/dispatcher.py | 326 +++-- src/ai/backend/manager/scheduler/types.py | 13 + tests/manager/BUILD | 1 + tests/manager/conftest.py | 17 +- tests/manager/models/test_dbutils.py | 23 +- tests/manager/scheduler_utils.py | 207 +++ tests/manager/test_agent_selector.py | 189 +++ tests/manager/test_priority_scheduler.py | 80 ++ tests/manager/test_scheduler.py | 1178 +++-------------- 10 files changed, 903 insertions(+), 1150 deletions(-) create mode 100644 tests/manager/scheduler_utils.py create mode 100644 tests/manager/test_agent_selector.py create mode 100644 tests/manager/test_priority_scheduler.py diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py index 2b58e4f845..0b76afd848 100644 --- a/src/ai/backend/manager/models/utils.py +++ b/src/ai/backend/manager/models/utils.py @@ -33,6 +33,7 @@ from sqlalchemy.orm import sessionmaker from tenacity import ( AsyncRetrying, + AttemptManager, RetryError, TryAgain, retry_if_exception_type, @@ -414,6 +415,24 @@ async def execute_with_retry(txn_func: Callable[[], Awaitable[TQueryResult]]) -> return result +async def retry_txn(max_attempts: int = 20) -> AsyncIterator[AttemptManager]: + try: + async for attempt in AsyncRetrying( + wait=wait_exponential(multiplier=0.02, min=0.02, max=5.0), + stop=stop_after_attempt(max_attempts), + retry=retry_if_exception_type(TryAgain) | retry_if_exception_type(DBAPIError), + ): + # The caller of a Python generator cannot catch the exceptions thrown in the block, + # so we need to pass AttemptManager. + yield attempt + assert attempt.retry_state.outcome is not None + exc = attempt.retry_state.outcome.exception() + if isinstance(exc, DBAPIError) and not is_db_retry_error(exc): + raise exc + except RetryError: + raise RuntimeError(f"DB serialization failed after {max_attempts} retries") + + JSONCoalesceExpr: TypeAlias = sa.sql.elements.BinaryExpression diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 201d7f3106..f8e1b63759 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -5,7 +5,12 @@ import json import logging import uuid -from collections.abc import Awaitable, Mapping, Sequence +from collections.abc import ( + Awaitable, + Callable, + Mapping, + Sequence, +) from contextvars import ContextVar from datetime import datetime, timedelta from decimal import Decimal @@ -88,7 +93,12 @@ recalc_concurrency_used, ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine -from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge +from ..models.utils import ( + execute_with_retry, + retry_txn, + sql_json_increment, + sql_json_merge, +) from .predicates import ( check_concurrency, check_dependencies, @@ -425,82 +435,31 @@ async def _schedule_in_sgroup( sched_ctx: SchedulingContext, sgroup_name: str, ) -> None: - async def _apply_cancellation( - db_sess: SASession, session_ids: list[SessionId], reason="pending-timeout" - ): - now = datetime.now(tzutc()) - kernel_query = ( - sa.update(KernelRow) - .values( - status=KernelStatus.CANCELLED, - status_info=reason, - terminated_at=now, - status_history=sql_json_merge( - KernelRow.status_history, - (), - { - KernelStatus.CANCELLED.name: now.isoformat(), - }, - ), - ) - .where(KernelRow.session_id.in_(session_ids)) - ) - await db_sess.execute(kernel_query) - query = ( - sa.update(SessionRow) - .values( - status=SessionStatus.CANCELLED, - status_info=reason, - terminated_at=now, - status_history=sql_json_merge( - SessionRow.status_history, - (), - { - SessionStatus.CANCELLED.name: now.isoformat(), - }, - ), - ) - .where(SessionRow.id.in_(session_ids)) - ) - await db_sess.execute(query) + # Part 0: Load the scheduler and the agent selector. async with self.db.begin_readonly_session() as db_sess: scheduler, agent_selector = await self._load_scheduler(db_sess, sgroup_name) existing_sessions, pending_sessions, cancelled_sessions = await _list_managed_sessions( db_sess, sgroup_name, scheduler.sgroup_opts.pending_timeout ) - - if cancelled_sessions: - session_ids = [item.id for item in cancelled_sessions] - - async def _update(): - async with self.db.begin_session() as db_sess: - await _apply_cancellation(db_sess, session_ids) - - await execute_with_retry(_update) - for item in cancelled_sessions: - await self.event_producer.produce_event( - SessionCancelledEvent( - item.id, - item.creation_id, - reason=KernelLifecycleEventReason.PENDING_TIMEOUT, - ), - ) + await self.flush_cancelled_sessions(cancelled_sessions) + current_priority, pending_sessions = scheduler.prioritize(pending_sessions) log.debug( - "running scheduler (sgroup:{}, pending:{}, existing:{}, cancelled:{})", + "running scheduler (sgroup:{}, pending:{} at prio:{}, existing:{}, cancelled:{})", sgroup_name, len(pending_sessions), + current_priority, len(existing_sessions), len(cancelled_sessions), ) - zero = ResourceSlot() num_scheduled = 0 while len(pending_sessions) > 0: + # Part 1: Choose the pending session to try scheduling. + async with self.db.begin_readonly_session() as db_sess: candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name) - total_capacity = sum((ag.available_slots for ag in candidate_agents), zero) - + total_capacity = sum((ag.available_slots for ag in candidate_agents), ResourceSlot()) picked_session_id = scheduler.pick_session( total_capacity, pending_sessions, @@ -510,76 +469,38 @@ async def _update(): # no session is picked. # continue to next sgroup. return - for picked_idx, sess_ctx in enumerate(pending_sessions): - if sess_ctx.id == picked_session_id: + for picked_idx, pending_sess in enumerate(pending_sessions): + if pending_sess.id == picked_session_id: break else: # no matching entry for picked session? raise RuntimeError("should not reach here") - sess_ctx = pending_sessions.pop(picked_idx) - log_fmt = "schedule(s:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): " + pending_sess = pending_sessions.pop(picked_idx) + log_fmt = "schedule(s:{}, prio:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): " log_args = ( - sess_ctx.id, - sess_ctx.session_type, - sess_ctx.name, - sess_ctx.access_key, - sess_ctx.cluster_mode, + pending_sess.id, + pending_sess.priority, + pending_sess.session_type, + pending_sess.name, + pending_sess.access_key, + pending_sess.cluster_mode, ) _log_fmt.set(log_fmt) _log_args.set(log_args) log.debug(log_fmt + "try-scheduling", *log_args) - async def _check_predicates() -> list[tuple[str, Union[Exception, PredicateResult]]]: - check_results: list[tuple[str, Union[Exception, PredicateResult]]] = [] - async with self.db.begin_session() as db_sess: - predicates: list[tuple[str, Awaitable[PredicateResult]]] = [ - ( - "reserved_time", - check_reserved_batch_session(db_sess, sched_ctx, sess_ctx), - ), - ("dependencies", check_dependencies(db_sess, sched_ctx, sess_ctx)), - ("concurrency", check_concurrency(db_sess, sched_ctx, sess_ctx)), - ] - if not sess_ctx.is_private: - predicates += [ - ( - "pending_session_resource_limit", - check_pending_session_resource_limit(db_sess, sched_ctx, sess_ctx), - ), - ( - "pending_session_count_limit", - check_pending_session_count_limit(db_sess, sched_ctx, sess_ctx), - ), - ( - "keypair_resource_limit", - check_keypair_resource_limit(db_sess, sched_ctx, sess_ctx), - ), - ( - "user_resource_limit", - check_user_resource_limit(db_sess, sched_ctx, sess_ctx), - ), - ( - "user_group_resource_limit", - check_group_resource_limit(db_sess, sched_ctx, sess_ctx), - ), - ( - "domain_resource_limit", - check_domain_resource_limit(db_sess, sched_ctx, sess_ctx), - ), - ] - for predicate_name, check_coro in predicates: - try: - check_results.append((predicate_name, await check_coro)) - except DBAPIError: - raise - except Exception as e: - log.exception(log_fmt + "predicate-error", *log_args) - check_results.append((predicate_name, e)) - return check_results - - check_results = await execute_with_retry(_check_predicates) + # Part 2: Predicate checks with predicate hook plugins + + check_results = [] failed_predicates = [] passed_predicates = [] + async for attempt in retry_txn(): + with attempt: + check_results = await self.check_predicates( + sched_ctx, + pending_sess, + exc_handler=lambda _: log.exception(log_fmt + "predicate-error", *log_args), + ) for predicate_name, result in check_results: if isinstance(result, Exception): failed_predicates.append({ @@ -597,18 +518,10 @@ async def _check_predicates() -> list[tuple[str, Union[Exception, PredicateResul "msg": result.message or "", }) - async def _check_predicates_hook() -> HookResult: - async with self.db.begin_readonly_session() as db_sess: - return await self.registry.hook_plugin_ctx.dispatch( - "PREDICATE", - ( - db_sess, - sched_ctx, - sess_ctx, - ), - ) - - hook_result = await execute_with_retry(_check_predicates_hook) + hook_result = HookResult(status=PASSED, src_plugin=[], result=[]) + async for attempt in retry_txn(): + with attempt: + hook_result = await self.check_predicates_hook(sched_ctx, pending_sess) match hook_result.src_plugin: case str(): hook_name = hook_result.src_plugin @@ -616,7 +529,6 @@ async def _check_predicates_hook() -> HookResult: hook_name = f"({", ".join(hook_result.src_plugin)})" case _: hook_name = "" - if hook_result.status == PASSED: if hook_result.src_plugin: # Append result only when plugin exists. @@ -627,6 +539,8 @@ async def _check_predicates_hook() -> HookResult: "msg": hook_result.reason or "", }) + # Part 3: Interpret the predicate check results + status_update_data = { "last_try": datetime.now(tzutc()).isoformat(), "failed_predicates": failed_predicates, @@ -640,7 +554,7 @@ async def _cancel_failed_system_session() -> None: await _rollback_predicate_mutations( db_sess, sched_ctx, - sess_ctx, + pending_sess, ) query = ( sa.update(SessionRow) @@ -652,15 +566,15 @@ async def _cancel_failed_system_session() -> None: parent_updates=status_update_data, ), ) - .where(SessionRow.id == sess_ctx.id) + .where(SessionRow.id == pending_sess.id) ) await db_sess.execute(query) - if sess_ctx.is_private: - await _apply_cancellation(db_sess, [sess_ctx.id]) + if pending_sess.is_private: + await _apply_cancellation(db_sess, [pending_sess.id]) await self.event_producer.produce_event( SessionCancelledEvent( - sess_ctx.id, - sess_ctx.creation_id, + pending_sess.id, + pending_sess.creation_id, reason=KernelLifecycleEventReason.PENDING_TIMEOUT, ) ) @@ -675,7 +589,7 @@ async def _update_session_status_data() -> None: async with self.db.begin_session() as db_sess: kernel_query = ( sa.update(KernelRow) - .where(KernelRow.session_id == sess_ctx.id) + .where(KernelRow.session_id == pending_sess.id) .values( status_data=sql_json_merge( KernelRow.status_data, @@ -687,7 +601,7 @@ async def _update_session_status_data() -> None: await db_sess.execute(kernel_query) session_query = ( sa.update(SessionRow) - .where(SessionRow.id == sess_ctx.id) + .where(SessionRow.id == pending_sess.id) .values( status_data=sql_json_merge( SessionRow.status_data, @@ -700,10 +614,12 @@ async def _update_session_status_data() -> None: await execute_with_retry(_update_session_status_data) + # Part 4: Assign agent(s) via the agent selector. + async with self.db.begin_readonly_session() as db_sess: schedulable_sess = await SessionRow.get_session_by_id( db_sess, - sess_ctx.id, + pending_sess.id, eager_loading_op=( noload("*"), selectinload(SessionRow.kernels).options( @@ -1696,6 +1612,134 @@ async def _mark_session_cancelled() -> None: else: log.info(log_fmt + "started", *log_args) + async def flush_cancelled_sessions(self, cancelled_sessions: Sequence[SessionRow]) -> None: + if not cancelled_sessions: + return + session_ids = [item.id for item in cancelled_sessions] + + async for attempt in retry_txn(): + with attempt: + async with self.db.begin_session() as db_sess: + await _apply_cancellation(db_sess, session_ids) + for item in cancelled_sessions: + await self.event_producer.produce_event( + SessionCancelledEvent( + item.id, + item.creation_id, + reason=KernelLifecycleEventReason.PENDING_TIMEOUT, + ), + ) + + async def check_predicates( + self, + sched_ctx: SchedulingContext, + pending_sess: SessionRow, + *, + exc_handler: Callable[[Exception], None] | None = None, + ) -> list[tuple[str, Union[Exception, PredicateResult]]]: + check_results: list[tuple[str, Union[Exception, PredicateResult]]] = [] + async with self.db.begin_session() as db_sess: + predicates: list[tuple[str, Awaitable[PredicateResult]]] = [ + ( + "reserved_time", + check_reserved_batch_session(db_sess, sched_ctx, pending_sess), + ), + ("dependencies", check_dependencies(db_sess, sched_ctx, pending_sess)), + ("concurrency", check_concurrency(db_sess, sched_ctx, pending_sess)), + ] + if not pending_sess.is_private: + predicates += [ + ( + "pending_session_resource_limit", + check_pending_session_resource_limit(db_sess, sched_ctx, pending_sess), + ), + ( + "pending_session_count_limit", + check_pending_session_count_limit(db_sess, sched_ctx, pending_sess), + ), + ( + "keypair_resource_limit", + check_keypair_resource_limit(db_sess, sched_ctx, pending_sess), + ), + ( + "user_resource_limit", + check_user_resource_limit(db_sess, sched_ctx, pending_sess), + ), + ( + "user_group_resource_limit", + check_group_resource_limit(db_sess, sched_ctx, pending_sess), + ), + ( + "domain_resource_limit", + check_domain_resource_limit(db_sess, sched_ctx, pending_sess), + ), + ] + for predicate_name, check_coro in predicates: + try: + check_results.append((predicate_name, await check_coro)) + except DBAPIError: + raise + except Exception as e: + if exc_handler is not None: + exc_handler(e) + check_results.append((predicate_name, e)) + return check_results + + async def check_predicates_hook( + self, + sched_ctx: SchedulingContext, + pending_sess: SessionRow, + ) -> HookResult: + async with self.db.begin_readonly_session() as db_sess: + return await self.registry.hook_plugin_ctx.dispatch( + "PREDICATE", + ( + db_sess, + sched_ctx, + pending_sess, + ), + ) + + +async def _apply_cancellation( + db_sess: SASession, session_ids: list[SessionId], reason="pending-timeout" +) -> None: + now = datetime.now(tzutc()) + kernel_query = ( + sa.update(KernelRow) + .values( + status=KernelStatus.CANCELLED, + status_info=reason, + terminated_at=now, + status_history=sql_json_merge( + KernelRow.status_history, + (), + { + KernelStatus.CANCELLED.name: now.isoformat(), + }, + ), + ) + .where(KernelRow.session_id.in_(session_ids)) + ) + await db_sess.execute(kernel_query) + query = ( + sa.update(SessionRow) + .values( + status=SessionStatus.CANCELLED, + status_info=reason, + terminated_at=now, + status_history=sql_json_merge( + SessionRow.status_history, + (), + { + SessionStatus.CANCELLED.name: now.isoformat(), + }, + ), + ) + .where(SessionRow.id.in_(session_ids)) + ) + await db_sess.execute(query) + async def _list_managed_sessions( db_sess: SASession, diff --git a/src/ai/backend/manager/scheduler/types.py b/src/ai/backend/manager/scheduler/types.py index 8b9120b3e5..bc11cb2c84 100644 --- a/src/ai/backend/manager/scheduler/types.py +++ b/src/ai/backend/manager/scheduler/types.py @@ -112,6 +112,19 @@ def config_iv(self) -> t.Dict: """ raise NotImplementedError + @staticmethod + def prioritize(pending_sessions: Sequence[SessionRow]) -> tuple[int, list[SessionRow]]: + """ + Filter the pending session list by the top priority among the observed priorities of the + given pending sessions. + """ + if not pending_sessions: + return -1, [] + priorities = {s.priority for s in pending_sessions} + assert len(priorities) > 0 + top_priority = sorted(priorities, reverse=True)[0] + return top_priority, [*filter(lambda s: s.priority == top_priority, pending_sessions)] + @abstractmethod def pick_session( self, diff --git a/tests/manager/BUILD b/tests/manager/BUILD index c1c79ca1d7..30f52d20fb 100644 --- a/tests/manager/BUILD +++ b/tests/manager/BUILD @@ -2,6 +2,7 @@ python_test_utils( sources=[ "conftest.py", "model_factory.py", + "scheduler_utils.py", ], dependencies=[ ":fixtures", diff --git a/tests/manager/conftest.py b/tests/manager/conftest.py index cd16a3454f..ef6db12490 100644 --- a/tests/manager/conftest.py +++ b/tests/manager/conftest.py @@ -149,17 +149,23 @@ def logging_config(): yield config +@pytest.fixture(scope="session") +def ipc_base_path() -> Path: + ipc_base_path = Path.cwd() / f"tmp/backend.ai/manager-testing/ipc-{test_id}" + ipc_base_path.mkdir(parents=True, exist_ok=True) + return ipc_base_path + + @pytest.fixture(scope="session") def local_config( test_id, + ipc_base_path: Path, logging_config, etcd_container, # noqa: F811 redis_container, # noqa: F811 postgres_container, # noqa: F811 test_db, ) -> Iterator[LocalConfig]: - ipc_base_path = Path.cwd() / f"tmp/backend.ai/manager-testing/ipc-{test_id}" - ipc_base_path.mkdir(parents=True, exist_ok=True) etcd_addr = etcd_container[1] redis_addr = redis_container[1] postgres_addr = postgres_container[1] @@ -480,9 +486,12 @@ async def clean_fixture() -> None: @pytest.fixture -def file_lock_factory(local_config, request) -> Callable[[str], FileLock]: +def file_lock_factory( + ipc_base_path: Path, + request: pytest.FixtureRequest, +) -> Callable[[str], FileLock]: def _make_lock(lock_id: str) -> FileLock: - lock_path = local_config["manager"]["ipc-base-path"] / f"testing.{lock_id}.lock" + lock_path = ipc_base_path / f"testing.{lock_id}.lock" lock = FileLock(lock_path, timeout=0) request.addfinalizer(partial(lock_path.unlink, missing_ok=True)) return lock diff --git a/tests/manager/models/test_dbutils.py b/tests/manager/models/test_dbutils.py index 9ade5140e5..d2fadd3544 100644 --- a/tests/manager/models/test_dbutils.py +++ b/tests/manager/models/test_dbutils.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import asyncio +from unittest.mock import MagicMock import aiotools import pytest import sqlalchemy as sa +from sqlalchemy.exc import DBAPIError -from ai.backend.manager.models.utils import execute_with_retry, execute_with_txn_retry +from ai.backend.manager.models.utils import ( + execute_with_retry, + execute_with_txn_retry, + retry_txn, +) @pytest.mark.asyncio @@ -95,3 +103,16 @@ async def txn_func_temporary_serialization_failure(db_session): txn_func_temporary_serialization_failure, database_engine.begin_session, conn ) assert ret == 1234 + + +@pytest.mark.asyncio +async def test_retry_txn_as_code_block() -> None: + orig = MagicMock() + orig.pgcode = "40001" + retry_count = 0 + with pytest.raises(RuntimeError): + async for attempt in retry_txn(4): + with attempt: + retry_count += 1 + raise DBAPIError(None, None, orig) + assert retry_count > 3 diff --git a/tests/manager/scheduler_utils.py b/tests/manager/scheduler_utils.py new file mode 100644 index 0000000000..2044d6061b --- /dev/null +++ b/tests/manager/scheduler_utils.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import secrets +from collections.abc import ( + Iterator, + Mapping, + Sequence, +) +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import ( + Any, + Final, +) +from uuid import uuid4 + +from dateutil.parser import parse as dtparse + +from ai.backend.common.docker import ImageRef +from ai.backend.common.types import ( + AccessKey, + AgentId, + KernelId, + ResourceSlot, + SessionId, + SessionTypes, +) +from ai.backend.manager.defs import DEFAULT_ROLE +from ai.backend.manager.models.agent import AgentRow +from ai.backend.manager.models.kernel import KernelRow, KernelStatus +from ai.backend.manager.models.session import SESSION_PRIORITY_DEFUALT, SessionRow, SessionStatus + +ARCH_FOR_TEST: Final = "x86_64" + +agent_selection_resource_priority: Final = ["cuda", "rocm", "tpu", "cpu", "mem"] + +common_image_ref: Final = ImageRef( + "lablup/python:3.6-ubunt18.04", + ["*"], + architecture=ARCH_FOR_TEST, +) + +example_group_id = uuid4() +example_total_capacity = ResourceSlot({"cpu": "4.0", "mem": "4096"}) +example_sgroup_name1: Final = "sg01" +example_sgroup_name2: Final = "sg02" + +_common_dummy_for_pending_session: Mapping[str, Any] = dict( + domain_name="default", + group_id=example_group_id, + vfolder_mounts=[], + environ={}, + bootstrap_script=None, + startup_command=None, + use_host_network=False, +) + +_timestamp_count = 0 + + +def generate_timestamp() -> datetime: + global _timestamp_count + val = dtparse("2021-12-28T23:59:59+00:00") + val += timedelta(_timestamp_count * 10) + _timestamp_count += 1 + return val + + +def generate_role() -> Iterator[tuple[str, int, int]]: + yield ("main", 1, 0) + sub_idx = 1 + while True: + yield ("sub", sub_idx, sub_idx) + sub_idx += 1 + + +@dataclass +class KernelOpt: + requested_slots: ResourceSlot + kernel_id: KernelId = field(default_factory=lambda: KernelId(uuid4())) + image: ImageRef = common_image_ref + + +_sess_kern_status_map = { + SessionStatus.PENDING: KernelStatus.PENDING, + SessionStatus.SCHEDULED: KernelStatus.SCHEDULED, + SessionStatus.RUNNING: KernelStatus.RUNNING, + SessionStatus.TERMINATED: KernelStatus.TERMINATED, +} + + +def find_and_pop_picked_session(pending_sessions, picked_session_id) -> SessionRow: + for picked_idx, pending_sess in enumerate(pending_sessions): + if pending_sess.id == picked_session_id: + break + else: + # no matching entry for picked session? + raise RuntimeError("should not reach here") + return pending_sessions.pop(picked_idx) + + +def update_agent_assignment( + agents: Sequence[AgentRow], + picked_agent_id: AgentId, + occupied_slots: ResourceSlot, +) -> None: + for ag in agents: + if ag.id == picked_agent_id: + ag.occupied_slots += occupied_slots + + +def create_mock_kernel( + session_id: SessionId, + kernel_id: KernelId, + requested_slots: ResourceSlot, + *, + status: KernelStatus = KernelStatus.PENDING, + cluster_role: str = DEFAULT_ROLE, + cluster_idx: int = 1, + local_rank: int = 0, +) -> KernelRow: + return KernelRow( + id=session_id, + session_id=kernel_id, + status=status, + access_key="dummy-access-key", + agent=None, + agent_addr=None, + cluster_role=cluster_role, + cluster_idx=cluster_idx, + local_rank=local_rank, + cluster_hostname=f"{cluster_role}{cluster_idx}", + architecture=common_image_ref.architecture, + registry=common_image_ref.registry, + image=common_image_ref.name, + requested_slots=requested_slots, + bootstrap_script=None, + startup_command=None, + created_at=generate_timestamp(), + ) + + +def create_mock_session( + session_id: SessionId, + requested_slots: ResourceSlot, + *, + access_key: AccessKey = AccessKey("user01"), + status: SessionStatus = SessionStatus.PENDING, + status_data: dict[str, Any] | None = None, + kernel_opts: Sequence[KernelOpt] | None = None, + priority: int = SESSION_PRIORITY_DEFUALT, +) -> SessionRow: + """Create a simple single-kernel pending session.""" + if kernel_opts is None: + # Create a single pending kernel as a default + kernel_opts = [KernelOpt(requested_slots=requested_slots)] + return SessionRow( + kernels=[ + create_mock_kernel( + session_id, + kopt.kernel_id, + kopt.requested_slots, + status=_sess_kern_status_map[status], + cluster_role=role_name, + cluster_idx=role_idx, + local_rank=local_rank, + ) + for kopt, (role_name, role_idx, local_rank) in zip(kernel_opts, generate_role()) + ], + priority=priority, + access_key=access_key, + id=session_id, + creation_id=secrets.token_hex(8), + name=f"session-{secrets.token_hex(4)}", + session_type=SessionTypes.BATCH, + status=status, + status_data=status_data, + cluster_mode="single-node", + cluster_size=len(kernel_opts), + scaling_group_name=example_sgroup_name1, + requested_slots=requested_slots, + occupying_slots=( + requested_slots + if status not in (SessionStatus.PENDING, SessionStatus.SCHEDULED) + else ResourceSlot() + ), + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=generate_timestamp(), + ) + + +def create_mock_agent( + id: AgentId, + *, + scaling_group: str = example_sgroup_name1, + available_slots: ResourceSlot, + occupied_slots: ResourceSlot = ResourceSlot(), +) -> AgentRow: + return AgentRow( + id=id, + addr="10.0.1.1:6001", + architecture=ARCH_FOR_TEST, + scaling_group=scaling_group, + available_slots=available_slots, + occupied_slots=occupied_slots, + ) diff --git a/tests/manager/test_agent_selector.py b/tests/manager/test_agent_selector.py new file mode 100644 index 0000000000..6b544d45f9 --- /dev/null +++ b/tests/manager/test_agent_selector.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from collections.abc import Sequence +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from ai.backend.common.types import ( + AgentId, + AgentSelectionStrategy, + ResourceSlot, + SessionId, +) +from ai.backend.manager.models.agent import AgentRow +from ai.backend.manager.models.scaling_group import ScalingGroupOpts +from ai.backend.manager.models.session import SessionRow +from ai.backend.manager.scheduler.agent_selector import RoundRobinAgentSelector +from ai.backend.manager.scheduler.fifo import FIFOSlotScheduler +from ai.backend.manager.scheduler.types import InMemoryResourceGroupStateStore + +from .scheduler_utils import ( + agent_selection_resource_priority, + create_mock_agent, + create_mock_session, + find_and_pop_picked_session, + update_agent_assignment, +) + + +@pytest.mark.asyncio +async def test_agent_selection_strategy_rr() -> None: + example_homogeneous_agents = [ + create_mock_agent( + AgentId(f"i-{idx:03d}"), + available_slots=ResourceSlot({ + "cpu": Decimal("4.0"), + "mem": Decimal("4096"), + "cuda.shares": Decimal("4.0"), + "rocm.devices": Decimal("2"), + }), + ) + for idx in range(10) + ] + example_homogeneous_pending_sessions = [ + create_mock_session( + SessionId(uuid4()), + requested_slots=ResourceSlot({"cpu": Decimal("2"), "mem": Decimal("1024")}), + ) + for _ in range(20) + ] + + sgroup_opts = ScalingGroupOpts( + agent_selection_strategy=AgentSelectionStrategy.ROUNDROBIN, + ) + scheduler = FIFOSlotScheduler( + sgroup_opts, + {}, + ) + + agstate_cls = RoundRobinAgentSelector.get_state_cls() + agselector = RoundRobinAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) + + num_agents = len(example_homogeneous_agents) + total_capacity = sum((ag.available_slots for ag in example_homogeneous_agents), ResourceSlot()) + agent_ids = [] + # Repeat the allocation for two iterations + for _ in range(num_agents * 2): + picked_session_id = scheduler.pick_session( + total_capacity, + example_homogeneous_pending_sessions, + [], + ) + assert picked_session_id == example_homogeneous_pending_sessions[0].id + picked_session = find_and_pop_picked_session( + example_homogeneous_pending_sessions, + picked_session_id, + ) + agent_ids.append( + await agselector.assign_agent_for_session( + example_homogeneous_agents, + picked_session, + ) + ) + assert agent_ids == [AgentId(f"i-{idx:03d}") for idx in range(num_agents)] * 2 + + +@pytest.mark.asyncio +async def test_agent_selection_strategy_rr_skip_unacceptable_agents() -> None: + agents: Sequence[AgentRow] = [ + create_mock_agent( + AgentId("i-001"), + available_slots=ResourceSlot({"cpu": Decimal("8"), "mem": Decimal("4096")}), + ), + create_mock_agent( + AgentId("i-002"), + available_slots=ResourceSlot({"cpu": Decimal("4"), "mem": Decimal("2048")}), + ), + create_mock_agent( + AgentId("i-003"), + available_slots=ResourceSlot({"cpu": Decimal("2"), "mem": Decimal("1024")}), + ), + create_mock_agent( + AgentId("i-004"), + available_slots=ResourceSlot({"cpu": Decimal("1"), "mem": Decimal("512")}), + ), + ] + pending_sessions = [ + create_mock_session( + SessionId(uuid4()), + ResourceSlot({"cpu": Decimal("2"), "mem": Decimal("500")}), + ) + for _ in range(8) + ] + # Expected result: + # i-001 can accommodate 4 sessions. + # i-002 can accommodate 2 sessions. + # i-003 can accommodate 1 sessions. + # i-004 can accommodate 0 sessions. + + sgroup_opts = ScalingGroupOpts( + agent_selection_strategy=AgentSelectionStrategy.ROUNDROBIN, + ) + scheduler = FIFOSlotScheduler( + sgroup_opts, + {}, + ) + + agstate_cls = RoundRobinAgentSelector.get_state_cls() + agselector = RoundRobinAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) + + total_capacity = sum((ag.available_slots for ag in agents), ResourceSlot()) + + results: list[tuple[AgentId | None, SessionId]] = [] + scheduled_sessions: list[SessionRow] = [] + + for _ in range(8): + picked_session_id = scheduler.pick_session( + total_capacity, + pending_sessions, + scheduled_sessions, + ) + assert picked_session_id is not None + picked_session = find_and_pop_picked_session( + pending_sessions, + picked_session_id, + ) + # Bookkeeping picked_session in scheduled_sessions should be skipped if we fail to + # assign an agent, but we keep it here for the validation step of this test case. + scheduled_sessions.append(picked_session) + picked_agent_id = await agselector.assign_agent_for_session( + agents, + picked_session, + ) + if picked_agent_id is not None: + update_agent_assignment(agents, picked_agent_id, picked_session.requested_slots) + results.append((picked_agent_id, picked_session_id)) + + print() + for ag in agents: + print( + ag.id, + f"{ag.occupied_slots["cpu"]}/{ag.available_slots["cpu"]}", + f"{ag.occupied_slots["mem"]}/{ag.available_slots["mem"]}", + ) + # As more sessions have the assigned agents, the remaining capacity diminishes + # and the range of round-robin also becomes limited. + # When there is no assignable agent, it should return None. + assert len(results) == 8 + assert results == [ + ("i-001", scheduled_sessions[0].id), + ("i-002", scheduled_sessions[1].id), + ("i-003", scheduled_sessions[2].id), + ("i-001", scheduled_sessions[3].id), + ("i-002", scheduled_sessions[4].id), + ("i-001", scheduled_sessions[5].id), + ("i-001", scheduled_sessions[6].id), + (None, scheduled_sessions[7].id), + ] diff --git a/tests/manager/test_priority_scheduler.py b/tests/manager/test_priority_scheduler.py new file mode 100644 index 0000000000..e46a2a2870 --- /dev/null +++ b/tests/manager/test_priority_scheduler.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from ai.backend.common.types import ( + ResourceSlot, + SessionId, +) +from ai.backend.manager.models.scaling_group import ScalingGroupOpts +from ai.backend.manager.scheduler.fifo import FIFOSlotScheduler, LIFOSlotScheduler + +from .scheduler_utils import ( + create_mock_session, + find_and_pop_picked_session, +) + + +@pytest.mark.asyncio +async def test_priority_scheduler_fifo() -> None: + sid = lambda: SessionId(uuid4()) + rs = ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1024)}) + total_capacity = ResourceSlot({"cpu": Decimal(8), "mem": Decimal(8192)}) + pending_sessions = [ + create_mock_session(sid(), rs, priority=10), + create_mock_session(sid(), rs, priority=8), + create_mock_session(sid(), rs, priority=10), + create_mock_session(sid(), rs, priority=12), + create_mock_session(sid(), rs, priority=10), + ] + session_ids: list[SessionId] = [s.id for s in pending_sessions] + picked_session_ids: list[SessionId] = [] + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + while pending_sessions: + _, prioritized_pending_sessions = scheduler.prioritize(pending_sessions) + picked_session_id = scheduler.pick_session(total_capacity, prioritized_pending_sessions, []) + assert picked_session_id is not None + find_and_pop_picked_session(pending_sessions, picked_session_id) + picked_session_ids.append(picked_session_id) + + assert picked_session_ids == [ + session_ids[3], # priority 12 + session_ids[0], # priority 10 (oldest) + session_ids[2], # priority 10 + session_ids[4], # priority 10 (newest) + session_ids[1], # priority 8 + ] + + +@pytest.mark.asyncio +async def test_priority_scheduler_lifo() -> None: + sid = lambda: SessionId(uuid4()) + rs = ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1024)}) + total_capacity = ResourceSlot({"cpu": Decimal(8), "mem": Decimal(8192)}) + pending_sessions = [ + create_mock_session(sid(), rs, priority=10), + create_mock_session(sid(), rs, priority=8), + create_mock_session(sid(), rs, priority=10), + create_mock_session(sid(), rs, priority=12), + create_mock_session(sid(), rs, priority=10), + ] + session_ids: list[SessionId] = [s.id for s in pending_sessions] + picked_session_ids: list[SessionId] = [] + scheduler = LIFOSlotScheduler(ScalingGroupOpts(), {}) + while pending_sessions: + _, prioritized_pending_sessions = scheduler.prioritize(pending_sessions) + picked_session_id = scheduler.pick_session(total_capacity, prioritized_pending_sessions, []) + assert picked_session_id is not None + find_and_pop_picked_session(pending_sessions, picked_session_id) + picked_session_ids.append(picked_session_id) + + assert picked_session_ids == [ + session_ids[3], # priority 12 + session_ids[4], # priority 10 (newest) + session_ids[2], # priority 10 + session_ids[0], # priority 10 (oldest) + session_ids[1], # priority 8 + ] diff --git a/tests/manager/test_scheduler.py b/tests/manager/test_scheduler.py index 57d00ce94f..eaeec156c6 100644 --- a/tests/manager/test_scheduler.py +++ b/tests/manager/test_scheduler.py @@ -1,43 +1,34 @@ from __future__ import annotations -import secrets from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from decimal import Decimal from pprint import pprint -from typing import Any, Generator +from typing import Any from unittest import mock from unittest.mock import AsyncMock, MagicMock -from uuid import UUID, uuid4 +from uuid import uuid4 -import attrs import pytest import pytest_mock import trafaret as t from dateutil.parser import parse as dtparse from dateutil.tz import tzutc -from ai.backend.common.docker import ImageRef from ai.backend.common.types import ( AccessKey, AgentId, AgentSelectionStrategy, - ClusterMode, - KernelId, ResourceSlot, SessionId, SessionTypes, ) -from ai.backend.manager.defs import DEFAULT_ROLE from ai.backend.manager.models.agent import AgentRow -from ai.backend.manager.models.image import ImageRow -from ai.backend.manager.models.kernel import KernelRow from ai.backend.manager.models.scaling_group import ScalingGroupOpts from ai.backend.manager.models.session import SessionRow, SessionStatus from ai.backend.manager.registry import AgentRegistry from ai.backend.manager.scheduler.agent_selector import ( DispersedAgentSelector, - RoundRobinAgentSelector, ) from ai.backend.manager.scheduler.dispatcher import ( SchedulerDispatcher, @@ -49,9 +40,15 @@ from ai.backend.manager.scheduler.predicates import check_reserved_batch_session from ai.backend.manager.scheduler.types import InMemoryResourceGroupStateStore -ARCH_FOR_TEST = "x86_64" - -agent_selection_resource_priority = ["cuda", "rocm", "tpu", "cpu", "mem"] +from .scheduler_utils import ( + KernelOpt, + agent_selection_resource_priority, + create_mock_agent, + create_mock_session, + example_sgroup_name1, + example_sgroup_name2, + find_and_pop_picked_session, +) def test_load_intrinsic() -> None: @@ -95,847 +92,150 @@ def test_scheduler_configs() -> None: ) -example_group_id = uuid4() - -example_total_capacity = ResourceSlot({"cpu": "4.0", "mem": "4096"}) -example_sgroup_name1 = "sg01" -example_sgroup_name2 = "sg02" - - -@pytest.fixture -def example_agents() -> Sequence[AgentRow]: - return [ - AgentRow( - id=AgentId("i-001"), - addr="10.0.1.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name1, - available_slots=ResourceSlot({ - "cpu": Decimal("4.0"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("4.0"), - "rocm.devices": Decimal("2"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - ), - AgentRow( - id=AgentId("i-101"), - addr="10.0.2.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name2, - available_slots=ResourceSlot({ - "cpu": Decimal("3.0"), - "mem": Decimal("2560"), - "cuda.shares": Decimal("1.0"), - "rocm.devices": Decimal("8"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - ), - ] - - -@pytest.fixture -def example_agents_many() -> Sequence[AgentRow]: - return [ - AgentRow( - id=AgentId("i-001"), - addr="10.0.1.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name1, - available_slots=ResourceSlot({ - "cpu": Decimal("8"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("4.0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - }), - ), - AgentRow( - id=AgentId("i-002"), - addr="10.0.2.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name2, - available_slots=ResourceSlot({ - "cpu": Decimal("4"), - "mem": Decimal("2048"), - "cuda.shares": Decimal("1.0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - }), - ), - AgentRow( - id=AgentId("i-003"), - addr="10.0.3.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name2, - available_slots=ResourceSlot({ - "cpu": Decimal("2"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("1.0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - }), - ), - AgentRow( - id=AgentId("i-004"), - addr="10.0.4.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name2, - available_slots=ResourceSlot({ - "cpu": Decimal("1"), - "mem": Decimal("512"), - "cuda.shares": Decimal("0.5"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - }), - ), - ] - - -@pytest.fixture -def example_agents_multi_homogeneous( - request: pytest.FixtureRequest, -) -> Generator[Sequence[AgentRow], None, None]: - repeat = request.param.get("repeat", 10) - - yield [ - AgentRow( - id=AgentId(f"i-{idx:03d}"), - addr=f"10.0.1.{idx}:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name1, - available_slots=ResourceSlot({ - "cpu": Decimal("4.0"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("4.0"), - "rocm.devices": Decimal("2"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - ) - for idx in range(repeat) - ] - - -@pytest.fixture -def example_mixed_agents() -> Sequence[AgentRow]: +def create_example_agents() -> Sequence[AgentRow]: return [ - AgentRow( - id=AgentId("i-gpu"), - addr="10.0.1.1:6001", - architecture=ARCH_FOR_TEST, + create_mock_agent( + AgentId("i-001"), scaling_group=example_sgroup_name1, available_slots=ResourceSlot({ - "cpu": Decimal("4.0"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("4.0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), + "cpu": Decimal(4), + "mem": Decimal(4096), + "cuda.shares": Decimal(4), + "rocm.devices": Decimal(2), }), ), - AgentRow( - id=AgentId("i-cpu"), - addr="10.0.2.1:6001", - architecture=ARCH_FOR_TEST, + create_mock_agent( + AgentId("i-101"), scaling_group=example_sgroup_name2, available_slots=ResourceSlot({ - "cpu": Decimal("3.0"), - "mem": Decimal("2560"), - "cuda.shares": Decimal("0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), + "cpu": Decimal(3), + "mem": Decimal(2560), + "cuda.shares": Decimal(1), + "rocm.devices": Decimal(8), }), ), ] -@pytest.fixture -def example_agents_first_one_assigned() -> Sequence[AgentRow]: +def create_example_mixed_agents() -> Sequence[AgentRow]: return [ - AgentRow( - id=AgentId("i-001"), - addr="10.0.1.1:6001", - architecture=ARCH_FOR_TEST, + create_mock_agent( + AgentId("i-gpu"), scaling_group=example_sgroup_name1, available_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("2048"), - "cuda.shares": Decimal("2.0"), - "rocm.devices": Decimal("1"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("2048"), - "cuda.shares": Decimal("2.0"), - "rocm.devices": Decimal("1"), + "cpu": Decimal(4), + "mem": Decimal(4096), + "cuda.shares": Decimal(4), }), ), - AgentRow( - id=AgentId("i-101"), - addr="10.0.2.1:6001", - architecture=ARCH_FOR_TEST, + create_mock_agent( + AgentId("i-cpu"), scaling_group=example_sgroup_name2, available_slots=ResourceSlot({ - "cpu": Decimal("3.0"), - "mem": Decimal("2560"), - "cuda.shares": Decimal("1.0"), - "rocm.devices": Decimal("8"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), + "cpu": Decimal(3), + "mem": Decimal(2560), + "cuda.shares": Decimal(0), }), ), ] -@pytest.fixture -def example_agents_no_valid() -> Sequence[AgentRow]: +def create_example_pending_sessions() -> Sequence[SessionRow]: return [ - AgentRow( - id=AgentId("i-001"), - addr="10.0.1.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name1, - available_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("4.0"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("4.0"), - "rocm.devices": Decimal("2"), - }), - ), - AgentRow( - id=AgentId("i-101"), - addr="10.0.2.1:6001", - architecture=ARCH_FOR_TEST, - scaling_group=example_sgroup_name2, - available_slots=ResourceSlot({ - "cpu": Decimal("0"), - "mem": Decimal("0"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - occupied_slots=ResourceSlot({ - "cpu": Decimal("3.0"), - "mem": Decimal("2560"), - "cuda.shares": Decimal("1.0"), - "rocm.devices": Decimal("8"), - }), - ), - ] - - -@attrs.define(auto_attribs=True, slots=True) -class SessionKernelIdPair: - session_id: SessionId - kernel_ids: Sequence[KernelId] - - -cancelled_session_ids = [ - SessionId(UUID("251907d9-1290-4126-bc6c-000000000999")), -] - -pending_session_kernel_ids = [ - SessionKernelIdPair( - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000100")), - kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-000000000100"))], - ), - SessionKernelIdPair( - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000200")), - kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-000000000200"))], - ), - SessionKernelIdPair( - # single-node mode multi-container session - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000300")), - kernel_ids=[ - KernelId(UUID("251907d9-1290-4126-bc6c-000000000300")), - KernelId(UUID("251907d9-1290-4126-bc6c-000000000301")), - KernelId(UUID("251907d9-1290-4126-bc6c-000000000302")), - ], - ), - SessionKernelIdPair( - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000400")), - kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-000000000400"))], - ), -] - -existing_session_kernel_ids = [ - SessionKernelIdPair( - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-100000000100")), - kernel_ids=[ - KernelId(UUID("251907d9-1290-4126-bc6c-100000000100")), - KernelId(UUID("251907d9-1290-4126-bc6c-100000000101")), - ], - ), - SessionKernelIdPair( - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-100000000200")), - kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-100000000200"))], - ), - SessionKernelIdPair( - # single-node mode multi-container session - session_id=SessionId(UUID("251907d9-1290-4126-bc6c-100000000300")), - kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-100000000300"))], - ), -] - -common_image_ref = ImageRef("lablup/python:3.6-ubunt18.04", ["*"], architecture=ARCH_FOR_TEST) -common_image = ImageRow( - name=common_image_ref.canonical, - image=common_image_ref.name, - tag=common_image_ref.tag, - registry=common_image_ref.registry, - architecture=ARCH_FOR_TEST, -) - -_common_dummy_for_pending_session: Mapping[str, Any] = dict( - domain_name="default", - group_id=example_group_id, - vfolder_mounts=[], - environ={}, - bootstrap_script=None, - startup_command=None, - use_host_network=False, -) - -_common_dummy_for_existing_session: Mapping[str, Any] = dict( - domain_name="default", - group_id=example_group_id, -) - - -@pytest.fixture -def example_homogeneous_pending_sessions( - request: pytest.FixtureRequest, -) -> Generator[Sequence[SessionRow], None, None]: - repeat = request.param.get("repeat", 10) - yield [ - SessionRow( - kernels=[ - KernelRow( - id=pending_session_kernel_ids[2].kernel_ids[0], - session_id=pending_session_kernel_ids[2].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("1024"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-12-01T23:59:59+00:00"), - ), - ], - access_key=AccessKey("user01"), - id=UUID(f"251907d9-1290-4126-bc6c-{idx:012x}"), - creation_id=f"{idx:012x}", - name=f"session-{idx}", - session_type=SessionTypes.BATCH, - status=SessionStatus.PENDING, - cluster_mode="single-node", - cluster_size=1, - scaling_group_name=example_sgroup_name1, - requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("1024"), - }), - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2021-12-28T23:59:59+00:00"), - ) - for idx in range(repeat) - ] - - -@pytest.fixture -def example_cancelled_sessions() -> Sequence[SessionRow]: - return [ - SessionRow( - access_key=AccessKey("user01"), - id=cancelled_session_ids[0], - creation_id="aaa100", - name="ecs01", - session_type=SessionTypes.BATCH, - status=SessionStatus.PENDING, - cluster_mode="single-node", - cluster_size=1, - scaling_group_name=example_sgroup_name1, - requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("1"), - }), - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2021-12-28T23:59:59+00:00"), - ), - ] - - -def create_pending_session( - session_id: SessionId, kernel_id: KernelId, requested_slots: ResourceSlot -) -> SessionRow: - """Create a simple single-kernel pending session.""" - return SessionRow( - kernels=[ - KernelRow( - id=session_id, - session_id=kernel_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("1"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-12-28T23:59:59+00:00"), - ), - ], - access_key=AccessKey("user01"), - id=pending_session_kernel_ids[0].session_id, - creation_id="aaa100", - name="eps01", - session_type=SessionTypes.BATCH, - status=SessionStatus.PENDING, - cluster_mode="single-node", - cluster_size=1, - scaling_group_name=example_sgroup_name1, - requested_slots=requested_slots, - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2021-12-28T23:59:59+00:00"), - ) - - -@pytest.fixture -def example_pending_sessions() -> Sequence[SessionRow]: - # lower indicies are enqueued first. - return [ - SessionRow( # rocm - kernels=[ - KernelRow( - id=pending_session_kernel_ids[0].kernel_ids[0], - session_id=pending_session_kernel_ids[0].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("1"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-12-28T23:59:59+00:00"), - ), - ], + create_mock_session( # rocm + SessionId(uuid4()), access_key=AccessKey("user01"), - id=pending_session_kernel_ids[0].session_id, - creation_id="aaa100", - name="eps01", - session_type=SessionTypes.BATCH, - status=SessionStatus.PENDING, - cluster_mode="single-node", - cluster_size=1, - scaling_group_name=example_sgroup_name1, requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("1"), + "cpu": Decimal(2), + "mem": Decimal(1024), + "cuda.shares": Decimal(0), + "rocm.devices": Decimal(1), }), - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2021-12-28T23:59:59+00:00"), ), - SessionRow( # cuda - kernels=[ - KernelRow( - id=pending_session_kernel_ids[1].kernel_ids[0], - session_id=pending_session_kernel_ids[1].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("1.0"), - "mem": Decimal("2048"), - "cuda.shares": Decimal("0.5"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2022-02-01T23:59:59+00:00"), - ), - ], + create_mock_session( # cuda + SessionId(uuid4()), access_key=AccessKey("user02"), - id=pending_session_kernel_ids[1].session_id, - creation_id="aaa101", - name="eps02", - session_type=SessionTypes.BATCH, - status=SessionStatus.PENDING, - cluster_mode="single-node", - cluster_size=1, - scaling_group_name=example_sgroup_name1, requested_slots=ResourceSlot({ - "cpu": Decimal("1.0"), - "mem": Decimal("2048"), + "cpu": Decimal(1), + "mem": Decimal(2048), "cuda.shares": Decimal("0.5"), - "rocm.devices": Decimal("0"), + "rocm.devices": Decimal(0), }), - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2022-02-01T23:59:59+00:00"), ), - SessionRow( # cpu-only - kernels=[ - KernelRow( - id=pending_session_kernel_ids[2].kernel_ids[0], - session_id=pending_session_kernel_ids[2].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("0.4"), - "mem": Decimal("512"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-12-01T23:59:59+00:00"), - ), - KernelRow( - id=pending_session_kernel_ids[2].kernel_ids[1], - session_id=pending_session_kernel_ids[2].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role="sub", - cluster_idx=2, - local_rank=1, - cluster_hostname="sub1", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("0.3"), - "mem": Decimal("256"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-12-01T23:59:59+00:00"), - ), - KernelRow( - id=pending_session_kernel_ids[2].kernel_ids[2], - session_id=pending_session_kernel_ids[2].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role="sub", - cluster_idx=3, - local_rank=2, - cluster_hostname="sub2", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("0.3"), - "mem": Decimal("256"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-12-01T23:59:59+00:00"), - ), - ], + create_mock_session( # cpu-only, single-node cluster + SessionId(uuid4()), access_key=AccessKey("user03"), - status_data={}, - id=pending_session_kernel_ids[2].session_id, - creation_id="aaa102", - name="eps03", - session_type=SessionTypes.BATCH, - status=SessionStatus.PENDING, - cluster_mode="single-node", - cluster_size=3, - scaling_group_name=example_sgroup_name1, - requested_slots=ResourceSlot({ - "cpu": Decimal("1.0"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2021-12-01T23:59:59+00:00"), + requested_slots=ResourceSlot({"cpu": Decimal("1.0"), "mem": Decimal(1024)}), + kernel_opts=[ + KernelOpt(ResourceSlot({"cpu": Decimal("0.4"), "mem": Decimal(512)})), + KernelOpt(ResourceSlot({"cpu": Decimal("0.3"), "mem": Decimal(256)})), + KernelOpt(ResourceSlot({"cpu": Decimal("0.3"), "mem": Decimal(256)})), + ], ), ] -@pytest.fixture -def example_existing_sessions() -> Sequence[SessionRow]: +def create_example_existing_sessions() -> Sequence[SessionRow]: return [ - SessionRow( - kernels=[ - KernelRow( - id=existing_session_kernel_ids[0].kernel_ids[0], - session_id=existing_session_kernel_ids[0].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("1.0"), - "mem": Decimal("512"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2022-02-05T00:00:00+00:00"), + create_mock_session( + SessionId(uuid4()), + status=SessionStatus.RUNNING, + access_key=AccessKey("user01"), + requested_slots=ResourceSlot({ + "cpu": Decimal(3), + "mem": Decimal(1024), + "cuda.shares": Decimal(0), + "rocm.devices": Decimal(1), + }), + kernel_opts=[ + KernelOpt( + ResourceSlot({ + "cpu": Decimal(1), + "mem": Decimal(512), + "cuda.shares": Decimal(0), + "rocm.devices": Decimal(0), + }) ), - KernelRow( - id=existing_session_kernel_ids[0].kernel_ids[1], - session_id=existing_session_kernel_ids[0].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role="sub", - cluster_idx=2, - local_rank=1, - cluster_hostname="sub1", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("2.0"), - "mem": Decimal("512"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("1"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2022-02-05T00:00:00+00:00"), + KernelOpt( + ResourceSlot({ + "cpu": Decimal(2), + "mem": Decimal(512), + "cuda.shares": Decimal(0), + "rocm.devices": Decimal(1), + }) ), ], - access_key=AccessKey("user01"), - id=existing_session_kernel_ids[0].session_id, - name="ees01", - session_type=SessionTypes.BATCH, - status=SessionStatus.RUNNING, - cluster_mode="single-node", - cluster_size=2, - occupying_slots=ResourceSlot({ - "cpu": Decimal("3.0"), - "mem": Decimal("1024"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("1"), - }), - scaling_group_name=example_sgroup_name1, - **_common_dummy_for_existing_session, ), - SessionRow( - kernels=[ - KernelRow( - id=existing_session_kernel_ids[1].kernel_ids[0], - session_id=existing_session_kernel_ids[1].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("1.0"), - "mem": Decimal("2048"), - "cuda.shares": Decimal("0.5"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2021-09-03T00:00:00+00:00"), - ), - ], - access_key=AccessKey("user02"), - id=existing_session_kernel_ids[1].session_id, - session_type=SessionTypes.BATCH, + create_mock_session( + SessionId(uuid4()), status=SessionStatus.RUNNING, - name="ees02", - cluster_mode="single-node", - cluster_size=1, - occupying_slots=ResourceSlot({ - "cpu": Decimal("1.0"), - "mem": Decimal("2048"), + access_key=AccessKey("user02"), + requested_slots=ResourceSlot({ + "cpu": Decimal(1), + "mem": Decimal(2048), "cuda.shares": Decimal("0.5"), - "rocm.devices": Decimal("0"), + "rocm.devices": Decimal(0), }), - scaling_group_name=example_sgroup_name1, - **_common_dummy_for_existing_session, ), - SessionRow( - kernels=[ - KernelRow( - id=existing_session_kernel_ids[2].kernel_ids[0], - session_id=existing_session_kernel_ids[2].session_id, - access_key="dummy-access-key", - agent=None, - agent_addr=None, - cluster_role=DEFAULT_ROLE, - cluster_idx=1, - local_rank=0, - cluster_hostname=f"{DEFAULT_ROLE}0", - architecture=common_image_ref.architecture, - registry=common_image_ref.registry, - image=common_image_ref.name, - requested_slots=ResourceSlot({ - "cpu": Decimal("4.0"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), - }), - bootstrap_script=None, - startup_command=None, - created_at=dtparse("2022-01-15T00:00:00+00:00"), - ), - ], - access_key=AccessKey("user03"), - id=existing_session_kernel_ids[2].session_id, - session_type=SessionTypes.BATCH, + create_mock_session( + SessionId(uuid4()), status=SessionStatus.RUNNING, - name="ees03", - cluster_mode="single-node", - cluster_size=1, - occupying_slots=ResourceSlot({ - "cpu": Decimal("4.0"), - "mem": Decimal("4096"), - "cuda.shares": Decimal("0"), - "rocm.devices": Decimal("0"), + access_key=AccessKey("user03"), + requested_slots=ResourceSlot({ + "cpu": Decimal(4), + "mem": Decimal(4096), + "cuda.shares": Decimal(0), + "rocm.devices": Decimal(0), }), - scaling_group_name=example_sgroup_name1, - **_common_dummy_for_existing_session, ), ] -def _find_and_pop_picked_session(pending_sessions, picked_session_id) -> SessionRow: - for picked_idx, pending_sess in enumerate(pending_sessions): - if pending_sess.id == picked_session_id: - break - else: - # no matching entry for picked session? - raise RuntimeError("should not reach here") - return pending_sessions.pop(picked_idx) - - -def _update_agent_assignment( - agents: list[AgentRow], - picked_agent_id: AgentId, - occupied_slots: ResourceSlot, -) -> None: - for ag in agents: - if ag.id == picked_agent_id: - ag.occupied_slots += occupied_slots - - @pytest.mark.asyncio -async def test_fifo_scheduler( - example_agents: Sequence[AgentRow], - example_pending_sessions: Sequence[SessionRow], - example_existing_sessions: Sequence[SessionRow], -) -> None: +async def test_fifo_scheduler() -> None: + example_agents = create_example_agents() + example_pending_sessions = create_example_pending_sessions() + example_existing_sessions = create_example_existing_sessions() scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) agstate_cls = DispersedAgentSelector.get_state_cls() agselector = DispersedAgentSelector( @@ -950,7 +250,7 @@ async def test_fifo_scheduler( example_existing_sessions, ) assert picked_session_id == example_pending_sessions[0].id - picked_session = _find_and_pop_picked_session( + picked_session = find_and_pop_picked_session( example_pending_sessions, picked_session_id, ) @@ -962,11 +262,10 @@ async def test_fifo_scheduler( @pytest.mark.asyncio -async def test_lifo_scheduler( - example_agents: Sequence[AgentRow], - example_pending_sessions: Sequence[SessionRow], - example_existing_sessions: Sequence[SessionRow], -) -> None: +async def test_lifo_scheduler() -> None: + example_agents = create_example_agents() + example_pending_sessions = create_example_pending_sessions() + example_existing_sessions = create_example_existing_sessions() scheduler = LIFOSlotScheduler(ScalingGroupOpts(), {}) agstate_cls = DispersedAgentSelector.get_state_cls() agselector = DispersedAgentSelector( @@ -981,7 +280,7 @@ async def test_lifo_scheduler( example_existing_sessions, ) assert picked_session_id == example_pending_sessions[2].id - picked_session = _find_and_pop_picked_session( + picked_session = find_and_pop_picked_session( example_pending_sessions, picked_session_id, ) @@ -993,10 +292,9 @@ async def test_lifo_scheduler( @pytest.mark.asyncio -async def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( - example_mixed_agents: Sequence[AgentRow], - example_pending_sessions: Sequence[SessionRow], -) -> None: +async def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators() -> None: + example_mixed_agents = create_example_mixed_agents() + example_pending_sessions = create_example_pending_sessions() scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) agstate_cls = DispersedAgentSelector.get_state_cls() agselector = DispersedAgentSelector( @@ -1013,7 +311,7 @@ async def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( [], ) assert picked_session_id == example_pending_sessions[0].id - picked_session = _find_and_pop_picked_session( + picked_session = find_and_pop_picked_session( example_pending_sessions, picked_session_id, ) @@ -1032,22 +330,24 @@ async def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( assert agent_id == AgentId("i-cpu") -def gen_pending_for_holb_tests(session_id: str, status_data: Mapping[str, Any]) -> SessionRow: - return SessionRow( - id=SessionId(session_id), # type: ignore +_holb_capacity = ResourceSlot({"cpu": Decimal(4), "mem": Decimal(4096)}) +_holb_session_ids = { + "s0": SessionId(uuid4()), + "s1": SessionId(uuid4()), + "s2": SessionId(uuid4()), +} + + +def create_pending_session_holb( + session_id: SessionId, + status_data: dict[str, Any], +) -> SessionRow: + return create_mock_session( + session_id, + status=SessionStatus.PENDING, status_data=status_data, - name=secrets.token_hex(8), access_key=AccessKey("ak1"), - creation_id=secrets.token_urlsafe(8), - kernels=[], - session_type=SessionTypes.INTERACTIVE, - cluster_mode=ClusterMode.SINGLE_NODE, - cluster_size=1, - scaling_group_name=example_sgroup_name1, requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1024)}), - target_sgroup_names=[], - **_common_dummy_for_pending_session, - created_at=dtparse("2020-03-21T00:00:00+00:00"), ) @@ -1057,12 +357,12 @@ def test_fifo_scheduler_hol_blocking_avoidance_empty_status_data() -> None: """ scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {"num_retries_to_skip": 5}) pending_sessions = [ - gen_pending_for_holb_tests("s0", {}), - gen_pending_for_holb_tests("s1", {}), - gen_pending_for_holb_tests("s2", {}), + create_pending_session_holb(_holb_session_ids["s0"], {}), + create_pending_session_holb(_holb_session_ids["s1"], {}), + create_pending_session_holb(_holb_session_ids["s2"], {}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s0" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s0"] def test_fifo_scheduler_hol_blocking_avoidance_config() -> None: @@ -1072,21 +372,21 @@ def test_fifo_scheduler_hol_blocking_avoidance_config() -> None: """ scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {"num_retries_to_skip": 0}) pending_sessions = [ - gen_pending_for_holb_tests("s0", {"scheduler": {"retries": 5}}), - gen_pending_for_holb_tests("s1", {}), - gen_pending_for_holb_tests("s2", {}), + create_pending_session_holb(_holb_session_ids["s0"], {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s1"], {}), + create_pending_session_holb(_holb_session_ids["s2"], {}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s0" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s0"] scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {"num_retries_to_skip": 5}) pending_sessions = [ - gen_pending_for_holb_tests("s0", {"scheduler": {"retries": 5}}), - gen_pending_for_holb_tests("s1", {"scheduler": {"retries": 4}}), - gen_pending_for_holb_tests("s2", {"scheduler": {"retries": 3}}), + create_pending_session_holb(_holb_session_ids["s0"], {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s1"], {"scheduler": {"retries": 4}}), + create_pending_session_holb(_holb_session_ids["s2"], {"scheduler": {"retries": 3}}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s1" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s1"] def test_fifo_scheduler_hol_blocking_avoidance_skips() -> None: @@ -1096,20 +396,20 @@ def test_fifo_scheduler_hol_blocking_avoidance_skips() -> None: """ scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {"num_retries_to_skip": 5}) pending_sessions = [ - gen_pending_for_holb_tests("s0", {"scheduler": {"retries": 5}}), - gen_pending_for_holb_tests("s1", {}), - gen_pending_for_holb_tests("s2", {}), + create_pending_session_holb(_holb_session_ids["s0"], {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s1"], {}), + create_pending_session_holb(_holb_session_ids["s2"], {}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s1" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s1"] pending_sessions = [ - gen_pending_for_holb_tests("s0", {"scheduler": {"retries": 5}}), - gen_pending_for_holb_tests("s1", {"scheduler": {"retries": 10}}), - gen_pending_for_holb_tests("s2", {}), + create_pending_session_holb(_holb_session_ids["s0"], {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s1"], {"scheduler": {"retries": 10}}), + create_pending_session_holb(_holb_session_ids["s2"], {}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s2" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s2"] def test_fifo_scheduler_hol_blocking_avoidance_all_skipped() -> None: @@ -1119,12 +419,12 @@ def test_fifo_scheduler_hol_blocking_avoidance_all_skipped() -> None: """ scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {"num_retries_to_skip": 5}) pending_sessions = [ - gen_pending_for_holb_tests("s0", {"scheduler": {"retries": 5}}), - gen_pending_for_holb_tests("s1", {"scheduler": {"retries": 5}}), - gen_pending_for_holb_tests("s2", {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s0"], {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s1"], {"scheduler": {"retries": 5}}), + create_pending_session_holb(_holb_session_ids["s2"], {"scheduler": {"retries": 5}}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s0" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s0"] def test_fifo_scheduler_hol_blocking_avoidance_no_skip() -> None: @@ -1134,19 +434,18 @@ def test_fifo_scheduler_hol_blocking_avoidance_no_skip() -> None: """ scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {"num_retries_to_skip": 5}) pending_sessions = [ - gen_pending_for_holb_tests("s0", {}), - gen_pending_for_holb_tests("s1", {"scheduler": {"retries": 10}}), - gen_pending_for_holb_tests("s2", {}), + create_pending_session_holb(_holb_session_ids["s0"], {}), + create_pending_session_holb(_holb_session_ids["s1"], {"scheduler": {"retries": 10}}), + create_pending_session_holb(_holb_session_ids["s2"], {}), ] - picked_session_id = scheduler.pick_session(example_total_capacity, pending_sessions, []) - assert picked_session_id == "s0" + picked_session_id = scheduler.pick_session(_holb_capacity, pending_sessions, []) + assert picked_session_id == _holb_session_ids["s0"] @pytest.mark.asyncio -async def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( - example_mixed_agents: Sequence[AgentRow], - example_pending_sessions: Sequence[SessionRow], -) -> None: +async def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators() -> None: + example_mixed_agents = create_example_mixed_agents() + example_pending_sessions = create_example_pending_sessions() # Check the reverse with the LIFO scheduler. # The result must be same. sgroup_opts = ScalingGroupOpts(agent_selection_strategy=AgentSelectionStrategy.DISPERSED) @@ -1162,7 +461,7 @@ async def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( for idx in range(3): picked_session_id = scheduler.pick_session(total_capacity, example_pending_sessions, []) assert picked_session_id == example_pending_sessions[-1].id - picked_session = _find_and_pop_picked_session(example_pending_sessions, picked_session_id) + picked_session = find_and_pop_picked_session(example_pending_sessions, picked_session_id) agent_id = await agselector.assign_agent_for_session( example_mixed_agents, picked_session, @@ -1179,11 +478,10 @@ async def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( @pytest.mark.asyncio -async def test_drf_scheduler( - example_agents: Sequence[AgentRow], - example_pending_sessions: Sequence[SessionRow], - example_existing_sessions: Sequence[SessionRow], -) -> None: +async def test_drf_scheduler() -> None: + example_agents = create_example_agents() + example_pending_sessions = create_example_pending_sessions() + example_existing_sessions = create_example_existing_sessions() sgroup_opts = ScalingGroupOpts(agent_selection_strategy=AgentSelectionStrategy.DISPERSED) scheduler = DRFScheduler(sgroup_opts, {}) agstate_cls = DispersedAgentSelector.get_state_cls() @@ -1200,7 +498,7 @@ async def test_drf_scheduler( ) pprint(example_pending_sessions) assert picked_session_id == example_pending_sessions[1].id - picked_session = _find_and_pop_picked_session( + picked_session = find_and_pop_picked_session( example_pending_sessions, picked_session_id, ) @@ -1285,9 +583,9 @@ async def test_manually_assign_agent_available( AgentRegistry, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock ], mocker: pytest_mock.MockerFixture, - example_agents: Sequence[AgentRow], - example_pending_sessions: Sequence[SessionRow], ) -> None: + example_agents = create_example_agents() + example_pending_sessions = create_example_pending_sessions() mock_local_config = MagicMock() ( @@ -1397,7 +695,9 @@ async def test_manually_assign_agent_available( @pytest.mark.asyncio @mock.patch("ai.backend.manager.scheduler.predicates.datetime") -async def test_multiple_timezones_for_reserved_batch_session_predicate(mock_dt: MagicMock) -> None: +async def test_multiple_timezones_for_reserved_batch_session_predicate( + mock_dt: MagicMock, +) -> None: mock_db_conn = MagicMock() mock_sched_ctx = MagicMock() mock_sess_ctx = MagicMock() @@ -1439,133 +739,3 @@ async def test_multiple_timezones_for_reserved_batch_session_predicate(mock_dt: mock_db_conn.scalar = AsyncMock(return_value=None) result = await check_reserved_batch_session(mock_db_conn, mock_sched_ctx, mock_sess_ctx) assert result.passed - - -@pytest.mark.asyncio -@pytest.mark.parametrize("example_agents_multi_homogeneous", [{"repeat": 10}], indirect=True) -@pytest.mark.parametrize("example_homogeneous_pending_sessions", [{"repeat": 20}], indirect=True) -async def test_agent_selection_strategy_rr( - example_agents_multi_homogeneous: Sequence[AgentRow], - example_homogeneous_pending_sessions: Sequence[SessionRow], - example_existing_sessions: Sequence[SessionRow], -) -> None: - sgroup_opts = ScalingGroupOpts( - agent_selection_strategy=AgentSelectionStrategy.ROUNDROBIN, - ) - scheduler = FIFOSlotScheduler( - sgroup_opts, - {}, - ) - - agstate_cls = RoundRobinAgentSelector.get_state_cls() - agselector = RoundRobinAgentSelector( - sgroup_opts, - {}, - agent_selection_resource_priority, - state_store=InMemoryResourceGroupStateStore(agstate_cls), - ) - - num_agents = len(example_agents_multi_homogeneous) - total_capacity = sum( - (ag.available_slots for ag in example_agents_multi_homogeneous), ResourceSlot() - ) - agent_ids = [] - # Repeat the allocation for two iterations - for _ in range(num_agents * 2): - picked_session_id = scheduler.pick_session( - total_capacity, - example_homogeneous_pending_sessions, - example_existing_sessions, - ) - assert picked_session_id == example_homogeneous_pending_sessions[0].id - picked_session = _find_and_pop_picked_session( - example_homogeneous_pending_sessions, - picked_session_id, - ) - agent_ids.append( - await agselector.assign_agent_for_session( - example_agents_multi_homogeneous, - picked_session, - ) - ) - assert agent_ids == [AgentId(f"i-{idx:03d}") for idx in range(num_agents)] * 2 - - -@pytest.mark.asyncio -async def test_agent_selection_strategy_rr_skip_unacceptable_agents( - example_agents_many: Sequence[AgentRow], -) -> None: - # example_agents_many: - # i-001: cpu=8, mem=4096, cuda.shares=4.0 - # i-002: cpu=4, mem=2048, cuda.shares=2.0 - # i-003: cpu=2, mem=1024, cuda.shares=1.0 - # i-004: cpu=1, mem=512, cuda.shares=0.5 - agents: list[AgentRow] = [*example_agents_many] - - # all pending sessions: - # cpu=2, mem=500 - pending_sessions = [ - create_pending_session( - SessionId(uuid4()), - KernelId(uuid4()), - ResourceSlot({ - "cpu": Decimal("2"), - "mem": Decimal("500"), - }), - ) - for _ in range(8) - ] - - sgroup_opts = ScalingGroupOpts( - agent_selection_strategy=AgentSelectionStrategy.ROUNDROBIN, - ) - scheduler = FIFOSlotScheduler( - sgroup_opts, - {}, - ) - - agstate_cls = RoundRobinAgentSelector.get_state_cls() - agselector = RoundRobinAgentSelector( - sgroup_opts, - {}, - agent_selection_resource_priority, - state_store=InMemoryResourceGroupStateStore(agstate_cls), - ) - - total_capacity = sum((ag.available_slots for ag in agents), ResourceSlot()) - - results: list[AgentId | None] = [] - scheduled_sessions: list[SessionRow] = [] - - for _ in range(8): - picked_session_id = scheduler.pick_session( - total_capacity, - pending_sessions, - scheduled_sessions, - ) - assert picked_session_id is not None - picked_session = _find_and_pop_picked_session( - pending_sessions, - picked_session_id, - ) - scheduled_sessions.append(picked_session) - result = await agselector.assign_agent_for_session( - agents, - picked_session, - ) - if result is not None: - _update_agent_assignment(agents, result, picked_session.requested_slots) - results.append(result) - - print() - for ag in agents: - print( - ag.id, - f"{ag.occupied_slots["cpu"]}/{ag.available_slots["cpu"]}", - f"{ag.occupied_slots["mem"]}/{ag.available_slots["mem"]}", - ) - # As more sessions have the assigned agents, the remaining capacity diminishes - # and the range of round-robin also becomes limited. - # When there is no assignable agent, it should return None. - assert len(results) == 8 - assert results == ["i-001", "i-002", "i-003", "i-001", "i-002", "i-001", "i-001", None] From 4987a9e38d8b0c0fa0647f6aafe70bcf1e2eec00 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Fri, 20 Sep 2024 01:02:04 +0900 Subject: [PATCH 2/3] doc: Add news fragment --- changes/2848.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/2848.feature.md diff --git a/changes/2848.feature.md b/changes/2848.feature.md new file mode 100644 index 0000000000..51636bcefb --- /dev/null +++ b/changes/2848.feature.md @@ -0,0 +1 @@ +Implement the priority-aware scheduler that applies to any arbitrary scheduler plugin From db992f3915b733b8567996af4d34875b1fab33d9 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Fri, 20 Sep 2024 14:19:49 +0900 Subject: [PATCH 3/3] refactor: Simplify and update comments --- src/ai/backend/manager/models/utils.py | 5 +++-- src/ai/backend/manager/scheduler/types.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py index 0b76afd848..0dc596e8ea 100644 --- a/src/ai/backend/manager/models/utils.py +++ b/src/ai/backend/manager/models/utils.py @@ -422,8 +422,9 @@ async def retry_txn(max_attempts: int = 20) -> AsyncIterator[AttemptManager]: stop=stop_after_attempt(max_attempts), retry=retry_if_exception_type(TryAgain) | retry_if_exception_type(DBAPIError), ): - # The caller of a Python generator cannot catch the exceptions thrown in the block, - # so we need to pass AttemptManager. + # Since Python generators cannot catch the exceptions thrown in the code block executed + # when yielded because stack frames are switched, we should pass AttemptManager to + # provide a shared exception handling mechanism like the original execute_with_retry(). yield attempt assert attempt.retry_state.outcome is not None exc = attempt.retry_state.outcome.exception() diff --git a/src/ai/backend/manager/scheduler/types.py b/src/ai/backend/manager/scheduler/types.py index bc11cb2c84..f02adfe690 100644 --- a/src/ai/backend/manager/scheduler/types.py +++ b/src/ai/backend/manager/scheduler/types.py @@ -122,7 +122,7 @@ def prioritize(pending_sessions: Sequence[SessionRow]) -> tuple[int, list[Sessio return -1, [] priorities = {s.priority for s in pending_sessions} assert len(priorities) > 0 - top_priority = sorted(priorities, reverse=True)[0] + top_priority = max(priorities) return top_priority, [*filter(lambda s: s.priority == top_priority, pending_sessions)] @abstractmethod