Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mito-ai: add mypy tests #1512

Merged
merged 15 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/test-mito-ai-mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Test - mito-ai mypy

on:
push:
branches: [ dev ]
paths:
- 'mito-ai/**'
pull_request:
paths:
- 'mito-ai/**'

jobs:
test-mito-ai-mypy:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ["3.10"]

steps:
- name: Cancel Previous Runs
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: mito-ai/setup.py
- name: Install dependencies
run: |
cd mito-ai
pip install -e ".[test]"
- name: Check types with MyPY
run: |
mypy mito-ai/mito_ai/ --ignore-missing-imports
13 changes: 7 additions & 6 deletions mito-ai/mito_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Dict
from jupyter_server.utils import url_path_join
from .handlers import CompletionHandler
from .providers import OpenAIProvider
from mito_ai.handlers import CompletionHandler
from mito_ai.providers import OpenAIProvider

try:
from _version import __version__
from _version import __version__ # type: ignore
except ImportError:
# Fallback when using the package in dev mode without installing
# in editable mode with pip. It is highly recommended to install
Expand All @@ -14,11 +15,11 @@
__version__ = "dev"


def _jupyter_labextension_paths():
def _jupyter_labextension_paths() -> List[Dict[str, str]]:
return [{"src": "labextension", "dest": "mito-ai"}]


def _jupyter_server_extension_points():
def _jupyter_server_extension_points() -> List[Dict[str, str]]:
"""
Returns a list of dictionaries with metadata describing
where to find the `_load_jupyter_server_extension` function.
Expand All @@ -33,7 +34,7 @@ def _jupyter_server_extension_points():

# For a further explanation of the Jupyter architecture watch the first 35 minutes
# of this video: https://www.youtube.com/watch?v=9_-siU-_XoI
def _load_jupyter_server_extension(server_app):
def _load_jupyter_server_extension(server_app) -> None: # type: ignore
host_pattern = ".*$"
web_app = server_app.web_app
base_url = web_app.settings["base_url"]
Expand Down
28 changes: 15 additions & 13 deletions mito-ai/mito_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from dataclasses import asdict
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Optional, Literal, Type
from typing import Any, Awaitable, Dict, Optional, Literal, Type, Union

import tornado
import tornado.ioloop
Expand All @@ -16,7 +16,7 @@

from mito_ai.logger import get_logger
from mito_ai.models import (
AllIncomingMessageTypes,
IncomingMessageTypes,
CodeExplainMessageBuilder,
CompletionError,
CompletionItem,
Expand Down Expand Up @@ -47,7 +47,7 @@ def initialize(self, llm: OpenAIProvider) -> None:
super().initialize()
self.log.debug("Initializing websocket connection %s", self.request.path)
self._llm = llm
self.full_message_history = []
self.full_message_history: list[ChatCompletionMessageParam] = []
self.is_pro = is_pro()

@property
Expand Down Expand Up @@ -75,10 +75,10 @@ async def pre_get(self) -> None:
):
raise tornado.web.HTTPError(HTTPStatus.FORBIDDEN)

async def get(self, *args, **kwargs) -> None:
async def get(self, *args: Any, **kwargs: dict[str, Any]) -> None:
"""Get an event to open a socket."""
# This method ensure to call `pre_get` before opening the socket.
await ensure_async(self.pre_get())
await ensure_async(self.pre_get()) # type: ignore

initialize_user()

Expand All @@ -99,7 +99,7 @@ def on_close(self) -> None:
# Clear the message history
self.full_message_history = []

async def on_message(self, message: str) -> None:
async def on_message(self, message: str) -> None: # type: ignore
"""Handle incoming messages on the WebSocket.

Args:
Expand All @@ -111,7 +111,7 @@ async def on_message(self, message: str) -> None:
parsed_message = json.loads(message)

metadata_dict = parsed_message.get('metadata', {})
type: AllIncomingMessageTypes = parsed_message.get('type')
type: IncomingMessageTypes = parsed_message.get('type')
except ValueError as e:
self.log.error("Invalid completion request.", exc_info=e)
return
Expand All @@ -121,7 +121,6 @@ async def on_message(self, message: str) -> None:
self.full_message_history = []
return

messages = []
response_format = None

# Generate new message based on message type
Expand Down Expand Up @@ -154,7 +153,7 @@ async def on_message(self, message: str) -> None:
else:
raise ValueError(f"Invalid message type: {type}")

new_message = {
new_message: ChatCompletionMessageParam = {
"role": "user",
"content": prompt
}
Expand Down Expand Up @@ -186,7 +185,7 @@ async def on_message(self, message: str) -> None:
except Exception as e:
await self.handle_exception(e, request)

def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
def open(self, *args: str, **kwargs: str) -> None:
"""Invoked when a new WebSocket is opened.

The arguments to `open` are extracted from the `tornado.web.URLSpec`
Expand All @@ -203,7 +202,7 @@ def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
# Send the server capabilities to the client.
self.reply(self._llm.capabilities)

async def handle_exception(self, e: Exception, request: CompletionRequest):
async def handle_exception(self, e: Exception, request: CompletionRequest) -> None:
"""
Handles an exception raised in either ``handle_request`` or
``handle_stream_request``.
Expand All @@ -219,8 +218,11 @@ async def handle_exception(self, e: Exception, request: CompletionRequest):
hint = "There was an error communicating with OpenAI. This might be due to a temporary OpenAI outage, a problem with your internet connection, or an incorrect API key. Please try again."
else:
hint = "There was an error communicating with Mito server. This might be due to a temporary server outage or a problem with your internet connection. Please try again."
error = CompletionError.from_exception(e, hint=hint)

error: CompletionError = CompletionError.from_exception(e, hint=hint)
self._send_error({"new": error})

reply: Union[CompletionStreamChunk, CompletionReply]
if request.stream:
reply = CompletionStreamChunk(
chunk=CompletionItem(content="", isIncomplete=True),
Expand Down Expand Up @@ -282,7 +284,7 @@ async def _handle_stream_request(self, request: CompletionRequest, prompt_type:
self.full_message_history.append(
{
"role": "assistant",
"content": reply.items[0].content
"content": reply.items[0].content # type: ignore
}
)
latency_ms = round((time.time() - start) * 1000)
Expand Down
120 changes: 80 additions & 40 deletions mito-ai/mito_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import traceback
from dataclasses import dataclass
from typing import List, Literal, Optional, Type
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Type, Union

from pydantic import BaseModel
from openai.types.chat import ChatCompletionMessageParam

from .prompt_builders import (
from mito_ai.prompt_builders import (
create_chat_prompt,
create_inline_prompt,
create_explain_code_prompt,
Expand All @@ -16,18 +16,22 @@
)

CompletionIncomingMessageTypes = Literal['chat', 'inline_completion', 'codeExplain', 'smartDebug', 'agent:planning']
AllIncomingMessageTypes = Literal['clear_history', CompletionIncomingMessageTypes]
IncomingMessageTypes = Union[Literal['clear_history'], CompletionIncomingMessageTypes]

@dataclass(frozen=True)
class AICapabilities:
"""AI provider capabilities"""
"""
AI provider capabilities
"""

# Configuration schema.
configuration: dict
"""Configuration schema."""

# AI provider name.
provider: str
"""AI provider name."""

# Message type.
type: str = "ai_capabilities"
"""Message type."""

@dataclass(frozen=True)
class ChatMessageBuilder:
Expand Down Expand Up @@ -130,56 +134,79 @@ class PlanOfAttack(BaseModel):

@dataclass(frozen=True)
class CompletionRequest:
"""Message send by the client to request an AI chat response."""
"""
Message send by the client to request an AI chat response.
"""

# Message type.
type: IncomingMessageTypes
"""Message type."""

# Message UID generated by the client.
message_id: str
"""Message UID generated by the client."""
messages: List[dict] = None
"""Chat messages."""

# Chat messages.
messages: List[ChatCompletionMessageParam] = field(default_factory=list)

# Whether to stream the response (if supported by the model).
stream: bool = False
"""Whether to stream the response (if supported by the model)."""


@dataclass(frozen=True)
class CompletionItemError:
"""Completion item error information."""
"""
Completion item error information.
"""

# Error message.
message: Optional[str] = None
"""Error message."""


@dataclass(frozen=True)
class CompletionItem:
"""A completion suggestion."""
"""
A completion suggestion.
"""

# The completion.
content: str
"""The completion."""

# Whether the completion is incomplete or not.
isIncomplete: Optional[bool] = None
"""Whether the completion is incomplete or not."""

# Unique token identifying the completion request in the frontend.
token: Optional[str] = None
"""Unique token identifying the completion request in the frontend."""

# Error information for the completion item.
error: Optional[CompletionItemError] = None
"""Error information for the completion item."""


@dataclass(frozen=True)
class CompletionError:
"""Completion error description"""
"""
Completion error description.
"""

# Error type.
error_type: str
"""Error type"""

# Error title.
title: str
"""Error title"""

# Error traceback.
traceback: str
"""Error traceback"""

# Hint to resolve the error.
hint: str = ""
"""Hint to resolve the error"""

@staticmethod
def from_exception(exception: BaseException, hint: str = "") -> CompletionError:
"""Create a completion error from an exception."""
"""
Create a completion error from an exception.

Note: OpenAI exceptions can include a 'body' attribute with detailed error information.
While mypy doesn't know about this attribute on BaseException, we need to handle it
to properly extract error messages from OpenAI API responses.
"""
error_type = type(exception)
error_module = getattr(error_type, "__module__", "")
return CompletionError(
Expand All @@ -196,37 +223,50 @@ def from_exception(exception: BaseException, hint: str = "") -> CompletionError:

@dataclass(frozen=True)
class ErrorMessage(CompletionError):
"""Error message."""
"""
Error message.
"""

# Message type.
type: Literal["error"] = "error"
"""Message type."""



@dataclass(frozen=True)
class CompletionReply:
"""Message sent from model to client with the completion suggestions."""
"""
Message sent from model to client with the completion suggestions.
"""

# List of completion items.
items: List[CompletionItem]
"""List of completion items."""

# Parent message UID.
parent_id: str
"""Parent message UID."""

# Message type.
type: Literal["reply"] = "reply"
"""Message type."""

# Completion error.
error: Optional[CompletionError] = None
"""Completion error."""


@dataclass(frozen=True)
class CompletionStreamChunk:
"""Message sent from model to client with the infill suggestions"""
"""
Message sent from model to client with the infill suggestions
"""

chunk: CompletionItem
"""Completion item."""

# Parent message UID.
parent_id: str
"""Parent message UID."""

# Whether the completion is done or not.
done: bool
"""Whether the completion is done or not."""

# Message type.
type: Literal["chunk"] = "chunk"
"""Message type."""

# Completion error.
error: Optional[CompletionError] = None
"""Completion error."""
Loading
Loading