Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement priority scheduler #2848

Merged
merged 3 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2848.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement the priority-aware scheduler that applies to any arbitrary scheduler plugin
20 changes: 20 additions & 0 deletions src/ai/backend/manager/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.orm import sessionmaker
from tenacity import (
AsyncRetrying,
AttemptManager,
RetryError,
TryAgain,
retry_if_exception_type,
Expand Down Expand Up @@ -414,6 +415,25 @@ async def execute_with_retry(txn_func: Callable[[], Awaitable[TQueryResult]]) ->
return result


async def retry_txn(max_attempts: int = 20) -> AsyncIterator[AttemptManager]:
try:
async for attempt in AsyncRetrying(
wait=wait_exponential(multiplier=0.02, min=0.02, max=5.0),
stop=stop_after_attempt(max_attempts),
retry=retry_if_exception_type(TryAgain) | retry_if_exception_type(DBAPIError),
):
# Since Python generators cannot catch the exceptions thrown in the code block executed
# when yielded because stack frames are switched, we should pass AttemptManager to
# provide a shared exception handling mechanism like the original execute_with_retry().
yield attempt
assert attempt.retry_state.outcome is not None
exc = attempt.retry_state.outcome.exception()
if isinstance(exc, DBAPIError) and not is_db_retry_error(exc):
raise exc
except RetryError:
raise RuntimeError(f"DB serialization failed after {max_attempts} retries")


JSONCoalesceExpr: TypeAlias = sa.sql.elements.BinaryExpression


Expand Down
326 changes: 185 additions & 141 deletions src/ai/backend/manager/scheduler/dispatcher.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions src/ai/backend/manager/scheduler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ def config_iv(self) -> t.Dict:
"""
raise NotImplementedError

@staticmethod
def prioritize(pending_sessions: Sequence[SessionRow]) -> tuple[int, list[SessionRow]]:
"""
Filter the pending session list by the top priority among the observed priorities of the
given pending sessions.
"""
if not pending_sessions:
return -1, []
priorities = {s.priority for s in pending_sessions}
assert len(priorities) > 0
top_priority = max(priorities)
return top_priority, [*filter(lambda s: s.priority == top_priority, pending_sessions)]

@abstractmethod
def pick_session(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/manager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ python_test_utils(
sources=[
"conftest.py",
"model_factory.py",
"scheduler_utils.py",
],
dependencies=[
":fixtures",
Expand Down
17 changes: 13 additions & 4 deletions tests/manager/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,23 @@ def logging_config():
yield config


@pytest.fixture(scope="session")
def ipc_base_path() -> Path:
ipc_base_path = Path.cwd() / f"tmp/backend.ai/manager-testing/ipc-{test_id}"
ipc_base_path.mkdir(parents=True, exist_ok=True)
return ipc_base_path


@pytest.fixture(scope="session")
def local_config(
test_id,
ipc_base_path: Path,
logging_config,
etcd_container, # noqa: F811
redis_container, # noqa: F811
postgres_container, # noqa: F811
test_db,
) -> Iterator[LocalConfig]:
ipc_base_path = Path.cwd() / f"tmp/backend.ai/manager-testing/ipc-{test_id}"
ipc_base_path.mkdir(parents=True, exist_ok=True)
etcd_addr = etcd_container[1]
redis_addr = redis_container[1]
postgres_addr = postgres_container[1]
Expand Down Expand Up @@ -480,9 +486,12 @@ async def clean_fixture() -> None:


@pytest.fixture
def file_lock_factory(local_config, request) -> Callable[[str], FileLock]:
def file_lock_factory(
ipc_base_path: Path,
request: pytest.FixtureRequest,
) -> Callable[[str], FileLock]:
def _make_lock(lock_id: str) -> FileLock:
lock_path = local_config["manager"]["ipc-base-path"] / f"testing.{lock_id}.lock"
lock_path = ipc_base_path / f"testing.{lock_id}.lock"
lock = FileLock(lock_path, timeout=0)
request.addfinalizer(partial(lock_path.unlink, missing_ok=True))
return lock
Expand Down
23 changes: 22 additions & 1 deletion tests/manager/models/test_dbutils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import annotations

import asyncio
from unittest.mock import MagicMock

import aiotools
import pytest
import sqlalchemy as sa
from sqlalchemy.exc import DBAPIError

from ai.backend.manager.models.utils import execute_with_retry, execute_with_txn_retry
from ai.backend.manager.models.utils import (
execute_with_retry,
execute_with_txn_retry,
retry_txn,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -95,3 +103,16 @@ async def txn_func_temporary_serialization_failure(db_session):
txn_func_temporary_serialization_failure, database_engine.begin_session, conn
)
assert ret == 1234


@pytest.mark.asyncio
async def test_retry_txn_as_code_block() -> None:
orig = MagicMock()
orig.pgcode = "40001"
retry_count = 0
with pytest.raises(RuntimeError):
async for attempt in retry_txn(4):
with attempt:
retry_count += 1
raise DBAPIError(None, None, orig)
assert retry_count > 3
207 changes: 207 additions & 0 deletions tests/manager/scheduler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from __future__ import annotations

import secrets
from collections.abc import (
Iterator,
Mapping,
Sequence,
)
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import (
Any,
Final,
)
from uuid import uuid4

from dateutil.parser import parse as dtparse

from ai.backend.common.docker import ImageRef
from ai.backend.common.types import (
AccessKey,
AgentId,
KernelId,
ResourceSlot,
SessionId,
SessionTypes,
)
from ai.backend.manager.defs import DEFAULT_ROLE
from ai.backend.manager.models.agent import AgentRow
from ai.backend.manager.models.kernel import KernelRow, KernelStatus
from ai.backend.manager.models.session import SESSION_PRIORITY_DEFUALT, SessionRow, SessionStatus

ARCH_FOR_TEST: Final = "x86_64"

agent_selection_resource_priority: Final = ["cuda", "rocm", "tpu", "cpu", "mem"]

common_image_ref: Final = ImageRef(
"lablup/python:3.6-ubunt18.04",
["*"],
architecture=ARCH_FOR_TEST,
)

example_group_id = uuid4()
example_total_capacity = ResourceSlot({"cpu": "4.0", "mem": "4096"})
example_sgroup_name1: Final = "sg01"
example_sgroup_name2: Final = "sg02"

_common_dummy_for_pending_session: Mapping[str, Any] = dict(
domain_name="default",
group_id=example_group_id,
vfolder_mounts=[],
environ={},
bootstrap_script=None,
startup_command=None,
use_host_network=False,
)

_timestamp_count = 0


def generate_timestamp() -> datetime:
global _timestamp_count
val = dtparse("2021-12-28T23:59:59+00:00")
val += timedelta(_timestamp_count * 10)
_timestamp_count += 1
return val


def generate_role() -> Iterator[tuple[str, int, int]]:
yield ("main", 1, 0)
sub_idx = 1
while True:
yield ("sub", sub_idx, sub_idx)
sub_idx += 1


@dataclass
class KernelOpt:
requested_slots: ResourceSlot
kernel_id: KernelId = field(default_factory=lambda: KernelId(uuid4()))
image: ImageRef = common_image_ref


_sess_kern_status_map = {
SessionStatus.PENDING: KernelStatus.PENDING,
SessionStatus.SCHEDULED: KernelStatus.SCHEDULED,
SessionStatus.RUNNING: KernelStatus.RUNNING,
SessionStatus.TERMINATED: KernelStatus.TERMINATED,
}


def find_and_pop_picked_session(pending_sessions, picked_session_id) -> SessionRow:
for picked_idx, pending_sess in enumerate(pending_sessions):
if pending_sess.id == picked_session_id:
break
else:
# no matching entry for picked session?
raise RuntimeError("should not reach here")
return pending_sessions.pop(picked_idx)


def update_agent_assignment(
agents: Sequence[AgentRow],
picked_agent_id: AgentId,
occupied_slots: ResourceSlot,
) -> None:
for ag in agents:
if ag.id == picked_agent_id:
ag.occupied_slots += occupied_slots


def create_mock_kernel(
session_id: SessionId,
kernel_id: KernelId,
requested_slots: ResourceSlot,
*,
status: KernelStatus = KernelStatus.PENDING,
cluster_role: str = DEFAULT_ROLE,
cluster_idx: int = 1,
local_rank: int = 0,
) -> KernelRow:
return KernelRow(
id=session_id,
session_id=kernel_id,
status=status,
access_key="dummy-access-key",
agent=None,
agent_addr=None,
cluster_role=cluster_role,
cluster_idx=cluster_idx,
local_rank=local_rank,
cluster_hostname=f"{cluster_role}{cluster_idx}",
architecture=common_image_ref.architecture,
registry=common_image_ref.registry,
image=common_image_ref.name,
requested_slots=requested_slots,
bootstrap_script=None,
startup_command=None,
created_at=generate_timestamp(),
)


def create_mock_session(
session_id: SessionId,
requested_slots: ResourceSlot,
*,
access_key: AccessKey = AccessKey("user01"),
status: SessionStatus = SessionStatus.PENDING,
status_data: dict[str, Any] | None = None,
kernel_opts: Sequence[KernelOpt] | None = None,
priority: int = SESSION_PRIORITY_DEFUALT,
) -> SessionRow:
"""Create a simple single-kernel pending session."""
if kernel_opts is None:
# Create a single pending kernel as a default
kernel_opts = [KernelOpt(requested_slots=requested_slots)]
return SessionRow(
kernels=[
create_mock_kernel(
session_id,
kopt.kernel_id,
kopt.requested_slots,
status=_sess_kern_status_map[status],
cluster_role=role_name,
cluster_idx=role_idx,
local_rank=local_rank,
)
for kopt, (role_name, role_idx, local_rank) in zip(kernel_opts, generate_role())
],
priority=priority,
access_key=access_key,
id=session_id,
creation_id=secrets.token_hex(8),
name=f"session-{secrets.token_hex(4)}",
session_type=SessionTypes.BATCH,
status=status,
status_data=status_data,
cluster_mode="single-node",
cluster_size=len(kernel_opts),
scaling_group_name=example_sgroup_name1,
requested_slots=requested_slots,
occupying_slots=(
requested_slots
if status not in (SessionStatus.PENDING, SessionStatus.SCHEDULED)
else ResourceSlot()
),
target_sgroup_names=[],
**_common_dummy_for_pending_session,
created_at=generate_timestamp(),
)


def create_mock_agent(
id: AgentId,
*,
scaling_group: str = example_sgroup_name1,
available_slots: ResourceSlot,
occupied_slots: ResourceSlot = ResourceSlot(),
) -> AgentRow:
return AgentRow(
id=id,
addr="10.0.1.1:6001",
architecture=ARCH_FOR_TEST,
scaling_group=scaling_group,
available_slots=available_slots,
occupied_slots=occupied_slots,
)
Loading
Loading