diff --git a/src/ai/backend/common/bgtask.py b/src/ai/backend/common/bgtask.py index 8dadaa3d7ec..0711638f5bf 100644 --- a/src/ai/backend/common/bgtask.py +++ b/src/ai/backend/common/bgtask.py @@ -27,6 +27,7 @@ from redis.asyncio.client import Pipeline from . import redis_helper +from .defs import BackgroundTaskLogLevel as LogLevel from .events import ( BgtaskCancelledEvent, BgtaskDoneEvent, @@ -70,6 +71,7 @@ async def update( self, increment: Union[int, float] = 0, message: str | None = None, + log_level: LogLevel = LogLevel.INFO, ) -> None: self.current_progress += increment # keep the state as local variables because they might be changed @@ -87,6 +89,7 @@ async def _pipe_builder(r: Redis) -> Pipeline: "total": str(total), "msg": message or "", "last_update": str(time.time()), + "log_level": str(log_level), }, ) await pipe.expire(tracker_key, MAX_BGTASK_ARCHIVE_PERIOD) @@ -99,6 +102,7 @@ async def _pipe_builder(r: Redis) -> Pipeline: message=message, current_progress=current, total_progress=total, + log_level=log_level, ), ) @@ -157,6 +161,7 @@ async def push_bgtask_events( case BgtaskUpdatedEvent(): body["current_progress"] = event.current_progress body["total_progress"] = event.total_progress + body["log_level"] = event.log_level await resp.send(json.dumps(body), event=event.name, retry=5) case BgtaskDoneEvent(): if extra_data: diff --git a/src/ai/backend/common/defs.py b/src/ai/backend/common/defs.py index 9dadda4c90b..0881d88c7b0 100644 --- a/src/ai/backend/common/defs.py +++ b/src/ai/backend/common/defs.py @@ -1,3 +1,4 @@ +import enum from typing import Final # Redis database IDs depending on purposes @@ -10,3 +11,10 @@ DEFAULT_FILE_IO_TIMEOUT: Final = 10 + + +class BackgroundTaskLogLevel(enum.StrEnum): + INFO = enum.auto() + WARNING = enum.auto() + ERROR = enum.auto() + DEBUG = enum.auto() diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index abdd3b7685e..36339db5c9a 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -34,6 +34,7 @@ from redis.asyncio import ConnectionPool from . import msgpack, redis_helper +from .defs import BackgroundTaskLogLevel from .logging import BraceStyleAdapter from .types import ( AgentId, @@ -559,6 +560,7 @@ class BgtaskUpdatedEvent(AbstractEvent): current_progress: float = attrs.field() total_progress: float = attrs.field() message: Optional[str] = attrs.field(default=None) + log_level: BackgroundTaskLogLevel = attrs.field(default=BackgroundTaskLogLevel.INFO) def serialize(self) -> tuple: return ( @@ -566,6 +568,7 @@ def serialize(self) -> tuple: self.current_progress, self.total_progress, self.message, + str(self.log_level), ) @classmethod @@ -575,6 +578,7 @@ def deserialize(cls, value: tuple): value[1], value[2], value[3], + BackgroundTaskLogLevel(value[4]), )