Skip to content

Commit

Permalink
refactor: Enhance typing of some modules (#1924)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Feb 28, 2024
1 parent 719ff6d commit 29e940b
Show file tree
Hide file tree
Showing 42 changed files with 302 additions and 205 deletions.
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/mock/defs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum


class AllocationModes(str, enum.Enum):
class AllocationModes(enum.StrEnum):
DISCRETE = "discrete"
FRACTIONAL = "fractional"
14 changes: 7 additions & 7 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,15 +868,15 @@ async def server_main(
)
@click.option(
"--log-level",
type=click.Choice([*LogSeverity.__members__.keys()], case_sensitive=False),
default="INFO",
type=click.Choice([*LogSeverity], case_sensitive=False),
default=LogSeverity.INFO,
help="Set the logging verbosity level",
)
@click.pass_context
def main(
cli_ctx: click.Context,
config_path: Path,
log_level: str,
log_level: LogSeverity,
debug: bool = False,
) -> int:
"""Start the agent service as a foreground process."""
Expand Down Expand Up @@ -907,10 +907,10 @@ def main(
config.override_with_env(raw_cfg, ("container", "scratch-root"), "BACKEND_SCRATCH_ROOT")

if debug:
log_level = "DEBUG"
config.override_key(raw_cfg, ("debug", "enabled"), log_level == "DEBUG")
config.override_key(raw_cfg, ("logging", "level"), log_level.upper())
config.override_key(raw_cfg, ("logging", "pkg-ns", "ai.backend"), log_level.upper())
log_level = LogSeverity.DEBUG
config.override_key(raw_cfg, ("debug", "enabled"), log_level == LogSeverity.DEBUG)
config.override_key(raw_cfg, ("logging", "level"), log_level)
config.override_key(raw_cfg, ("logging", "pkg-ns", "ai.backend"), log_level)

# Validate and fill configurations
# (allow_extra will make configs to be forward-copmatible)
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ai.backend.common.types import ContainerId, KernelId, MountTypes, SessionId


class AgentBackend(enum.Enum):
class AgentBackend(enum.StrEnum):
# The list of importable backend names under "ai.backend.agent" pkg namespace.
DOCKER = "docker"
KUBERNETES = "kubernetes"
Expand Down Expand Up @@ -45,7 +45,7 @@ class AgentEventData:
data: dict[str, Any]


class ContainerStatus(str, enum.Enum):
class ContainerStatus(enum.StrEnum):
RUNNING = "running"
RESTARTING = "restarting"
PAUSED = "paused"
Expand All @@ -64,7 +64,7 @@ class Container:
backend_obj: Any # used to keep the backend-specific data


class LifecycleEvent(int, enum.Enum):
class LifecycleEvent(enum.IntEnum):
DESTROY = 0
CLEAN = 1
START = 2
Expand Down
19 changes: 12 additions & 7 deletions src/ai/backend/agent/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,17 @@ async def watcher_server(loop, pidx, args):
)
@click.option(
"--log-level",
type=click.Choice([*LogSeverity.__members__.keys()], case_sensitive=False),
default="INFO",
type=click.Choice([*LogSeverity], case_sensitive=False),
default=LogSeverity.INFO,
help="Set the logging verbosity level",
)
@click.pass_context
def main(ctx: click.Context, config_path: str, log_level: str, debug: bool) -> None:
def main(
ctx: click.Context,
config_path: str,
log_level: LogSeverity,
debug: bool,
) -> None:
watcher_config_iv = (
t.Dict({
t.Key("watcher"): t.Dict({
Expand Down Expand Up @@ -370,10 +375,10 @@ def main(ctx: click.Context, config_path: str, log_level: str, debug: bool) -> N
raw_cfg, ("watcher", "service-addr", "port"), "BACKEND_WATCHER_SERVICE_PORT"
)
if debug:
log_level = "DEBUG"
config.override_key(raw_cfg, ("debug", "enabled"), log_level == "DEBUG")
config.override_key(raw_cfg, ("logging", "level"), log_level.upper())
config.override_key(raw_cfg, ("logging", "pkg-ns", "ai.backend"), log_level.upper())
log_level = LogSeverity.DEBUG
config.override_key(raw_cfg, ("debug", "enabled"), log_level == LogSeverity.DEBUG)
config.override_key(raw_cfg, ("logging", "level"), log_level)
config.override_key(raw_cfg, ("logging", "pkg-ns", "ai.backend"), log_level)

try:
cfg = config.check(raw_cfg, watcher_config_iv)
Expand Down
9 changes: 5 additions & 4 deletions src/ai/backend/client/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ai.backend.client.compat import asyncio_run
from ai.backend.client.session import AsyncSession, Session
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import ClusterMode

from ..output.fields import routing_fields, service_fields
from ..output.types import FieldSpec
Expand Down Expand Up @@ -181,8 +182,8 @@ def info(ctx: CLIContext, service_name_or_id: str):
@click.option(
"--cluster-mode",
metavar="MODE",
type=click.Choice(["single-node", "multi-node"]),
default="single-node",
type=click.Choice([*ClusterMode], case_sensitive=False),
default=ClusterMode.SINGLE_NODE,
help="The mode of clustering.",
)
@click.option("-d", "--domain", type=str, default="default")
Expand Down Expand Up @@ -369,8 +370,8 @@ def create(
@click.option(
"--cluster-mode",
metavar="MODE",
type=click.Choice(["single-node", "multi-node"]),
default="single-node",
type=click.Choice([*ClusterMode], case_sensitive=False),
default=ClusterMode.SINGLE_NODE,
help="The mode of clustering.",
)
@click.option("-d", "--domain", type=str, default="default")
Expand Down
6 changes: 4 additions & 2 deletions src/ai/backend/client/cli/session/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import click

from ai.backend.common.types import SessionTypes

START_OPTION = [
click.option(
"-t",
Expand All @@ -16,8 +18,8 @@
click.option(
"--type",
metavar="SESSTYPE",
type=click.Choice(["batch", "interactive"]),
default="interactive",
type=click.Choice([*SessionTypes], case_sensitive=False),
default=SessionTypes.INTERACTIVE,
help="Either batch or interactive",
),
click.option(
Expand Down
7 changes: 4 additions & 3 deletions src/ai/backend/client/cli/session/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ai.backend.cli.params import CommaSeparatedListType, RangeExprOptionType
from ai.backend.cli.types import ExitCode
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import ClusterMode

from ...compat import asyncio_run, current_loop
from ...config import local_cache_path
Expand Down Expand Up @@ -346,8 +347,8 @@ def prepare_mount_arg(
@click.option(
"--cluster-mode",
metavar="MODE",
type=click.Choice(["single-node", "multi-node"]),
default="single-node",
type=click.Choice([*ClusterMode], case_sensitive=False),
default=ClusterMode.SINGLE_NODE,
help="The mode of clustering.",
)
@click.option(
Expand Down Expand Up @@ -409,7 +410,7 @@ def run(
scaling_group, # click_start_option
resources, # click_start_option
cluster_size, # click_start_option
cluster_mode,
cluster_mode: ClusterMode,
resource_opts, # click_start_option
architecture,
domain, # click_start_option
Expand Down
5 changes: 3 additions & 2 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ai.backend.cli.params import CommaSeparatedListType, OptionalType
from ai.backend.cli.types import ExitCode, Undefined, undefined
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import ClusterMode

from ...compat import asyncio_run
from ...exceptions import BackendAPIError
Expand Down Expand Up @@ -97,8 +98,8 @@ def _create_cmd(docs: str = None):
@click.option(
"--cluster-mode",
metavar="MODE",
type=click.Choice(["single-node", "multi-node"]),
default="single-node",
type=click.Choice([*ClusterMode], case_sensitive=False),
default=ClusterMode.SINGLE_NODE,
help="The mode of clustering.",
)
@click.option("--preopen", default=None, type=list_expr, help="Pre-open service ports")
Expand Down
25 changes: 12 additions & 13 deletions src/ai/backend/client/func/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Mapping,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID
Expand All @@ -29,7 +28,7 @@
from ai.backend.client.output.fields import session_fields
from ai.backend.client.output.types import FieldSpec, PaginatedResult
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import SessionTypes
from ai.backend.common.types import ClusterMode, SessionTypes

from ...cli.types import Undefined, undefined
from ..compat import current_loop
Expand Down Expand Up @@ -182,7 +181,7 @@ async def get_or_create(
resources: Mapping[str, str | int] = None,
resource_opts: Mapping[str, str | int] = None,
cluster_size: int = 1,
cluster_mode: Literal["single-node", "multi-node"] = "single-node",
cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE,
domain_name: str = None,
group_name: str = None,
bootstrap_script: str = None,
Expand Down Expand Up @@ -348,21 +347,21 @@ async def create_from_template(
*,
name: str | Undefined = undefined,
type_: str | Undefined = undefined,
starts_at: str = None,
starts_at: str | None = None, # not included in templates
enqueue_only: bool | Undefined = undefined,
max_wait: int | Undefined = undefined,
dependencies: Sequence[str] = None, # cannot be stored in templates
dependencies: Sequence[str] | None = None, # cannot be stored in templates
callback_url: str | Undefined = undefined,
no_reuse: bool | Undefined = undefined,
image: str | Undefined = undefined,
mounts: Union[List[str], Undefined] = undefined,
mount_map: Union[Mapping[str, str], Undefined] = undefined,
envs: Union[Mapping[str, str], Undefined] = undefined,
mounts: List[str] | Undefined = undefined,
mount_map: Mapping[str, str] | Undefined = undefined,
envs: Mapping[str, str] | Undefined = undefined,
startup_command: str | Undefined = undefined,
resources: Union[Mapping[str, str | int], Undefined] = undefined,
resource_opts: Union[Mapping[str, str | int], Undefined] = undefined,
resources: Mapping[str, str | int] | Undefined = undefined,
resource_opts: Mapping[str, str | int] | Undefined = undefined,
cluster_size: int | Undefined = undefined,
cluster_mode: Union[Literal["single-node", "multi-node"], Undefined] = undefined,
cluster_mode: ClusterMode | Undefined = undefined,
domain_name: str | Undefined = undefined,
group_name: str | Undefined = undefined,
bootstrap_script: str | Undefined = undefined,
Expand Down Expand Up @@ -1207,7 +1206,7 @@ async def get_or_create(
resources: Optional[Mapping[str, str]] = None,
resource_opts: Optional[Mapping[str, str]] = None,
cluster_size: int = 1,
cluster_mode: Literal["single-node", "multi-node"] = "single-node",
cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE,
domain_name: Optional[str] = None,
group_name: Optional[str] = None,
bootstrap_script: Optional[str] = None,
Expand Down Expand Up @@ -1244,7 +1243,7 @@ async def create_from_template(
resources: Mapping[str, int] | Undefined = undefined,
resource_opts: Mapping[str, int] | Undefined = undefined,
cluster_size: int | Undefined = undefined,
cluster_mode: Literal["single-node", "multi-node"] | Undefined = undefined,
cluster_mode: ClusterMode | Undefined = undefined,
domain_name: str | Undefined = undefined,
group_name: str | Undefined = undefined,
bootstrap_script: str | Undefined = undefined,
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/func/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)


class UserRole(str, enum.Enum):
class UserRole(enum.StrEnum):
"""
The role (privilege level) of users.
"""
Expand All @@ -69,7 +69,7 @@ class UserRole(str, enum.Enum):
MONITOR = "monitor"


class UserStatus(enum.Enum):
class UserStatus(enum.StrEnum):
"""
The detailed status of users to represent the signup process and account lifecycles.
"""
Expand Down
4 changes: 3 additions & 1 deletion src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def deserialize(cls, value: tuple):
)


class KernelLifecycleEventReason(str, enum.Enum):
class KernelLifecycleEventReason(enum.StrEnum):
AGENT_TERMINATION = "agent-termination"
ALREADY_TERMINATED = "already-terminated"
ANOMALY_DETECTED = "anomaly-detected"
Expand Down Expand Up @@ -238,6 +238,8 @@ class KernelLifecycleEventReason(str, enum.Enum):

@classmethod
def from_value(cls, value: Optional[str]) -> Optional[KernelLifecycleEventReason]:
if value is None:
return None
try:
return cls(value)
except ValueError:
Expand Down
Loading

0 comments on commit 29e940b

Please sign in to comment.