Skip to content

Commit

Permalink
Introduce raftify and RaftContext
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jan 12, 2024
1 parent 4a19001 commit 87e17da
Show file tree
Hide file tree
Showing 21 changed files with 1,179 additions and 569 deletions.
1 change: 1 addition & 0 deletions changes/1506.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Raft-based leader election process to manager group in HA condition in order to make their states consistent.
23 changes: 22 additions & 1 deletion configs/manager/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ password = "develove"


[manager]
num-proc = 4
num-proc = 3
service-addr = { host = "0.0.0.0", port = 8081 }
#user = "nobody"
#group = "nobody"
Expand All @@ -33,6 +33,27 @@ hide-agents = true
# The order of agent selection.
agent-selection-resource-priority = ["cuda", "rocm", "tpu", "cpu", "mem"]

[raft]
heartbeat-tick = 3
election-tick = 10
log-dir = "./logs"
log-level = "debug"

[[raft.peers]]
host = "127.0.0.1"
port = 60151
node-id = 1

[[raft.peers]]
host = "127.0.0.1"
port = 60152
node-id = 2

[[raft.peers]]
host = "127.0.0.1"
port = 60153
node-id = 3

[docker-registry]
ssl-verify = false

Expand Down
749 changes: 368 additions & 381 deletions python.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,5 @@ types-tabulate

backend.ai-krunner-alpine==5.1.0
backend.ai-krunner-static-gnu==4.1.0

raftify==0.1.32
74 changes: 73 additions & 1 deletion src/ai/backend/common/distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import abc
import asyncio
import logging
from typing import TYPE_CHECKING, Callable, Final

from aiomonitor.task import preserve_termination_log
from raftify import RaftNode

from .logging import BraceStyleAdapter

Expand All @@ -16,7 +18,77 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]


class GlobalTimer:
class AbstractGlobalTimer(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def generate_tick(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def join(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def leave(self) -> None:
raise NotImplementedError


class RaftGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

_event_producer: Final[EventProducer]

def __init__(
self,
raft_node: RaftNode,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
) -> None:
self._event_producer = event_producer
self._event_factory = event_factory
self._stopped = False
self.interval = interval
self.initial_delay = initial_delay
self.raft_node = raft_node

async def generate_tick(self) -> None:
try:
await asyncio.sleep(self.initial_delay)
if self._stopped:
return
while True:
try:
if self._stopped:
return
if await self.raft_node.is_leader():
await self._event_producer.produce_event(self._event_factory())
if self._stopped:
return
await asyncio.sleep(self.interval)
except asyncio.TimeoutError: # timeout raised from etcd lock
log.warn("timeout raised while trying to acquire lock. retrying...")
except asyncio.CancelledError:
pass

async def join(self) -> None:
self._tick_task = asyncio.create_task(self.generate_tick())

async def leave(self) -> None:
self._stopped = True
await asyncio.sleep(0)
if not self._tick_task.done():
try:
self._tick_task.cancel()
await self._tick_task
except asyncio.CancelledError:
pass


class DistributedLockGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
Expand Down
33 changes: 32 additions & 1 deletion src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, cast

import attrs
from raftify import Raft, RaftNode

if TYPE_CHECKING:
from ai.backend.common.bgtask import BackgroundTaskManager
Expand All @@ -26,6 +27,35 @@ class BaseContext:
pass


class RaftClusterContext:
_cluster: Optional[Raft] = None
bootstrap_done: bool
node_id_start: int

def __init__(
self,
bootstrap_done: bool = False,
node_id_start: int = 1,
) -> None:
self.bootstrap_done = bootstrap_done
self.node_id_start = node_id_start

def use_raft(self) -> bool:
return self._cluster is not None

@property
def cluster(self) -> Raft:
return cast(Raft, self._cluster)

@cluster.setter
def cluster(self, rhs: Raft) -> None:
self._cluster = rhs

@property
def raft_node(self) -> RaftNode:
return self.cluster.get_raft_node()


@attrs.define(slots=True, auto_attribs=True, init=False)
class RootContext(BaseContext):
pidx: int
Expand Down Expand Up @@ -53,3 +83,4 @@ class RootContext(BaseContext):
error_monitor: ErrorPluginContext
stats_monitor: StatsPluginContext
background_task_manager: BackgroundTaskManager
raft_ctx: RaftClusterContext
34 changes: 24 additions & 10 deletions src/ai/backend/manager/api/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from dateutil.relativedelta import relativedelta

from ai.backend.common import validators as tx
from ai.backend.common.distributed import GlobalTimer
from ai.backend.common.distributed import (
AbstractGlobalTimer,
DistributedLockGlobalTimer,
RaftGlobalTimer,
)
from ai.backend.common.events import AbstractEvent, EmptyEventArgs, EventHandler
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AgentId, LogSeverity
Expand Down Expand Up @@ -234,7 +238,7 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean

@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
log_cleanup_timer: GlobalTimer
log_cleanup_timer: AbstractGlobalTimer
log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent]


Expand All @@ -246,14 +250,24 @@ async def init(app: web.Application) -> None:
app,
log_cleanup_task,
)
app_ctx.log_cleanup_timer = GlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)

if root_ctx.raft_ctx.use_raft():
app_ctx.log_cleanup_timer = RaftGlobalTimer(
root_ctx.raft_ctx.raft_node,
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
)
else:
app_ctx.log_cleanup_timer = DistributedLockGlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)
await app_ctx.log_cleanup_timer.join()


Expand Down
72 changes: 72 additions & 0 deletions src/ai/backend/manager/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import asyncio
import json
import logging
import pathlib
import subprocess
import sys
from datetime import datetime
from functools import partial
from typing import Any

import click
from more_itertools import chunked
from raftify import Peers, RaftServiceClient, cli_main
from setproctitle import setproctitle
from tabulate import tabulate

from ai.backend.cli.params import BoolExprType, OptionalType
from ai.backend.cli.types import ExitCode
Expand All @@ -19,6 +23,7 @@
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import LogSeverity
from ai.backend.common.validators import TimeDuration
from ai.backend.manager.raft.utils import register_custom_deserializer

from .context import CLIContext, redis_ctx

Expand Down Expand Up @@ -326,6 +331,73 @@ async def _clear_terminated_sessions():
asyncio.run(_clear_terminated_sessions())


async def inspect_node_status(cli_ctx: CLIContext) -> None:
raft_configs = cli_ctx.local_config["raft"]
table = []
headers = ["ENDPOINT", "NODE ID", "IS LEADER", "RAFT TERM", "RAFT APPLIED INDEX"]

if raft_configs is not None:
initial_peers = Peers({
int(entry["node-id"]): f"{entry['host']}:{entry['port']}"
for entry in raft_configs["peers"]
})

peers: dict[str, Any] | None = None
for _, peer_addr in initial_peers.items():
raft_client = await RaftServiceClient.build(peer_addr)
try:
resp = await raft_client.get_peers()
peers = json.loads(resp)
except Exception as e:
print(f"Failed to get peers from {peer_addr}: {e}")
continue

if peers is None:
print("No peers are available!")
return

for peer in peers.values():
raft_client = await RaftServiceClient.build(peer["addr"])
node_debugging_info = json.loads(await raft_client.debug_node())

is_leader = node_debugging_info["node_id"] == node_debugging_info["leader_id"]
table.append([
peer["addr"],
node_debugging_info["node_id"],
is_leader,
node_debugging_info["term"],
node_debugging_info["raft_log"]["applied"],
])

table = [headers, *sorted(table, key=lambda x: str(x[0]))]
print(
tabulate(table, headers="firstrow", tablefmt="grid", stralign="center", numalign="center")
)


@main.command()
@click.pass_obj
def status(cli_ctx: CLIContext) -> None:
"""
Collect and print each manager process's status.
"""
asyncio.run(inspect_node_status(cli_ctx))


@main.command()
@click.pass_obj
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
def raft(cli_ctx: CLIContext, args) -> None:
register_custom_deserializer()

argv = sys.argv
# Remove "backend.ai", "mgr", "raft" from the argv
argv[:3] = []
argv.insert(0, "raftify_cli")

asyncio.run(cli_main(argv))


@main.group(cls=LazyGroup, import_name="ai.backend.manager.cli.dbschema:cli")
def schema():
"""Command set for managing the database schema."""
Expand Down
Loading

0 comments on commit 87e17da

Please sign in to comment.