Skip to content

Commit

Permalink
Merge pull request #1512 from mito-ds/mito-ai-type-checking
Browse files Browse the repository at this point in the history
mito-ai: add mypy tests
  • Loading branch information
aarondr77 authored Feb 7, 2025
2 parents 1513feb + a7cb0f5 commit d1be563
Show file tree
Hide file tree
Showing 14 changed files with 238 additions and 115 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/test-mito-ai-mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Test - mito-ai mypy

on:
push:
branches: [ dev ]
paths:
- 'mito-ai/**'
pull_request:
paths:
- 'mito-ai/**'

jobs:
test-mito-ai-mypy:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ["3.10"]

steps:
- name: Cancel Previous Runs
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: mito-ai/setup.py
- name: Install dependencies
run: |
cd mito-ai
pip install -e ".[test]"
- name: Check types with MyPY
run: |
mypy mito-ai/mito_ai/ --ignore-missing-imports
13 changes: 7 additions & 6 deletions mito-ai/mito_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Dict
from jupyter_server.utils import url_path_join
from .handlers import CompletionHandler
from .providers import OpenAIProvider
from mito_ai.handlers import CompletionHandler
from mito_ai.providers import OpenAIProvider

try:
from _version import __version__
from _version import __version__ # type: ignore
except ImportError:
# Fallback when using the package in dev mode without installing
# in editable mode with pip. It is highly recommended to install
Expand All @@ -14,11 +15,11 @@
__version__ = "dev"


def _jupyter_labextension_paths():
def _jupyter_labextension_paths() -> List[Dict[str, str]]:
return [{"src": "labextension", "dest": "mito-ai"}]


def _jupyter_server_extension_points():
def _jupyter_server_extension_points() -> List[Dict[str, str]]:
"""
Returns a list of dictionaries with metadata describing
where to find the `_load_jupyter_server_extension` function.
Expand All @@ -33,7 +34,7 @@ def _jupyter_server_extension_points():

# For a further explanation of the Jupyter architecture watch the first 35 minutes
# of this video: https://www.youtube.com/watch?v=9_-siU-_XoI
def _load_jupyter_server_extension(server_app):
def _load_jupyter_server_extension(server_app) -> None: # type: ignore
host_pattern = ".*$"
web_app = server_app.web_app
base_url = web_app.settings["base_url"]
Expand Down
28 changes: 15 additions & 13 deletions mito-ai/mito_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from dataclasses import asdict
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Optional, Literal, Type
from typing import Any, Awaitable, Dict, Optional, Literal, Type, Union

import tornado
import tornado.ioloop
Expand All @@ -16,7 +16,7 @@

from mito_ai.logger import get_logger
from mito_ai.models import (
AllIncomingMessageTypes,
IncomingMessageTypes,
CodeExplainMessageBuilder,
CompletionError,
CompletionItem,
Expand Down Expand Up @@ -47,7 +47,7 @@ def initialize(self, llm: OpenAIProvider) -> None:
super().initialize()
self.log.debug("Initializing websocket connection %s", self.request.path)
self._llm = llm
self.full_message_history = []
self.full_message_history: list[ChatCompletionMessageParam] = []
self.is_pro = is_pro()

@property
Expand Down Expand Up @@ -75,10 +75,10 @@ async def pre_get(self) -> None:
):
raise tornado.web.HTTPError(HTTPStatus.FORBIDDEN)

async def get(self, *args, **kwargs) -> None:
async def get(self, *args: Any, **kwargs: dict[str, Any]) -> None:
"""Get an event to open a socket."""
# This method ensure to call `pre_get` before opening the socket.
await ensure_async(self.pre_get())
await ensure_async(self.pre_get()) # type: ignore

initialize_user()

Expand All @@ -99,7 +99,7 @@ def on_close(self) -> None:
# Clear the message history
self.full_message_history = []

async def on_message(self, message: str) -> None:
async def on_message(self, message: str) -> None: # type: ignore
"""Handle incoming messages on the WebSocket.
Args:
Expand All @@ -111,7 +111,7 @@ async def on_message(self, message: str) -> None:
parsed_message = json.loads(message)

metadata_dict = parsed_message.get('metadata', {})
type: AllIncomingMessageTypes = parsed_message.get('type')
type: IncomingMessageTypes = parsed_message.get('type')
except ValueError as e:
self.log.error("Invalid completion request.", exc_info=e)
return
Expand All @@ -121,7 +121,6 @@ async def on_message(self, message: str) -> None:
self.full_message_history = []
return

messages = []
response_format = None

# Generate new message based on message type
Expand Down Expand Up @@ -154,7 +153,7 @@ async def on_message(self, message: str) -> None:
else:
raise ValueError(f"Invalid message type: {type}")

new_message = {
new_message: ChatCompletionMessageParam = {
"role": "user",
"content": prompt
}
Expand Down Expand Up @@ -186,7 +185,7 @@ async def on_message(self, message: str) -> None:
except Exception as e:
await self.handle_exception(e, request)

def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
def open(self, *args: str, **kwargs: str) -> None:
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
Expand All @@ -203,7 +202,7 @@ def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
# Send the server capabilities to the client.
self.reply(self._llm.capabilities)

async def handle_exception(self, e: Exception, request: CompletionRequest):
async def handle_exception(self, e: Exception, request: CompletionRequest) -> None:
"""
Handles an exception raised in either ``handle_request`` or
``handle_stream_request``.
Expand All @@ -219,8 +218,11 @@ async def handle_exception(self, e: Exception, request: CompletionRequest):
hint = "There was an error communicating with OpenAI. This might be due to a temporary OpenAI outage, a problem with your internet connection, or an incorrect API key. Please try again."
else:
hint = "There was an error communicating with Mito server. This might be due to a temporary server outage or a problem with your internet connection. Please try again."
error = CompletionError.from_exception(e, hint=hint)

error: CompletionError = CompletionError.from_exception(e, hint=hint)
self._send_error({"new": error})

reply: Union[CompletionStreamChunk, CompletionReply]
if request.stream:
reply = CompletionStreamChunk(
chunk=CompletionItem(content="", isIncomplete=True),
Expand Down Expand Up @@ -282,7 +284,7 @@ async def _handle_stream_request(self, request: CompletionRequest, prompt_type:
self.full_message_history.append(
{
"role": "assistant",
"content": reply.items[0].content
"content": reply.items[0].content # type: ignore
}
)
latency_ms = round((time.time() - start) * 1000)
Expand Down
120 changes: 80 additions & 40 deletions mito-ai/mito_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import traceback
from dataclasses import dataclass
from typing import List, Literal, Optional, Type
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Type, Union

from pydantic import BaseModel
from openai.types.chat import ChatCompletionMessageParam

from .prompt_builders import (
from mito_ai.prompt_builders import (
create_chat_prompt,
create_inline_prompt,
create_explain_code_prompt,
Expand All @@ -16,18 +16,22 @@
)

CompletionIncomingMessageTypes = Literal['chat', 'inline_completion', 'codeExplain', 'smartDebug', 'agent:planning']
AllIncomingMessageTypes = Literal['clear_history', CompletionIncomingMessageTypes]
IncomingMessageTypes = Union[Literal['clear_history'], CompletionIncomingMessageTypes]

@dataclass(frozen=True)
class AICapabilities:
"""AI provider capabilities"""
"""
AI provider capabilities
"""

# Configuration schema.
configuration: dict
"""Configuration schema."""

# AI provider name.
provider: str
"""AI provider name."""

# Message type.
type: str = "ai_capabilities"
"""Message type."""

@dataclass(frozen=True)
class ChatMessageBuilder:
Expand Down Expand Up @@ -130,56 +134,79 @@ class PlanOfAttack(BaseModel):

@dataclass(frozen=True)
class CompletionRequest:
"""Message send by the client to request an AI chat response."""
"""
Message send by the client to request an AI chat response.
"""

# Message type.
type: IncomingMessageTypes
"""Message type."""

# Message UID generated by the client.
message_id: str
"""Message UID generated by the client."""
messages: List[dict] = None
"""Chat messages."""

# Chat messages.
messages: List[ChatCompletionMessageParam] = field(default_factory=list)

# Whether to stream the response (if supported by the model).
stream: bool = False
"""Whether to stream the response (if supported by the model)."""


@dataclass(frozen=True)
class CompletionItemError:
"""Completion item error information."""
"""
Completion item error information.
"""

# Error message.
message: Optional[str] = None
"""Error message."""


@dataclass(frozen=True)
class CompletionItem:
"""A completion suggestion."""
"""
A completion suggestion.
"""

# The completion.
content: str
"""The completion."""

# Whether the completion is incomplete or not.
isIncomplete: Optional[bool] = None
"""Whether the completion is incomplete or not."""

# Unique token identifying the completion request in the frontend.
token: Optional[str] = None
"""Unique token identifying the completion request in the frontend."""

# Error information for the completion item.
error: Optional[CompletionItemError] = None
"""Error information for the completion item."""


@dataclass(frozen=True)
class CompletionError:
"""Completion error description"""
"""
Completion error description.
"""

# Error type.
error_type: str
"""Error type"""

# Error title.
title: str
"""Error title"""

# Error traceback.
traceback: str
"""Error traceback"""

# Hint to resolve the error.
hint: str = ""
"""Hint to resolve the error"""

@staticmethod
def from_exception(exception: BaseException, hint: str = "") -> CompletionError:
"""Create a completion error from an exception."""
"""
Create a completion error from an exception.
Note: OpenAI exceptions can include a 'body' attribute with detailed error information.
While mypy doesn't know about this attribute on BaseException, we need to handle it
to properly extract error messages from OpenAI API responses.
"""
error_type = type(exception)
error_module = getattr(error_type, "__module__", "")
return CompletionError(
Expand All @@ -196,37 +223,50 @@ def from_exception(exception: BaseException, hint: str = "") -> CompletionError:

@dataclass(frozen=True)
class ErrorMessage(CompletionError):
"""Error message."""
"""
Error message.
"""

# Message type.
type: Literal["error"] = "error"
"""Message type."""



@dataclass(frozen=True)
class CompletionReply:
"""Message sent from model to client with the completion suggestions."""
"""
Message sent from model to client with the completion suggestions.
"""

# List of completion items.
items: List[CompletionItem]
"""List of completion items."""

# Parent message UID.
parent_id: str
"""Parent message UID."""

# Message type.
type: Literal["reply"] = "reply"
"""Message type."""

# Completion error.
error: Optional[CompletionError] = None
"""Completion error."""


@dataclass(frozen=True)
class CompletionStreamChunk:
"""Message sent from model to client with the infill suggestions"""
"""
Message sent from model to client with the infill suggestions
"""

chunk: CompletionItem
"""Completion item."""

# Parent message UID.
parent_id: str
"""Parent message UID."""

# Whether the completion is done or not.
done: bool
"""Whether the completion is done or not."""

# Message type.
type: Literal["chunk"] = "chunk"
"""Message type."""

# Completion error.
error: Optional[CompletionError] = None
"""Completion error."""
Loading

0 comments on commit d1be563

Please sign in to comment.