From dc09aa1890859f3ad0548ed053f8a6acfeffdd46 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 15:35:43 -0500 Subject: [PATCH 01/15] mito-ai: add mypy tests --- mito-ai/mito_ai/models.py | 30 +++++++++++++----------- mito-ai/mito_ai/utils/db.py | 7 +++--- mito-ai/mito_ai/utils/open_ai_utils.py | 12 +++++----- mito-ai/mito_ai/utils/telemetry_utils.py | 7 ++++-- mito-ai/mito_ai/utils/version_utils.py | 30 ++++++++++++++---------- mito-ai/mypy.ini | 26 ++++++++++++++++++++ mito-ai/setup.py | 6 ++++- 7 files changed, 80 insertions(+), 38 deletions(-) create mode 100644 mito-ai/mypy.ini diff --git a/mito-ai/mito_ai/models.py b/mito-ai/mito_ai/models.py index 6e66decfe..5b6d5fc00 100644 --- a/mito-ai/mito_ai/models.py +++ b/mito-ai/mito_ai/models.py @@ -1,8 +1,9 @@ from __future__ import annotations import traceback -from dataclasses import dataclass -from typing import List, Literal, Optional, Type +from dataclasses import dataclass, field +from typing import List, Literal, Optional, Type, Dict, Any, Union, cast, runtime_checkable, Protocol, get_args +import json from pydantic import BaseModel from openai.types.chat import ChatCompletionMessageParam @@ -16,7 +17,7 @@ ) CompletionIncomingMessageTypes = Literal['chat', 'inline_completion', 'codeExplain', 'smartDebug', 'agent:planning'] -AllIncomingMessageTypes = Literal['clear_history', CompletionIncomingMessageTypes] +AllIncomingMessageTypes = Union[Literal['clear_history'], CompletionIncomingMessageTypes] @dataclass(frozen=True) class AICapabilities: @@ -132,11 +133,11 @@ class PlanOfAttack(BaseModel): class CompletionRequest: """Message send by the client to request an AI chat response.""" - type: IncomingMessageTypes + type: AllIncomingMessageTypes """Message type.""" message_id: str """Message UID generated by the client.""" - messages: List[dict] = None + messages: List[Dict[str, Any]] = field(default_factory=list) """Chat messages.""" stream: bool = False """Whether to stream the response (if supported by the model).""" @@ -152,16 +153,12 @@ class CompletionItemError: @dataclass(frozen=True) class CompletionItem: - """A completion suggestion.""" + """Completion item information.""" content: str - """The completion.""" - isIncomplete: Optional[bool] = None - """Whether the completion is incomplete or not.""" - token: Optional[str] = None - """Unique token identifying the completion request in the frontend.""" + """Content of the completion.""" error: Optional[CompletionItemError] = None - """Error information for the completion item.""" + """Error information.""" @dataclass(frozen=True) @@ -179,14 +176,19 @@ class CompletionError: @staticmethod def from_exception(exception: BaseException, hint: str = "") -> CompletionError: - """Create a completion error from an exception.""" + """Create a completion error from an exception. + + Note: OpenAI exceptions can include a 'body' attribute with detailed error information. + While mypy doesn't know about this attribute on BaseException, we need to handle it + to properly extract error messages from OpenAI API responses. + """ error_type = type(exception) error_module = getattr(error_type, "__module__", "") return CompletionError( error_type=f"{error_module}.{error_type.__name__}" if error_module else error_type.__name__, - title=exception.body.get("message") + title=exception.body.get("message") # type: ignore[attr-defined] if hasattr(exception, "body") else (exception.args[0] if exception.args else "Exception"), traceback=traceback.format_exc(), diff --git a/mito-ai/mito_ai/utils/db.py b/mito-ai/mito_ai/utils/db.py index baf747b0a..bf0909ea9 100644 --- a/mito-ai/mito_ai/utils/db.py +++ b/mito-ai/mito_ai/utils/db.py @@ -3,11 +3,12 @@ """ import os import json -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, Final, TypeVar, cast from .schema import MITO_FOLDER # The path of the user.json file -USER_JSON_PATH = os.path.join(MITO_FOLDER, 'user.json') +USER_JSON_PATH: Final[str] = os.path.join(MITO_FOLDER, 'user.json') + def get_user_field(field: str) -> Optional[Any]: """ @@ -22,7 +23,7 @@ def get_user_field(field: str) -> Optional[Any]: def set_user_field(field: str, value: Any) -> None: """ - Updates the value of a specific feild in user.json + Updates the value of a specific field in user.json """ with open(USER_JSON_PATH, 'r') as user_file_old: old_user_json = json.load(user_file_old) diff --git a/mito-ai/mito_ai/utils/open_ai_utils.py b/mito-ai/mito_ai/utils/open_ai_utils.py index fa85c3c61..34ef12675 100644 --- a/mito-ai/mito_ai/utils/open_ai_utils.py +++ b/mito-ai/mito_ai/utils/open_ai_utils.py @@ -4,7 +4,7 @@ # Copyright (c) Saga Inc. import json -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Final, cast from datetime import datetime, timedelta from pydantic import BaseModel @@ -18,13 +18,13 @@ ) from .version_utils import is_pro -MITO_AI_URL = "https://ogtzairktg.execute-api.us-east-1.amazonaws.com/Prod/completions/" +MITO_AI_URL: Final[str] = "https://ogtzairktg.execute-api.us-east-1.amazonaws.com/Prod/completions/" -OPEN_SOURCE_AI_COMPLETIONS_LIMIT = 500 -OPEN_SOURCE_INLINE_COMPLETIONS_LIMIT = 30 # days +OPEN_SOURCE_AI_COMPLETIONS_LIMIT: Final[int] = 500 +OPEN_SOURCE_INLINE_COMPLETIONS_LIMIT: Final[int] = 30 # days -__user_email = None -__user_id = None +__user_email: Optional[str] = None +__user_id: Optional[str] = None def check_mito_server_quota(n_counts: int, first_usage_date: str) -> None: diff --git a/mito-ai/mito_ai/utils/telemetry_utils.py b/mito-ai/mito_ai/utils/telemetry_utils.py index 4580c56b2..04f201c1e 100644 --- a/mito-ai/mito_ai/utils/telemetry_utils.py +++ b/mito-ai/mito_ai/utils/telemetry_utils.py @@ -50,8 +50,11 @@ def telemetry_turned_on() -> bool: if is_pro(): return False - telemetry = get_user_field(UJ_MITOSHEET_TELEMETRY) - return telemetry if telemetry is not None else False + telemetry = get_user_field(UJ_MITOSHEET_TELEMETRY) + if telemetry is None: + return False + + return bool(telemetry) def identify() -> None: """ diff --git a/mito-ai/mito_ai/utils/version_utils.py b/mito-ai/mito_ai/utils/version_utils.py index dae401f26..0adf6da6e 100644 --- a/mito-ai/mito_ai/utils/version_utils.py +++ b/mito-ai/mito_ai/utils/version_utils.py @@ -1,21 +1,22 @@ import os +from typing import cast from .schema import UJ_MITOSHEET_ENTERPRISE, UJ_MITOSHEET_PRO from .db import get_user_field # Check if helper packages are installed try: - import mitosheet_helper_pro + import mitosheet_helper_pro # type: ignore MITOSHEET_HELPER_PRO = True except ImportError: MITOSHEET_HELPER_PRO = False try: - import mitosheet_helper_enterprise + import mitosheet_helper_enterprise # type: ignore MITOSHEET_HELPER_ENTERPRISE = True except ImportError: MITOSHEET_HELPER_ENTERPRISE = False try: - import mitosheet_private + import mitosheet_private # type: ignore MITOSHEET_PRIVATE = True except ImportError: MITOSHEET_PRIVATE = False @@ -23,7 +24,7 @@ # This is a legacy helper that we don't use anymore, however, we're keeping it for now # for backwards compatibility, since I'm not 100% confident that nobody is currently using it. try: - import mitosheet_helper_private + import mitosheet_helper_private # type: ignore MITOSHEET_HELPER_PRIVATE = True except ImportError: MITOSHEET_HELPER_PRIVATE = False @@ -37,11 +38,11 @@ def is_pro() -> bool: # This package overides the user.json if MITOSHEET_HELPER_PRO: - return MITOSHEET_HELPER_PRO + return True # This package overides the user.json if MITOSHEET_PRIVATE: - return MITOSHEET_PRIVATE + return True # Check if the config is set # TODO: Check if the mito config pro is set to true. @@ -52,23 +53,28 @@ def is_pro() -> bool: return True pro = get_user_field(UJ_MITOSHEET_PRO) - - return pro if pro is not None else False + if pro is None: + return False + + return bool(pro) def is_enterprise() -> bool: """ Helper function for returning if this is a Mito Enterprise users """ - is_enterprise = get_user_field(UJ_MITOSHEET_ENTERPRISE) # This package overides the user.json if MITOSHEET_HELPER_ENTERPRISE: - return MITOSHEET_HELPER_ENTERPRISE + return True # TODO: Check if the mito config enterprise is set to true. # I don't think that any user is on enterprise via this method + + is_enterprise = get_user_field(UJ_MITOSHEET_ENTERPRISE) + if is_enterprise is None: + return False - # TODO: heck if someone has a temp enterprise license set + # TODO: Check if someone has a temp enterprise license set - return is_enterprise if is_enterprise is not None else False \ No newline at end of file + return bool(is_enterprise) \ No newline at end of file diff --git a/mito-ai/mypy.ini b/mito-ai/mypy.ini new file mode 100644 index 000000000..58c670e7a --- /dev/null +++ b/mito-ai/mypy.ini @@ -0,0 +1,26 @@ +[mypy] +python_version = 3.9 +warn_return_any = True +warn_unused_configs = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = False +no_implicit_optional = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_no_return = True +warn_unreachable = True + +# Ignore missing imports for certain third-party libraries +[mypy-openai.*] +ignore_missing_imports = True + +[mypy-traitlets.*] +ignore_missing_imports = True + +[mypy-pydantic.*] +ignore_missing_imports = True + +[mypy-analytics.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/mito-ai/setup.py b/mito-ai/setup.py index 1201aa7ed..199547e83 100644 --- a/mito-ai/setup.py +++ b/mito-ai/setup.py @@ -96,7 +96,11 @@ def get_data_files_from_data_files_spec( 'wheel==0.42.0', 'twine==5.1.1', 'setuptools==68.0.0' - + ], + 'dev': [ + 'mypy>=1.8.0', + 'types-setuptools>=69.0.0', + 'types-tornado>=5.1.1', ], 'test': [ 'pytest==8.3.4', From 07c54942544a21f15bbd1b8e17b8b216958011e1 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 15:45:40 -0500 Subject: [PATCH 02/15] mito-ai: update types and comments in models.py --- mito-ai/mito_ai/models.py | 97 ++++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/mito-ai/mito_ai/models.py b/mito-ai/mito_ai/models.py index 5b6d5fc00..88d1b21fc 100644 --- a/mito-ai/mito_ai/models.py +++ b/mito-ai/mito_ai/models.py @@ -17,7 +17,7 @@ ) CompletionIncomingMessageTypes = Literal['chat', 'inline_completion', 'codeExplain', 'smartDebug', 'agent:planning'] -AllIncomingMessageTypes = Union[Literal['clear_history'], CompletionIncomingMessageTypes] +IncomingMessageTypes = Union[Literal['clear_history'], CompletionIncomingMessageTypes] @dataclass(frozen=True) class AICapabilities: @@ -131,52 +131,74 @@ class PlanOfAttack(BaseModel): @dataclass(frozen=True) class CompletionRequest: - """Message send by the client to request an AI chat response.""" + """ + Message send by the client to request an AI chat response. + """ - type: AllIncomingMessageTypes - """Message type.""" + # Message type. + type: IncomingMessageTypes + + # Message UID generated by the client. message_id: str - """Message UID generated by the client.""" + + # Chat messages. messages: List[Dict[str, Any]] = field(default_factory=list) - """Chat messages.""" + + # Whether to stream the response (if supported by the model). stream: bool = False - """Whether to stream the response (if supported by the model).""" @dataclass(frozen=True) class CompletionItemError: - """Completion item error information.""" + """ + Completion item error information. + """ + # Error message. message: Optional[str] = None - """Error message.""" @dataclass(frozen=True) class CompletionItem: - """Completion item information.""" + """ + A completion suggestion. + """ + # The completion. content: str - """Content of the completion.""" + + # Whether the completion is incomplete or not. + isIncomplete: Optional[bool] = None + + # Unique token identifying the completion request in the frontend. + token: Optional[str] = None + + # Error information for the completion item. error: Optional[CompletionItemError] = None - """Error information.""" @dataclass(frozen=True) class CompletionError: - """Completion error description""" + """ + Completion error description. + """ + # Error type. error_type: str - """Error type""" + + # Error title. title: str - """Error title""" + + # Error traceback. traceback: str - """Error traceback""" + + # Hint to resolve the error. hint: str = "" - """Hint to resolve the error""" @staticmethod def from_exception(exception: BaseException, hint: str = "") -> CompletionError: - """Create a completion error from an exception. + """ + Create a completion error from an exception. Note: OpenAI exceptions can include a 'body' attribute with detailed error information. While mypy doesn't know about this attribute on BaseException, we need to handle it @@ -198,37 +220,50 @@ def from_exception(exception: BaseException, hint: str = "") -> CompletionError: @dataclass(frozen=True) class ErrorMessage(CompletionError): - """Error message.""" + """ + Error message. + """ + # Message type. type: Literal["error"] = "error" - """Message type.""" + @dataclass(frozen=True) class CompletionReply: - """Message sent from model to client with the completion suggestions.""" + """ + Message sent from model to client with the completion suggestions. + """ + # List of completion items. items: List[CompletionItem] - """List of completion items.""" + + # Parent message UID. parent_id: str - """Parent message UID.""" + + # Message type. type: Literal["reply"] = "reply" - """Message type.""" + + # Completion error. error: Optional[CompletionError] = None - """Completion error.""" @dataclass(frozen=True) class CompletionStreamChunk: - """Message sent from model to client with the infill suggestions""" + """ + Message sent from model to client with the infill suggestions + """ chunk: CompletionItem - """Completion item.""" + + # Parent message UID. parent_id: str - """Parent message UID.""" + + # Whether the completion is done or not. done: bool - """Whether the completion is done or not.""" + + # Message type. type: Literal["chunk"] = "chunk" - """Message type.""" + + # Completion error. error: Optional[CompletionError] = None - """Completion error.""" From 48896e44a954b61a9a7834f12a7b63c4677d594f Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 15:48:57 -0500 Subject: [PATCH 03/15] mito-ai: more models.py cleanup --- mito-ai/mito_ai/models.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mito-ai/mito_ai/models.py b/mito-ai/mito_ai/models.py index 88d1b21fc..1d03970a2 100644 --- a/mito-ai/mito_ai/models.py +++ b/mito-ai/mito_ai/models.py @@ -2,13 +2,12 @@ import traceback from dataclasses import dataclass, field -from typing import List, Literal, Optional, Type, Dict, Any, Union, cast, runtime_checkable, Protocol, get_args -import json +from typing import List, Literal, Optional, Type, Union from pydantic import BaseModel from openai.types.chat import ChatCompletionMessageParam -from .prompt_builders import ( +from mito_ai.prompt_builders import ( create_chat_prompt, create_inline_prompt, create_explain_code_prompt, @@ -21,14 +20,18 @@ @dataclass(frozen=True) class AICapabilities: - """AI provider capabilities""" + """ + AI provider capabilities + """ + # Configuration schema. configuration: dict - """Configuration schema.""" + + # AI provider name. provider: str - """AI provider name.""" + + # Message type. type: str = "ai_capabilities" - """Message type.""" @dataclass(frozen=True) class ChatMessageBuilder: @@ -142,7 +145,7 @@ class CompletionRequest: message_id: str # Chat messages. - messages: List[Dict[str, Any]] = field(default_factory=list) + messages: List[ChatCompletionMessageParam] = field(default_factory=list) # Whether to stream the response (if supported by the model). stream: bool = False From 5a47dd8c77d74c38b332171c3c99b56910c44d1b Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 15:49:10 -0500 Subject: [PATCH 04/15] mito-ai: db.py cleanup --- mito-ai/mito_ai/utils/db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mito-ai/mito_ai/utils/db.py b/mito-ai/mito_ai/utils/db.py index bf0909ea9..f02cb6910 100644 --- a/mito-ai/mito_ai/utils/db.py +++ b/mito-ai/mito_ai/utils/db.py @@ -3,7 +3,7 @@ """ import os import json -from typing import Any, Dict, Optional, Union, Final, TypeVar, cast +from typing import Any, Optional, Final from .schema import MITO_FOLDER # The path of the user.json file From 16dd2711362d7361acdb46ac10887b4f918097ae Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 15:52:27 -0500 Subject: [PATCH 05/15] mito-ai: add type annotations to prompt_builders --- mito-ai/mito_ai/prompt_builders.py | 14 +++++++------- mito-ai/mito_ai/utils/open_ai_utils.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mito-ai/mito_ai/prompt_builders.py b/mito-ai/mito_ai/prompt_builders.py index f5509c7c6..769f9e4f1 100644 --- a/mito-ai/mito_ai/prompt_builders.py +++ b/mito-ai/mito_ai/prompt_builders.py @@ -63,7 +63,7 @@ def create_chat_prompt( return prompt -def create_explain_code_prompt(active_cell_code: str): +def create_explain_code_prompt(active_cell_code: str) -> str: prompt = f"""Explain the code in the active code cell to me like I have a basic understanding of Python. Don't explain each line, but instead explain the overall logic of the code. @@ -92,10 +92,10 @@ def multiply(x, y): return prompt def create_inline_prompt( - prefix: str, - suffix: str, - variables: List[str], -): + prefix: str, + suffix: str, + variables: List[str], +) -> str: variables_str = '\n'.join([f"{variable}" for variable in variables]) prompt = f"""You are a coding assistant that lives inside of JupyterLab. Your job is to help the user write code. @@ -250,7 +250,7 @@ def create_error_prompt( errorMessage: str, active_cell_code: str, variables: List[str], -): +) -> str: variables_str = '\n'.join([f"{variable}" for variable in variables]) prompt = f"""You are debugging code in a JupyterLab 4 notebook. Analyze the error and provide a solution that maintains the original intent. @@ -381,7 +381,7 @@ def parse_date(date_str): """ return prompt -def create_agent_prompt(file_type: str, columnSamples: List[str], input: str): +def create_agent_prompt(file_type: str, columnSamples: List[str], input: str) -> str: if file_type: file_sample_snippet = f"""You will be working with the following dataset (sample rows shown) from a {file_type} file: {columnSamples} diff --git a/mito-ai/mito_ai/utils/open_ai_utils.py b/mito-ai/mito_ai/utils/open_ai_utils.py index 34ef12675..e35e25af7 100644 --- a/mito-ai/mito_ai/utils/open_ai_utils.py +++ b/mito-ai/mito_ai/utils/open_ai_utils.py @@ -4,7 +4,7 @@ # Copyright (c) Saga Inc. import json -from typing import Any, Dict, List, Optional, Type, Final, cast +from typing import Any, Dict, List, Optional, Type, Final from datetime import datetime, timedelta from pydantic import BaseModel From 884608d7acf9aef63b5d03096862204881500715 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 15:54:15 -0500 Subject: [PATCH 06/15] mito-ai: add type checking to pytests --- mito-ai/mito_ai/tests/open_ai_utils_test.py | 4 ++-- mito-ai/mito_ai/tests/providers_test.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mito-ai/mito_ai/tests/open_ai_utils_test.py b/mito-ai/mito_ai/tests/open_ai_utils_test.py index 988fa9996..64d55f6f4 100644 --- a/mito-ai/mito_ai/tests/open_ai_utils_test.py +++ b/mito-ai/mito_ai/tests/open_ai_utils_test.py @@ -11,7 +11,7 @@ TODAY = datetime.now().strftime("%Y-%m-%d") -def test_check_mito_server_quota_open_source_user(): +def test_check_mito_server_quota_open_source_user() -> None: # Under both limits check_mito_server_quota(1, TODAY) @@ -28,7 +28,7 @@ def test_check_mito_server_quota_open_source_user(): check_mito_server_quota(1, REALLY_OLD_DATE) -def test_check_mito_server_quota_pro_user(): +def test_check_mito_server_quota_pro_user() -> None: # No error should be thrown since pro users don't have limits with patch("mito_ai.utils.open_ai_utils.is_pro", return_value=True): check_mito_server_quota(1, TODAY) diff --git a/mito-ai/mito_ai/tests/providers_test.py b/mito-ai/mito_ai/tests/providers_test.py index 0b677e57c..2913d7d1c 100644 --- a/mito-ai/mito_ai/tests/providers_test.py +++ b/mito-ai/mito_ai/tests/providers_test.py @@ -8,7 +8,7 @@ TODAY = datetime.now().strftime("%Y-%m-%d") -def test_os_user_mito_server_below_limit(): +def test_os_user_mito_server_below_limit() -> None: """ Open source user, with no OpenAI API key set (Mito server), below both limits. No error should be thrown. @@ -27,7 +27,7 @@ def test_os_user_mito_server_below_limit(): assert llm.last_error is None -def test_os_user_mito_server_above_limit(): +def test_os_user_mito_server_above_limit() -> None: """ Open source user, with no OpenAI API key set (Mito server), above both limits. An error should be thrown. @@ -60,7 +60,7 @@ def test_os_user_mito_server_above_limit(): assert llm.last_error.title == "mito_server_free_tier_limit_reached" -def test_os_user_openai_key_set_below_limit(): +def test_os_user_openai_key_set_below_limit() -> None: """ Open source user, with OpenAI API key set, below both limits. No error should be thrown. @@ -78,7 +78,7 @@ def test_os_user_openai_key_set_below_limit(): assert llm.last_error is None -def test_os_user_openai_key_set_above_limit(): +def test_os_user_openai_key_set_above_limit() -> None: """ Open source user, with OpenAI API key set, above both limits. No error should be thrown, since the user is using their own key. @@ -107,7 +107,7 @@ def test_os_user_openai_key_set_above_limit(): assert llm.last_error is None -def test_pro_user_mito_server_set_below_limit(): +def test_pro_user_mito_server_set_below_limit() -> None: """ Pro user, with no OpenAI API key set (Mito server), below chat limit. No error should be thrown. @@ -127,7 +127,7 @@ def test_pro_user_mito_server_set_below_limit(): assert llm.last_error is None -def test_pro_user_mito_server_above_limit(): +def test_pro_user_mito_server_above_limit() -> None: """ Pro user, with no OpenAI API key set (Mito server), with usage above both limits. No error should be thrown since pro users don't have limits. @@ -160,7 +160,7 @@ def test_pro_user_mito_server_above_limit(): assert llm.last_error is None -def test_pro_user_openai_key_set_below_limit(): +def test_pro_user_openai_key_set_below_limit() -> None: """ Pro user, with OpenAI API key set, below chat limit. No error should be thrown. @@ -179,7 +179,7 @@ def test_pro_user_openai_key_set_below_limit(): assert llm.last_error is None -def test_pro_user_openai_key_set_above_limit(): +def test_pro_user_openai_key_set_above_limit() -> None: """ Pro user, with OpenAI API key set, above both limits. No error should be thrown since pro users don't have limits. From d4b9c2f2e339643638ce3d392847e51076013cad Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:00:30 -0500 Subject: [PATCH 07/15] mito-ai: add __init__.py testing --- mito-ai/mito_ai/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mito-ai/mito_ai/__init__.py b/mito-ai/mito_ai/__init__.py index 1ebdbbb64..17c5be8a7 100644 --- a/mito-ai/mito_ai/__init__.py +++ b/mito-ai/mito_ai/__init__.py @@ -1,9 +1,10 @@ +from typing import List, Dict from jupyter_server.utils import url_path_join -from .handlers import CompletionHandler -from .providers import OpenAIProvider +from mito_ai.handlers import CompletionHandler +from mito_ai.providers import OpenAIProvider try: - from _version import __version__ + from _version import __version__ # type: ignore except ImportError: # Fallback when using the package in dev mode without installing # in editable mode with pip. It is highly recommended to install @@ -14,11 +15,11 @@ __version__ = "dev" -def _jupyter_labextension_paths(): +def _jupyter_labextension_paths() -> List[Dict[str, str]]: return [{"src": "labextension", "dest": "mito-ai"}] -def _jupyter_server_extension_points(): +def _jupyter_server_extension_points() -> List[Dict[str, str]]: """ Returns a list of dictionaries with metadata describing where to find the `_load_jupyter_server_extension` function. @@ -33,7 +34,7 @@ def _jupyter_server_extension_points(): # For a further explanation of the Jupyter architecture watch the first 35 minutes # of this video: https://www.youtube.com/watch?v=9_-siU-_XoI -def _load_jupyter_server_extension(server_app): +def _load_jupyter_server_extension(server_app) -> None: # type: ignore host_pattern = ".*$" web_app = server_app.web_app base_url = web_app.settings["base_url"] From 06a7f418c9f6e3067a131520b2d255746debcb24 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:24:08 -0500 Subject: [PATCH 08/15] mito-ai: add type handling for handlers.py --- mito-ai/mito_ai/handlers.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mito-ai/mito_ai/handlers.py b/mito-ai/mito_ai/handlers.py index f8e633ef9..c8645a6b7 100644 --- a/mito-ai/mito_ai/handlers.py +++ b/mito-ai/mito_ai/handlers.py @@ -3,7 +3,7 @@ import time from dataclasses import asdict from http import HTTPStatus -from typing import Any, Awaitable, Dict, Optional, Literal, Type +from typing import Any, Awaitable, Dict, Optional, Literal, Type, Union import tornado import tornado.ioloop @@ -16,7 +16,7 @@ from mito_ai.logger import get_logger from mito_ai.models import ( - AllIncomingMessageTypes, + IncomingMessageTypes, CodeExplainMessageBuilder, CompletionError, CompletionItem, @@ -47,7 +47,7 @@ def initialize(self, llm: OpenAIProvider) -> None: super().initialize() self.log.debug("Initializing websocket connection %s", self.request.path) self._llm = llm - self.full_message_history = [] + self.full_message_history: list[ChatCompletionMessageParam] = [] self.is_pro = is_pro() @property @@ -75,10 +75,10 @@ async def pre_get(self) -> None: ): raise tornado.web.HTTPError(HTTPStatus.FORBIDDEN) - async def get(self, *args, **kwargs) -> None: + async def get(self, *args: Any, **kwargs: dict[str, Any]) -> None: """Get an event to open a socket.""" # This method ensure to call `pre_get` before opening the socket. - await ensure_async(self.pre_get()) + await ensure_async(self.pre_get()) # type: ignore initialize_user() @@ -99,7 +99,7 @@ def on_close(self) -> None: # Clear the message history self.full_message_history = [] - async def on_message(self, message: str) -> None: + async def on_message(self, message: str) -> None: # type: ignore """Handle incoming messages on the WebSocket. Args: @@ -111,7 +111,7 @@ async def on_message(self, message: str) -> None: parsed_message = json.loads(message) metadata_dict = parsed_message.get('metadata', {}) - type: AllIncomingMessageTypes = parsed_message.get('type') + type: IncomingMessageTypes = parsed_message.get('type') except ValueError as e: self.log.error("Invalid completion request.", exc_info=e) return @@ -121,7 +121,6 @@ async def on_message(self, message: str) -> None: self.full_message_history = [] return - messages = [] response_format = None # Generate new message based on message type @@ -154,7 +153,7 @@ async def on_message(self, message: str) -> None: else: raise ValueError(f"Invalid message type: {type}") - new_message = { + new_message: ChatCompletionMessageParam = { "role": "user", "content": prompt } @@ -186,7 +185,7 @@ async def on_message(self, message: str) -> None: except Exception as e: await self.handle_exception(e, request) - def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: + def open(self, *args: str, **kwargs: str) -> None: """Invoked when a new WebSocket is opened. The arguments to `open` are extracted from the `tornado.web.URLSpec` @@ -203,7 +202,7 @@ def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: # Send the server capabilities to the client. self.reply(self._llm.capabilities) - async def handle_exception(self, e: Exception, request: CompletionRequest): + async def handle_exception(self, e: Exception, request: CompletionRequest) -> None: """ Handles an exception raised in either ``handle_request`` or ``handle_stream_request``. @@ -219,8 +218,11 @@ async def handle_exception(self, e: Exception, request: CompletionRequest): hint = "There was an error communicating with OpenAI. This might be due to a temporary OpenAI outage, a problem with your internet connection, or an incorrect API key. Please try again." else: hint = "There was an error communicating with Mito server. This might be due to a temporary server outage or a problem with your internet connection. Please try again." - error = CompletionError.from_exception(e, hint=hint) + + error: CompletionError = CompletionError.from_exception(e, hint=hint) self._send_error({"new": error}) + + reply: Union[CompletionStreamChunk, CompletionReply] if request.stream: reply = CompletionStreamChunk( chunk=CompletionItem(content="", isIncomplete=True), @@ -282,7 +284,7 @@ async def _handle_stream_request(self, request: CompletionRequest, prompt_type: self.full_message_history.append( { "role": "assistant", - "content": reply.items[0].content + "content": reply.items[0].content # type: ignore } ) latency_ms = round((time.time() - start) * 1000) From b2331ee5f4a3f3c5305a6abe1783a27757512469 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:35:24 -0500 Subject: [PATCH 09/15] mito-ai: add types to providers.py --- mito-ai/mito_ai/providers.py | 25 ++++++++++++++----------- mito-ai/mito_ai/utils/open_ai_utils.py | 8 +++++--- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/mito-ai/mito_ai/providers.py b/mito-ai/mito_ai/providers.py index 3966ba305..ddd19c8d1 100644 --- a/mito-ai/mito_ai/providers.py +++ b/mito-ai/mito_ai/providers.py @@ -6,7 +6,7 @@ import openai from openai._streaming import AsyncStream from openai.types.chat import ChatCompletionChunk -from traitlets import Instance, Unicode, default, validate, List +from traitlets import Instance, Unicode, default, validate from pydantic import BaseModel from traitlets.config import LoggingConfigurable @@ -50,7 +50,7 @@ class OpenAIProvider(LoggingConfigurable): help="OpenAI API key. Default value is read from the OPENAI_API_KEY environment variable.", ) - models = List(['gpt-4o-mini', 'o3-mini']) + models: List[str] = ['gpt-4o-mini', 'o3-mini'] last_error = Instance( CompletionError, @@ -60,7 +60,7 @@ class OpenAIProvider(LoggingConfigurable): This attribute is observed by the websocket provider to push the error to the client.""", ) - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Dict[str, Any]) -> None: super().__init__(log=get_logger(), **kwargs) self.last_error = None self._async_client: Optional[openai.AsyncOpenAI] = None @@ -68,14 +68,12 @@ def __init__(self, **kwargs) -> None: self._models: Optional[List[str]] = None @default("api_key") - def _api_key_default(self): + def _api_key_default(self) -> Optional[str]: default_key = os.environ.get("OPENAI_API_KEY") - return self._validate_api_key({"value": default_key}) + return self._validate_api_key(default_key) @validate("api_key") - def _validate_api_key(self, changes: Dict[str, Any]) -> Optional[str]: - """""" - api_key = changes["value"] + def _validate_api_key(self, api_key: Optional[str]) -> Optional[str]: if not api_key: self.log.debug( "No OpenAI API key provided; following back to Mito server API." @@ -145,7 +143,7 @@ def capabilities(self) -> AICapabilities: The provider capabilities. """ if self._models is None: - self._validate_api_key({"value": self.api_key}) + self._validate_api_key(self.api_key) # If the user has an OpenAI API key, then we don't need to check the Mito server quota. if self.api_key: @@ -261,8 +259,10 @@ async def request_completions( _num_usages = get_user_field(UJ_AI_MITO_API_NUM_USAGES) completion_function_params = get_open_ai_completion_function_params(model, request.messages, False, response_format) + + last_message_content = str(request.messages[-1].get("content", "")) if request.messages else None ai_response = await get_ai_completion_from_mito_server( - request.messages[-1].get("content", ""), + last_message_content, completion_function_params, _num_usages or 0, _first_usage_date or "", @@ -331,7 +331,10 @@ async def stream_completions( # Send the completion request to the OpenAI API and returns a stream of completion chunks try: completion_function_params = get_open_ai_completion_function_params(model, request.messages, stream=True) - stream: AsyncStream[ChatCompletionChunk] = await self._openAI_async_client.chat.completions.create(**completion_function_params) + client = self._openAI_async_client + if client is None: + raise ValueError("OpenAI client not initialized") + stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**completion_function_params) # Log the successful completion log_ai_completion_success( diff --git a/mito-ai/mito_ai/utils/open_ai_utils.py b/mito-ai/mito_ai/utils/open_ai_utils.py index e35e25af7..dbc429295 100644 --- a/mito-ai/mito_ai/utils/open_ai_utils.py +++ b/mito-ai/mito_ai/utils/open_ai_utils.py @@ -17,6 +17,8 @@ log, ) from .version_utils import is_pro +from openai.types.chat import ChatCompletionMessageParam + MITO_AI_URL: Final[str] = "https://ogtzairktg.execute-api.us-east-1.amazonaws.com/Prod/completions/" @@ -54,7 +56,7 @@ def check_mito_server_quota(n_counts: int, first_usage_date: str) -> None: async def get_ai_completion_from_mito_server( - last_message_content: str, + last_message_content: str | None, ai_completion_data: Dict[str, Any], n_counts: int, first_usage_date: str, @@ -72,7 +74,7 @@ async def get_ai_completion_from_mito_server( "email": __user_email, "user_id": __user_id, "data": ai_completion_data, - "user_input": last_message_content, # We add this just for logging purposes + "user_input": last_message_content or "", # We add this just for logging purposes } headers = { @@ -96,7 +98,7 @@ async def get_ai_completion_from_mito_server( def get_open_ai_completion_function_params( model: str, - messages: List[Dict[str, Any]], + messages: List[ChatCompletionMessageParam], stream: bool, response_format: Optional[Type[BaseModel]] = None ) -> Dict[str, Any]: From baa67153abcf175754fc0b0323a7f660faeeb379 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:36:57 -0500 Subject: [PATCH 10/15] mito_ai: add rest of types --- mito-ai/mito_ai/models.py | 2 +- mito-ai/mito_ai/utils/open_ai_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mito-ai/mito_ai/models.py b/mito-ai/mito_ai/models.py index 1d03970a2..4fadcae88 100644 --- a/mito-ai/mito_ai/models.py +++ b/mito-ai/mito_ai/models.py @@ -213,7 +213,7 @@ def from_exception(exception: BaseException, hint: str = "") -> CompletionError: error_type=f"{error_module}.{error_type.__name__}" if error_module else error_type.__name__, - title=exception.body.get("message") # type: ignore[attr-defined] + title=exception.body.get("message") if hasattr(exception, "body") else (exception.args[0] if exception.args else "Exception"), traceback=traceback.format_exc(), diff --git a/mito-ai/mito_ai/utils/open_ai_utils.py b/mito-ai/mito_ai/utils/open_ai_utils.py index dbc429295..95d37e212 100644 --- a/mito-ai/mito_ai/utils/open_ai_utils.py +++ b/mito-ai/mito_ai/utils/open_ai_utils.py @@ -4,7 +4,7 @@ # Copyright (c) Saga Inc. import json -from typing import Any, Dict, List, Optional, Type, Final +from typing import Any, Dict, List, Optional, Type, Final, Union from datetime import datetime, timedelta from pydantic import BaseModel @@ -56,7 +56,7 @@ def check_mito_server_quota(n_counts: int, first_usage_date: str) -> None: async def get_ai_completion_from_mito_server( - last_message_content: str | None, + last_message_content: Union[str, None], ai_completion_data: Dict[str, Any], n_counts: int, first_usage_date: str, @@ -93,7 +93,7 @@ async def get_ai_completion_from_mito_server( # so we just return that. content = json.loads(res.body) - return content.get("completion", "") + return str(content.get("completion", "")) def get_open_ai_completion_function_params( From ef51ef678ee78972c4df61bc2fe8f711ca5d8ed8 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:39:33 -0500 Subject: [PATCH 11/15] .github: add mito-ai mypy workflow --- .github/workflows/test-mito-ai-mypy.yml | 37 +++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 .github/workflows/test-mito-ai-mypy.yml diff --git a/.github/workflows/test-mito-ai-mypy.yml b/.github/workflows/test-mito-ai-mypy.yml new file mode 100644 index 000000000..455d36a30 --- /dev/null +++ b/.github/workflows/test-mito-ai-mypy.yml @@ -0,0 +1,37 @@ +name: Test - mito-ai mypy + +on: + push: + branches: [ dev ] + paths: + - 'mito-ai/**' + pull_request: + paths: + - 'mito-ai/**' + +jobs: + test-mito-ai-mypy: + runs-on: ubuntu-20.04 + strategy: + matrix: + python-version: ["3.10"] + + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.7.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: mito-ai/setup.py + - name: Install dependencies + run: | + cd mito-ai + pip install -e ".[test]" + - name: Check types with MyPY + run: | + mypy mito_ai/mito_ai/ \ No newline at end of file From ad234fd474f4d1dc58428024ee1ee6d4defe6707 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:43:04 -0500 Subject: [PATCH 12/15] mito-ai: update setup.py --- mito-ai/setup.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mito-ai/setup.py b/mito-ai/setup.py index 199547e83..01b7d779d 100644 --- a/mito-ai/setup.py +++ b/mito-ai/setup.py @@ -97,14 +97,12 @@ def get_data_files_from_data_files_spec( 'twine==5.1.1', 'setuptools==68.0.0' ], - 'dev': [ + 'test': [ + 'pytest==8.3.4', 'mypy>=1.8.0', 'types-setuptools>=69.0.0', 'types-tornado>=5.1.1', ], - 'test': [ - 'pytest==8.3.4', - ], }, keywords=["AI", "Jupyter", "Mito"], entry_points={ From b6222564102d309008728ad5d17f2c47ef299125 Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:45:15 -0500 Subject: [PATCH 13/15] .github: update mypy command in workflow --- .github/workflows/test-mito-ai-mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-mito-ai-mypy.yml b/.github/workflows/test-mito-ai-mypy.yml index 455d36a30..ca0ba201a 100644 --- a/.github/workflows/test-mito-ai-mypy.yml +++ b/.github/workflows/test-mito-ai-mypy.yml @@ -34,4 +34,4 @@ jobs: pip install -e ".[test]" - name: Check types with MyPY run: | - mypy mito_ai/mito_ai/ \ No newline at end of file + mypy mito_ai/ \ No newline at end of file From 84447b30f1bc5d5a81b06503371f086e0ef2edbf Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:48:01 -0500 Subject: [PATCH 14/15] .github: update mypy command again --- .github/workflows/test-mito-ai-mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-mito-ai-mypy.yml b/.github/workflows/test-mito-ai-mypy.yml index ca0ba201a..37077785b 100644 --- a/.github/workflows/test-mito-ai-mypy.yml +++ b/.github/workflows/test-mito-ai-mypy.yml @@ -34,4 +34,4 @@ jobs: pip install -e ".[test]" - name: Check types with MyPY run: | - mypy mito_ai/ \ No newline at end of file + mypy mito-ai/mito_ai/ \ No newline at end of file From a7cb0f5cd6348c2c94393a04f2f4cecdc17996dc Mon Sep 17 00:00:00 2001 From: Aaron Diamond-Reivich Date: Thu, 6 Feb 2025 16:52:02 -0500 Subject: [PATCH 15/15] .github: update command --- .github/workflows/test-mito-ai-mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-mito-ai-mypy.yml b/.github/workflows/test-mito-ai-mypy.yml index 37077785b..00725c3da 100644 --- a/.github/workflows/test-mito-ai-mypy.yml +++ b/.github/workflows/test-mito-ai-mypy.yml @@ -34,4 +34,4 @@ jobs: pip install -e ".[test]" - name: Check types with MyPY run: | - mypy mito-ai/mito_ai/ \ No newline at end of file + mypy mito-ai/mito_ai/ --ignore-missing-imports \ No newline at end of file