Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce Raftify and RaftGlobalTimer to replace DistributedGlobalTimer #1506

Closed
wants to merge 16 commits into from
Closed
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,6 @@ docs/manager/rest-reference/openapi.json

/DIST-INFO
/INSTALL-INFO

# Raft cluster config
raft-cluster-config.toml
1 change: 1 addition & 0 deletions changes/1506.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Raft-based leader election process to manager group in HA condition in order to make their states consistent.
8 changes: 7 additions & 1 deletion configs/manager/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pool-recycle = 50


[manager]
num-proc = 4
num-proc = 3
service-addr = { host = "0.0.0.0", port = 8081 }
#user = "nobody"
#group = "nobody"
Expand All @@ -34,6 +34,11 @@ hide-agents = true
# The order of agent selection.
agent-selection-resource-priority = ["cuda", "rocm", "tpu", "cpu", "mem"]

[raft]
heartbeat-tick = 3
election-tick = 10
log-dir = "./logs"

[docker-registry]
ssl-verify = false

Expand All @@ -47,6 +52,7 @@ drivers = ["console"]
"aiotools" = "INFO"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"ai.backend.manager.server.raft" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"

Expand Down
197 changes: 105 additions & 92 deletions python.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,4 @@ backend.ai-krunner-alpine==5.1.0
backend.ai-krunner-static-gnu==4.1.1

etcd-client-py==0.2.4
raftify==0.1.65
74 changes: 73 additions & 1 deletion src/ai/backend/common/distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import abc
import asyncio
import logging
from typing import TYPE_CHECKING, Callable, Final

from aiomonitor.task import preserve_termination_log
from raftify import RaftNode

from .logging import BraceStyleAdapter

Expand All @@ -16,7 +18,77 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]


class GlobalTimer:
class AbstractGlobalTimer(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def generate_tick(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def join(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def leave(self) -> None:
raise NotImplementedError


class RaftGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

_event_producer: Final[EventProducer]

def __init__(
self,
raft_node: RaftNode,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
) -> None:
self._event_producer = event_producer
self._event_factory = event_factory
self._stopped = False
self.interval = interval
self.initial_delay = initial_delay
self.raft_node = raft_node

async def generate_tick(self) -> None:
try:
await asyncio.sleep(self.initial_delay)
if self._stopped:
return
while True:
try:
if self._stopped:
return
if await self.raft_node.is_leader():
await self._event_producer.produce_event(self._event_factory())
if self._stopped:
return
await asyncio.sleep(self.interval)
except asyncio.TimeoutError: # timeout raised from etcd lock
log.warn("timeout raised while trying to acquire lock. retrying...")
except asyncio.CancelledError:
pass

async def join(self) -> None:
self._tick_task = asyncio.create_task(self.generate_tick())

async def leave(self) -> None:
self._stopped = True
await asyncio.sleep(0)
if not self._tick_task.done():
try:
self._tick_task.cancel()
await self._tick_task
except asyncio.CancelledError:
pass


class DistributedLockGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
Expand Down
24 changes: 23 additions & 1 deletion src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, cast

import attrs
from raftify import Raft, RaftNode

if TYPE_CHECKING:
from ai.backend.common.bgtask import BackgroundTaskManager
Expand All @@ -26,6 +27,25 @@ class BaseContext:
pass


class RaftClusterContext:
_cluster: Optional[Raft] = None

def use_raft(self) -> bool:
return self._cluster is not None

@property
def cluster(self) -> Raft:
return cast(Raft, self._cluster)

@cluster.setter
def cluster(self, rhs: Raft) -> None:
self._cluster = rhs

@property
def raft_node(self) -> RaftNode:
return self.cluster.get_raft_node()


@attrs.define(slots=True, auto_attribs=True, init=False)
class RootContext(BaseContext):
pidx: int
Expand All @@ -40,6 +60,7 @@ class RootContext(BaseContext):
redis_lock: RedisConnectionInfo
shared_config: SharedConfig
local_config: LocalConfig
raft_cluster_config: Optional[LocalConfig]
cors_options: CORSOptions

webapp_plugin_ctx: WebappPluginContext
Expand All @@ -53,3 +74,4 @@ class RootContext(BaseContext):
error_monitor: ErrorPluginContext
stats_monitor: StatsPluginContext
background_task_manager: BackgroundTaskManager
raft_ctx: RaftClusterContext
34 changes: 24 additions & 10 deletions src/ai/backend/manager/api/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from dateutil.relativedelta import relativedelta

from ai.backend.common import validators as tx
from ai.backend.common.distributed import GlobalTimer
from ai.backend.common.distributed import (
AbstractGlobalTimer,
DistributedLockGlobalTimer,
RaftGlobalTimer,
)
from ai.backend.common.events import AbstractEvent, EmptyEventArgs, EventHandler
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AgentId, LogSeverity
Expand Down Expand Up @@ -234,7 +238,7 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean

@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
log_cleanup_timer: GlobalTimer
log_cleanup_timer: AbstractGlobalTimer
log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent]


Expand All @@ -246,14 +250,24 @@ async def init(app: web.Application) -> None:
app,
log_cleanup_task,
)
app_ctx.log_cleanup_timer = GlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)

if root_ctx.raft_ctx.use_raft():
app_ctx.log_cleanup_timer = RaftGlobalTimer(
root_ctx.raft_ctx.raft_node,
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
)
else:
app_ctx.log_cleanup_timer = DistributedLockGlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)
await app_ctx.log_cleanup_timer.join()


Expand Down
98 changes: 98 additions & 0 deletions src/ai/backend/manager/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from __future__ import annotations

import asyncio
import json
import logging
import pathlib
import subprocess
import sys
from datetime import datetime
from functools import partial
from typing import Any

import click
from more_itertools import chunked
from raftify import (
InitialRole,
Peer,
Peers,
RaftServiceClient,
cli_main,
)
from setproctitle import setproctitle
from tabulate import tabulate

from ai.backend.cli.params import BoolExprType, OptionalType
from ai.backend.cli.types import ExitCode
Expand All @@ -19,6 +29,7 @@
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import LogSeverity
from ai.backend.common.validators import TimeDuration
from ai.backend.manager.raft.utils import register_custom_deserializer

from .context import CLIContext, redis_ctx

Expand Down Expand Up @@ -326,6 +337,93 @@ async def _clear_terminated_sessions():
asyncio.run(_clear_terminated_sessions())


async def inspect_node_status(cli_ctx: CLIContext) -> None:
raft_configs = cli_ctx.local_config["raft"]
table = []
headers = ["ENDPOINT", "NODE ID", "IS LEADER", "RAFT TERM", "RAFT APPLIED INDEX"]

if raft_configs is not None:
raft_cluster_configs = cli_ctx.raft_cluster_config
assert raft_cluster_configs is not None

other_peers = [{**peer, "myself": False} for peer in raft_cluster_configs["peers"]["other"]]
my_peers = [{**peer, "myself": True} for peer in raft_cluster_configs["peers"]["myself"]]
all_peers = sorted([*other_peers, *my_peers], key=lambda x: x["node-id"])

initial_peers = Peers({
int(peer_config["node-id"]): Peer(
addr=f"{peer_config['host']}:{peer_config['port']}",
role=InitialRole.from_str(peer_config["role"]),
)
for peer_config in all_peers
})

peers: dict[str, Any] | None = None
for intial_peer in initial_peers.to_dict().values():
raft_client = await RaftServiceClient.build(intial_peer.get_addr())
try:
resp = await raft_client.get_peers()
peers = json.loads(resp)
except Exception as e:
print(f"Failed to getting peers from {intial_peer.get_addr()}: {e}")
continue

if peers is None:
print("No peers are available!")
return

for node_id in sorted(peers.keys()):
peer = peers[node_id]
raft_client = await RaftServiceClient.build(peer["addr"])

try:
node_debugging_info = json.loads(await raft_client.debug_node())
except Exception as e:
print(f"Failed to getting debugging info from {peer['addr']}: {e}")
table.append([peer["addr"], "(Invalid response)"])

is_leader = node_debugging_info["node_id"] == node_debugging_info["leader_id"]
table.append([
peer["addr"],
node_debugging_info["node_id"],
is_leader,
node_debugging_info["term"],
node_debugging_info["raft_log"]["applied"],
])

table = [headers, *sorted(table, key=lambda x: str(x[0]))]
print(
tabulate(table, headers="firstrow", tablefmt="grid", stralign="center", numalign="center")
)


@main.command()
@click.pass_obj
def status(cli_ctx: CLIContext) -> None:
"""
Collect and print each manager process's status.
"""
asyncio.run(inspect_node_status(cli_ctx))


async def handle_raft_cli_main(argv: list[str]):
await cli_main(argv)


@main.command()
@click.pass_obj
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
def raft(cli_ctx: CLIContext, args) -> None:
register_custom_deserializer()

argv = sys.argv
# Remove "backend.ai", "mgr", "raft" from the argv
argv[:3] = []
argv.insert(0, "raftify-cli")

asyncio.run(handle_raft_cli_main(argv))


@main.group(cls=LazyGroup, import_name="ai.backend.manager.cli.dbschema:cli")
def schema():
"""Command set for managing the database schema."""
Expand Down
Loading
Loading