From 39d4199976c5b088be8266f7b5d7b679a64688b8 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Thu, 5 Dec 2024 22:47:33 +0900 Subject: [PATCH] feat: Add routing config schema for app proxy --- src/ai/backend/wsproxy/config.py | 34 +++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/wsproxy/config.py b/src/ai/backend/wsproxy/config.py index 4638403ca0..fb011cbc6d 100644 --- a/src/ai/backend/wsproxy/config.py +++ b/src/ai/backend/wsproxy/config.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from pathlib import Path from pprint import pformat -from typing import Annotated, Any +from typing import Annotated, Any, Literal import click from pydantic import ( @@ -201,14 +201,14 @@ def __get_pydantic_json_schema__( ) -class LogDriver(str, enum.Enum): +class LogDriver(enum.StrEnum): CONSOLE = "console" LOGSTASH = "logstash" FILE = "file" GRAYLOG = "graylog" -class LogstashProtocol(str, enum.Enum): +class LogstashProtocol(enum.StrEnum): ZMQ_PUSH = "zmq.push" ZMQ_PUB = "zmq.pub" TCP = "tcp" @@ -333,6 +333,31 @@ class DebugConfig(BaseSchema): log_events: Annotated[bool, Field(default=False)] +class RoutingAlgorithmType(enum.StrEnum): + UNIFORM_WEIGHTED = "weighted-uniform-random" + WRR_SHUFFLED = "suffled-weighted-round-robin" + + +class WeightedRoundRobinParams(BaseSchema): + multiplier: int = Field(default=10) + shuffle_period: int = Field(default=5) + + +class RoutingConfigForUniform(BaseSchema, extra="forbid"): + type: Literal[RoutingAlgorithmType.UNIFORM_WEIGHTED] = RoutingAlgorithmType.UNIFORM_WEIGHTED + + +class RoutingConfigForWRR(BaseSchema, extra="forbid"): + type: Literal[RoutingAlgorithmType.WRR_SHUFFLED] = RoutingAlgorithmType.WRR_SHUFFLED + options: WeightedRoundRobinParams = Field(default=WeightedRoundRobinParams()) + + +RoutingConfig = Annotated[ + RoutingConfigForUniform | RoutingConfigForWRR, + Field(discriminator="type"), +] + + class WSProxyConfig(BaseSchema): ipc_base_path: Annotated[ Path, @@ -404,6 +429,9 @@ class WSProxyConfig(BaseSchema): protocol: Annotated[ ProxyProtocol, Field(default=ProxyProtocol.HTTP, description="Proxy protocol") ] + default_routing_config: RoutingConfig = Field( + default=RoutingConfigForWRR(), description="The default routing configuration" + ) jwt_encrypt_key: Annotated[ str, Field(examples=["50M3G00DL00KING53CR3T"], description="JWT encryption key")