From bf22409e807549cc0b4238ca651bef2f028b1c69 Mon Sep 17 00:00:00 2001 From: Jason Weill <93281816+JasonWeill@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:28:53 -0800 Subject: [PATCH] Backport PR #430: Model parameters option to pass in model tuning, arbitrary parameters (#453) Co-authored-by: Piyush Jain --- .../jupyter_ai_magics/embedding_providers.py | 3 +- .../jupyter_ai_magics/magics.py | 19 +++----- .../jupyter_ai_magics/parsers.py | 42 +++++++++++++++++ .../jupyter_ai_magics/providers.py | 13 +++++- .../jupyter_ai/chat_handlers/ask.py | 3 +- .../jupyter_ai/chat_handlers/base.py | 11 ++++- .../jupyter_ai/chat_handlers/default.py | 3 +- .../jupyter_ai/chat_handlers/generate.py | 4 +- packages/jupyter-ai/jupyter_ai/extension.py | 46 ++++++++++++++++++- 9 files changed, 124 insertions(+), 20 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index 5fe522beb..d09907fb0 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Type +from typing import ClassVar, List from jupyter_ai_magics.providers import ( AuthStrategy, @@ -12,7 +12,6 @@ HuggingFaceHubEmbeddings, OpenAIEmbeddings, ) -from langchain.embeddings.base import Embeddings from pydantic import BaseModel, Extra diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 6f1d29546..1e853da67 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -405,9 +405,11 @@ def handle_error(self, args: ErrorArgs): prompt = f"Explain the following error:\n\n{last_error}" # Set CellArgs based on ErrorArgs - cell_args = CellArgs( - type="root", model_id=args.model_id, format=args.format, reset=False - ) + values = args.dict() + values["type"] = "root" + values["reset"] = False + cell_args = CellArgs(**values) + return self.run_ai_cell(cell_args, prompt) def _append_exchange_openai(self, prompt: str, output: str): @@ -538,16 +540,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str): provider_params["request_schema"] = args.request_schema provider_params["response_path"] = args.response_path - # Validate that the request schema is well-formed JSON - try: - json.loads(args.request_schema) - except json.JSONDecodeError as e: - raise ValueError( - "request-schema must be valid JSON. " - f"Error at line {e.lineno}, column {e.colno}: {e.msg}" - ) from None + model_parameters = json.loads(args.model_parameters) - provider = Provider(**provider_params) + provider = Provider(**provider_params, **model_parameters) # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index a6acf3525..cadd41f4a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -1,3 +1,4 @@ +import json from typing import Literal, Optional, get_args import click @@ -32,12 +33,21 @@ + "does nothing with other providers." ) +MODEL_PARAMETERS_SHORT_OPTION = "-m" +MODEL_PARAMETERS_LONG_OPTION = "--model-parameters" +MODEL_PARAMETERS_HELP = ( + "A JSON value that specifies extra values that will be passed " + "to the model. The accepted value parsed to a dict, unpacked " + "and passed as-is to the provider class." +) + class CellArgs(BaseModel): type: Literal["root"] = "root" model_id: str format: FORMAT_CHOICES_TYPE reset: bool + model_parameters: Optional[str] # The following parameters are required only for SageMaker models region_name: Optional[str] request_schema: Optional[str] @@ -49,6 +59,7 @@ class ErrorArgs(BaseModel): type: Literal["error"] = "error" model_id: str format: FORMAT_CHOICES_TYPE + model_parameters: Optional[str] # The following parameters are required only for SageMaker models region_name: Optional[str] request_schema: Optional[str] @@ -93,6 +104,19 @@ def get_help(self, ctx): click.echo(super().get_help(ctx)) +def verify_json_value(ctx, param, value): + if not value: + return value + try: + json.loads(value) + except json.JSONDecodeError as e: + raise ValueError( + f"{param.get_error_hint(ctx)} must be valid JSON. " + f"Error at line {e.lineno}, column {e.colno}: {e.msg}" + ) + return value + + @click.command() @click.argument("model_id") @click.option( @@ -120,6 +144,7 @@ def get_help(self, ctx): REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP, + callback=verify_json_value, ) @click.option( RESPONSE_PATH_SHORT_OPTION, @@ -127,6 +152,14 @@ def get_help(self, ctx): required=False, help=RESPONSE_PATH_HELP, ) +@click.option( + MODEL_PARAMETERS_SHORT_OPTION, + MODEL_PARAMETERS_LONG_OPTION, + required=False, + help=MODEL_PARAMETERS_HELP, + callback=verify_json_value, + default="{}", +) def cell_magic_parser(**kwargs): """ Invokes a language model identified by MODEL_ID, with the prompt being @@ -166,6 +199,7 @@ def line_magic_parser(): REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP, + callback=verify_json_value, ) @click.option( RESPONSE_PATH_SHORT_OPTION, @@ -173,6 +207,14 @@ def line_magic_parser(): required=False, help=RESPONSE_PATH_HELP, ) +@click.option( + MODEL_PARAMETERS_SHORT_OPTION, + MODEL_PARAMETERS_LONG_OPTION, + required=False, + help=MODEL_PARAMETERS_HELP, + callback=verify_json_value, + default="{}", +) def error_subparser(**kwargs): """ Explains the most recent error. Takes the same options (except -r) as diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 9fdbffa7a..a17067921 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -5,7 +5,17 @@ import io import json from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union +from typing import ( + Any, + ClassVar, + Coroutine, + Dict, + List, + Literal, + Mapping, + Optional, + Union, +) from jsonpath_ng import parse from langchain.chat_models import ( @@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs): content_handler = JsonContentHandler( request_schema=request_schema, response_path=response_path ) + super().__init__(*args, **kwargs, content_handler=content_handler) async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 2f3f1388a..e5c852051 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - self.llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + self.llm = provider(**provider_params, **model_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 7894dcfa5..5ffe65c7c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -3,7 +3,7 @@ import traceback # necessary to prevent circular import -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type from uuid import uuid4 from jupyter_ai.config_manager import ConfigManager, Logger @@ -23,10 +23,12 @@ def __init__( log: Logger, config_manager: ConfigManager, root_chat_handlers: Dict[str, "RootChatHandler"], + model_parameters: Dict[str, Dict], ): self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers + self.model_parameters = model_parameters self.parser = argparse.ArgumentParser() self.llm = None self.llm_params = None @@ -122,6 +124,13 @@ def get_llm_chain(self): self.llm_params = lm_provider_params return self.llm_chain + def get_model_parameters( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): + return self.model_parameters.get( + f"{provider.id}:{provider_params['model_id']}", {} + ) + def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3638b6cfb..d329e05e2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + llm = provider(**provider_params, **model_parameters) if llm.is_chat_provider: prompt_template = ChatPromptTemplate.from_messages( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index e3f84a924..9b66cfae2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + llm = provider(**provider_params, **model_parameters) + self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 50865ed96..f9254a581 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,7 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp -from traitlets import List, Unicode +from traitlets import Dict, List, Unicode from .chat_handlers import ( AskChatHandler, @@ -53,6 +53,45 @@ class AiExtension(ExtensionApp): config=True, ) + allowed_models = List( + Unicode(), + default_value=None, + help=""" + Language models to allow, as a list of global model IDs in the format + `:`. If `None`, all are allowed. Defaults to + `None`. + + Note: Currently, if `allowed_providers` is also set, then this field is + ignored. This is subject to change in a future non-major release. Using + both traits is considered to be undefined behavior at this time. + """, + allow_none=True, + config=True, + ) + + blocked_models = List( + Unicode(), + default_value=None, + help=""" + Language models to block, as a list of global model IDs in the format + `:`. If `None`, none are blocked. Defaults to + `None`. + """, + allow_none=True, + config=True, + ) + + model_parameters = Dict( + key_trait=Unicode(), + value_trait=Dict(), + default_value={}, + help="""Key-value pairs for model id and corresponding parameters that + are passed to the provider class. The values are unpacked and passed to + the provider class as-is.""", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() restrictions = { @@ -60,6 +99,10 @@ def initialize_settings(self): "blocked_providers": self.blocked_providers, } + self.settings["model_parameters"] = self.model_parameters + self.log.info(f"Configured model parameters: {self.model_parameters}") + + # Fetch LM & EM providers self.settings["lm_providers"] = get_lm_providers( log=self.log, restrictions=restrictions ) @@ -107,6 +150,7 @@ def initialize_settings(self): "log": self.log, "config_manager": self.settings["jai_config_manager"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], + "model_parameters": self.settings["model_parameters"], } default_chat_handler = DefaultChatHandler( **chat_handler_kwargs, chat_history=self.settings["chat_history"]