Skip to content

Commit

Permalink
feat(BA-96): Metric-based model service autoscaling (#3277)
Browse files Browse the repository at this point in the history
Co-authored-by: octodog <[email protected]>
Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent a6584ef commit 9be8899
Show file tree
Hide file tree
Showing 18 changed files with 2,022 additions and 198 deletions.
1 change: 1 addition & 0 deletions changes/3277.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support auto-scaling of model services by observing proxy and app-specific metrics as configured by autoscaling rules bound to each endpoint
116 changes: 116 additions & 0 deletions docs/manager/graphql-reference/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ type Queries {

"""Added in 24.12.0."""
networks(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): NetworkConnection

"""Added in 25.1.0."""
endpoint_auto_scaling_rule_node(id: String!): EndpointAutoScalingRuleNode

"""Added in 25.1.0."""
endpoint_auto_scaling_rule_nodes(endpoint: String!, filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): EndpointAutoScalingRuleConnection
}

"""
Expand Down Expand Up @@ -1653,6 +1659,63 @@ type NetworkEdge {
cursor: String!
}

"""Added in 25.1.0."""
type EndpointAutoScalingRuleNode implements Node {
"""The ID of the object"""
id: ID!
row_id: UUID!
metric_source: AutoScalingMetricSource!
metric_name: String!
threshold: String!
comparator: AutoScalingMetricComparator!
step_size: Int!
cooldown_seconds: Int!
min_replicas: Int
max_replicas: Int
created_at: DateTime!
last_triggered_at: DateTime
endpoint: UUID!
}

"""The source type to fetch metrics. Added in 25.1.0."""
enum AutoScalingMetricSource {
KERNEL
INFERENCE_FRAMEWORK
}

"""
The comparator used to compare the metric value with the threshold. Added in 25.1.0.
"""
enum AutoScalingMetricComparator {
LESS_THAN
LESS_THAN_OR_EQUAL
GREATER_THAN
GREATER_THAN_OR_EQUAL
}

"""Added in 25.1.0."""
type EndpointAutoScalingRuleConnection {
"""Pagination data for this connection."""
pageInfo: PageInfo!

"""Contains the nodes in this connection."""
edges: [EndpointAutoScalingRuleEdge]!

"""Total count of the GQL nodes of the query."""
count: Int
}

"""
Added in 25.1.0. A Relay edge containing a `EndpointAutoScalingRule` and its cursor.
"""
type EndpointAutoScalingRuleEdge {
"""The item at the end of the edge"""
node: EndpointAutoScalingRuleNode

"""A cursor for use in pagination"""
cursor: String!
}

"""All available GraphQL mutations."""
type Mutations {
modify_agent(id: String!, props: ModifyAgentInput!): ModifyAgent
Expand Down Expand Up @@ -1855,6 +1918,15 @@ type Mutations {
id: String!
): DeleteContainerRegistryNode

"""Added in 25.1.0."""
create_endpoint_auto_scaling_rule_node(endpoint: String!, props: EndpointAutoScalingRuleInput!): CreateEndpointAutoScalingRuleNode

"""Added in 25.1.0."""
modify_endpoint_auto_scaling_rule_node(id: String!, props: ModifyEndpointAutoScalingRuleInput!): ModifyEndpointAutoScalingRuleNode

"""Added in 25.1.0."""
delete_endpoint_auto_scaling_rule_node(id: String!): DeleteEndpointAutoScalingRuleNode

"""Deprecated since 24.09.0. use `CreateContainerRegistryNode` instead"""
create_container_registry(hostname: String!, props: CreateContainerRegistryInput!): CreateContainerRegistry

Expand Down Expand Up @@ -2593,6 +2665,50 @@ type DeleteContainerRegistryNode {
container_registry: ContainerRegistryNode
}

"""Added in 25.1.0."""
type CreateEndpointAutoScalingRuleNode {
ok: Boolean
msg: String
rule: EndpointAutoScalingRuleNode
}

"""Added in 25.1.0."""
input EndpointAutoScalingRuleInput {
metric_source: AutoScalingMetricSource!
metric_name: String!
threshold: String!
comparator: AutoScalingMetricComparator!
step_size: Int!
cooldown_seconds: Int!
min_replicas: Int
max_replicas: Int
}

"""Added in 25.1.0."""
type ModifyEndpointAutoScalingRuleNode {
ok: Boolean
msg: String
rule: EndpointAutoScalingRuleNode
}

"""Added in 25.1.0."""
input ModifyEndpointAutoScalingRuleInput {
metric_source: AutoScalingMetricSource
metric_name: String
threshold: String
comparator: AutoScalingMetricComparator
step_size: Int
cooldown_seconds: Int
min_replicas: Int
max_replicas: Int
}

"""Added in 25.1.0."""
type DeleteEndpointAutoScalingRuleNode {
ok: Boolean
msg: String
}

"""Deprecated since 24.09.0. use `CreateContainerRegistryNode` instead"""
type CreateContainerRegistry {
container_registry: ContainerRegistry
Expand Down
92 changes: 73 additions & 19 deletions src/ai/backend/cli/params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import json
import re
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar, Union
from typing import (
Any,
Generic,
Optional,
Protocol,
TypeVar,
override,
)

import click
import trafaret
Expand All @@ -12,7 +20,13 @@
class BoolExprType(click.ParamType):
name = "boolean"

def convert(self, value, param, ctx):
@override
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> bool:
if isinstance(value, bool):
return value
try:
Expand All @@ -34,7 +48,13 @@ class ByteSizeParamType(click.ParamType):
"e": 2**60,
}

def convert(self, value, param, ctx):
@override
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> Any:
if isinstance(value, int):
return value
if not isinstance(value, str):
Expand All @@ -54,7 +74,13 @@ def convert(self, value, param, ctx):
class ByteSizeParamCheckType(ByteSizeParamType):
name = "byte-check"

def convert(self, value, param, ctx):
@override
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> str:
if isinstance(value, int):
return value
if not isinstance(value, str):
Expand All @@ -72,7 +98,13 @@ def convert(self, value, param, ctx):
class CommaSeparatedKVListParamType(click.ParamType):
name = "comma-seperated-KVList-check"

def convert(self, value: Union[str, Mapping[str, str]], param, ctx) -> Mapping[str, str]:
@override
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> Mapping[str, str]:
if isinstance(value, dict):
return value
if not isinstance(value, str):
Expand Down Expand Up @@ -111,9 +143,10 @@ def __init__(self) -> None:
super().__init__()
self._parsed = False

@override
def convert(
self,
value: Optional[str],
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> Any:
Expand Down Expand Up @@ -151,8 +184,14 @@ class RangeExprOptionType(click.ParamType):
_rx_range_key = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
name = "Range Expression"

def convert(self, arg, param, ctx):
key, value = arg.split("=", maxsplit=1)
@override
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> Any:
key, value = value.split("=", maxsplit=1)
assert self._rx_range_key.match(key), "The key must be a valid slug string."
try:
if value.startswith("case:"):
Expand All @@ -172,12 +211,18 @@ def convert(self, arg, param, ctx):
class CommaSeparatedListType(click.ParamType):
name = "List Expression"

def convert(self, arg, param, ctx):
@override
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> Sequence[str]:
try:
if isinstance(arg, int):
return arg
elif isinstance(arg, str):
return arg.split(",")
if isinstance(value, int):
return value
elif isinstance(value, str):
return value.split(",")
except ValueError as e:
self.fail(repr(e), param, ctx)

Expand All @@ -189,22 +234,31 @@ class SingleValueConstructorType(Protocol):
def __init__(self, value: Any) -> None: ...


TScalar = TypeVar("TScalar", bound=SingleValueConstructorType)
TScalar = TypeVar("TScalar", bound=SingleValueConstructorType | click.ParamType)


class OptionalType(click.ParamType, Generic[TScalar]):
name = "Optional Type Wrapper"

def __init__(self, type_: type[TScalar] | type[click.ParamType]) -> None:
def __init__(self, type_: type[TScalar] | click.ParamType) -> None:
super().__init__()
self.type_ = type_

def convert(self, value: Any, param, ctx) -> TScalar | Undefined:
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
) -> TScalar | Undefined:
try:
if value is undefined:
return undefined
if issubclass(self.type_, click.ParamType):
return self.type_()(value)
return self.type_(value)
match self.type_:
case click.ParamType():
return self.type_(value)
case type() if issubclass(self.type_, click.ParamType):
return self.type_()(value)
case _:
return self.type_(value)
except ValueError:
self.fail(f"{value!r} is not valid `{self.type_}` or `undefined`", param, ctx)
1 change: 1 addition & 0 deletions src/ai/backend/client/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import model # noqa # type: ignore
from . import server_log # noqa # type: ignore
from . import service # noqa # type: ignore
from . import service_auto_scaling_rule # noqa # type: ignore
from . import session # noqa # type: ignore
from . import session_template # noqa # type: ignore
from . import vfolder # noqa # type: ignore
Expand Down
Loading

0 comments on commit 9be8899

Please sign in to comment.