From 09a339e60fa1bc74e92d8d4cae598ca2326d288e Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 8 Nov 2023 10:42:22 -0800 Subject: [PATCH] Model parameters option to pass in model tuning, arbitrary parameters (#430) * Endpoint args for SM endpoints * Added model and endpoints kwargs options. * Added configurable option for model parameters. * Updated magics, added model_parameters, removed model_kwargs and endpoint_kwargs. * Fixes %ai error for SM endpoints. * Fixed docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 430 fixes (#2) * log configured model_parameters * fix markdown formatting in docs * fix single quotes and use preferred traitlets CLI syntax --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: david qiu --- docs/source/users/index.md | 118 +++++++++++++++++- .../jupyter_ai_magics/embedding_providers.py | 3 +- .../jupyter_ai_magics/magics.py | 19 ++- .../jupyter_ai_magics/parsers.py | 42 +++++++ .../jupyter_ai_magics/providers.py | 13 +- .../jupyter_ai/chat_handlers/ask.py | 3 +- .../jupyter_ai/chat_handlers/base.py | 11 +- .../jupyter_ai/chat_handlers/default.py | 3 +- .../jupyter_ai/chat_handlers/generate.py | 4 +- packages/jupyter-ai/jupyter_ai/extension.py | 17 ++- 10 files changed, 208 insertions(+), 25 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index b52fb66bb..fb52e29ba 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -855,30 +855,138 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP ## Configuration -You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers. +You can specify an allowlist, to only allow only a certain list of providers, or +a blocklist, to block some providers. ### Blocklisting providers -This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. + +This configuration allows for blocking specific providers in the settings panel. +This list takes precedence over the allowlist in the next section. ``` jupyter lab --AiExtension.blocked_providers=openai ``` -To block more than one provider in the block-list, repeat the runtime configuration. +To block more than one provider in the block-list, repeat the runtime +configuration. ``` jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21 ``` ### Allowlisting providers -This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers. + +This configuration allows for filtering the list of providers in the settings +panel to only an allowlisted set of providers. ``` jupyter lab --AiExtension.allowed_providers=openai ``` -To allow more than one provider in the allowlist, repeat the runtime configuration. +To allow more than one provider in the allowlist, repeat the runtime +configuration. ``` jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21 ``` + +### Model parameters + +This configuration allows specifying arbitrary parameters that are unpacked and +passed to the provider class. This is useful for passing parameters such as +model tuning that affect the response generation by the model. This is also an +appropriate place to pass in custom attributes required by certain +providers/models. + +The accepted value is a dictionary, with top level keys as the model id +(provider:model_id), and value should be any arbitrary dictionary which is +unpacked and passed as-is to the provider class. + +#### Configuring as a startup option + +In this sample, the `bedrock` provider will be created with the value for +`model_kwargs` when `ai21.j2-mid-v1` model is selected. + +```bash +jupyter lab --AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}' +``` + +Note the usage of single quotes surrounding the dictionary to escape the double +quotes. This is required in some shells. The above will result in the following +LLM class to be generated. + +```python +BedrockProvider(model_kwargs={"maxTokens":200}, ...) +``` + +Here is another example, where `anthropic` provider will be created with the +values for `max_tokens` and `temperature`, when `claude-2` model is selected. + + +```bash +jupyter lab --AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}' +``` + +The above will result in the following LLM class to be generated. + +```python +AnthropicProvider(max_tokens=1024, temperature=0.9, ...) +``` + +To pass multiple sets of model parameters for multiple models in the +command-line, you can append them as additional arguments to +`--AiExtension.model_parameters`, as shown below. + +```bash +jupyter lab \ +--AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}' \ +--AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}' +``` + +However, for more complex configuration, we highly recommend that you specify +this in a dedicated configuration file. We will describe how to do so in the +following section. + +#### Configuring as a config file + +This configuration can also be specified in a config file in json format. The +file should be named `jupyter_jupyter_ai_config.json` and saved in a path that +JupyterLab can pick from. You can find this path by running `jupyter --paths` +command, and picking one of the paths from the `config` section. + +Here is an example of running the `jupyter --paths` command. + +```bash +(jupyter-ai-lab4) ➜ jupyter --paths +config: + /opt/anaconda3/envs/jupyter-ai-lab4/etc/jupyter + /Users/3coins/.jupyter + /Users/3coins/.local/etc/jupyter + /usr/3coins/etc/jupyter + /etc/jupyter +data: + /opt/anaconda3/envs/jupyter-ai-lab4/share/jupyter + /Users/3coins/Library/Jupyter + /Users/3coins/.local/share/jupyter + /usr/local/share/jupyter + /usr/share/jupyter +runtime: + /Users/3coins/Library/Jupyter/runtime +``` + +Here is an example for configuring the `bedrock` provider for `ai21.j2-mid-v1` +model. + +```json +{ + "AiExtension": { + "model_parameters": { + "bedrock:ai21.j2-mid-v1": { + "model_kwargs": { + "maxTokens": 200 + } + } + } + } +} +``` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index 5fe522beb..d09907fb0 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Type +from typing import ClassVar, List from jupyter_ai_magics.providers import ( AuthStrategy, @@ -12,7 +12,6 @@ HuggingFaceHubEmbeddings, OpenAIEmbeddings, ) -from langchain.embeddings.base import Embeddings from pydantic import BaseModel, Extra diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index b5741a19b..8763566f4 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -392,9 +392,11 @@ def handle_error(self, args: ErrorArgs): prompt = f"Explain the following error:\n\n{last_error}" # Set CellArgs based on ErrorArgs - cell_args = CellArgs( - type="root", model_id=args.model_id, format=args.format, reset=False - ) + values = args.dict() + values["type"] = "root" + values["reset"] = False + cell_args = CellArgs(**values) + return self.run_ai_cell(cell_args, prompt) def _append_exchange_openai(self, prompt: str, output: str): @@ -518,16 +520,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str): provider_params["request_schema"] = args.request_schema provider_params["response_path"] = args.response_path - # Validate that the request schema is well-formed JSON - try: - json.loads(args.request_schema) - except json.JSONDecodeError as e: - raise ValueError( - "request-schema must be valid JSON. " - f"Error at line {e.lineno}, column {e.colno}: {e.msg}" - ) from None + model_parameters = json.loads(args.model_parameters) - provider = Provider(**provider_params) + provider = Provider(**provider_params, **model_parameters) # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index a6acf3525..cadd41f4a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -1,3 +1,4 @@ +import json from typing import Literal, Optional, get_args import click @@ -32,12 +33,21 @@ + "does nothing with other providers." ) +MODEL_PARAMETERS_SHORT_OPTION = "-m" +MODEL_PARAMETERS_LONG_OPTION = "--model-parameters" +MODEL_PARAMETERS_HELP = ( + "A JSON value that specifies extra values that will be passed " + "to the model. The accepted value parsed to a dict, unpacked " + "and passed as-is to the provider class." +) + class CellArgs(BaseModel): type: Literal["root"] = "root" model_id: str format: FORMAT_CHOICES_TYPE reset: bool + model_parameters: Optional[str] # The following parameters are required only for SageMaker models region_name: Optional[str] request_schema: Optional[str] @@ -49,6 +59,7 @@ class ErrorArgs(BaseModel): type: Literal["error"] = "error" model_id: str format: FORMAT_CHOICES_TYPE + model_parameters: Optional[str] # The following parameters are required only for SageMaker models region_name: Optional[str] request_schema: Optional[str] @@ -93,6 +104,19 @@ def get_help(self, ctx): click.echo(super().get_help(ctx)) +def verify_json_value(ctx, param, value): + if not value: + return value + try: + json.loads(value) + except json.JSONDecodeError as e: + raise ValueError( + f"{param.get_error_hint(ctx)} must be valid JSON. " + f"Error at line {e.lineno}, column {e.colno}: {e.msg}" + ) + return value + + @click.command() @click.argument("model_id") @click.option( @@ -120,6 +144,7 @@ def get_help(self, ctx): REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP, + callback=verify_json_value, ) @click.option( RESPONSE_PATH_SHORT_OPTION, @@ -127,6 +152,14 @@ def get_help(self, ctx): required=False, help=RESPONSE_PATH_HELP, ) +@click.option( + MODEL_PARAMETERS_SHORT_OPTION, + MODEL_PARAMETERS_LONG_OPTION, + required=False, + help=MODEL_PARAMETERS_HELP, + callback=verify_json_value, + default="{}", +) def cell_magic_parser(**kwargs): """ Invokes a language model identified by MODEL_ID, with the prompt being @@ -166,6 +199,7 @@ def line_magic_parser(): REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP, + callback=verify_json_value, ) @click.option( RESPONSE_PATH_SHORT_OPTION, @@ -173,6 +207,14 @@ def line_magic_parser(): required=False, help=RESPONSE_PATH_HELP, ) +@click.option( + MODEL_PARAMETERS_SHORT_OPTION, + MODEL_PARAMETERS_LONG_OPTION, + required=False, + help=MODEL_PARAMETERS_HELP, + callback=verify_json_value, + default="{}", +) def error_subparser(**kwargs): """ Explains the most recent error. Takes the same options (except -r) as diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 9fdbffa7a..a17067921 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -5,7 +5,17 @@ import io import json from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union +from typing import ( + Any, + ClassVar, + Coroutine, + Dict, + List, + Literal, + Mapping, + Optional, + Union, +) from jsonpath_ng import parse from langchain.chat_models import ( @@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs): content_handler = JsonContentHandler( request_schema=request_schema, response_path=response_path ) + super().__init__(*args, **kwargs, content_handler=content_handler) async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 2f3f1388a..e5c852051 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - self.llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + self.llm = provider(**provider_params, **model_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 7894dcfa5..5ffe65c7c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -3,7 +3,7 @@ import traceback # necessary to prevent circular import -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type from uuid import uuid4 from jupyter_ai.config_manager import ConfigManager, Logger @@ -23,10 +23,12 @@ def __init__( log: Logger, config_manager: ConfigManager, root_chat_handlers: Dict[str, "RootChatHandler"], + model_parameters: Dict[str, Dict], ): self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers + self.model_parameters = model_parameters self.parser = argparse.ArgumentParser() self.llm = None self.llm_params = None @@ -122,6 +124,13 @@ def get_llm_chain(self): self.llm_params = lm_provider_params return self.llm_chain + def get_model_parameters( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): + return self.model_parameters.get( + f"{provider.id}:{provider_params['model_id']}", {} + ) + def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3638b6cfb..d329e05e2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + llm = provider(**provider_params, **model_parameters) if llm.is_chat_provider: prompt_template = ChatPromptTemplate.from_messages( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 5036c8dfb..cdddee92d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + llm = provider(**provider_params, **model_parameters) + self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a2ecd5245..8ab8c0cc6 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,7 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp -from traitlets import List, Unicode +from traitlets import Dict, List, Unicode from .chat_handlers import ( AskChatHandler, @@ -81,6 +81,17 @@ class AiExtension(ExtensionApp): config=True, ) + model_parameters = Dict( + key_trait=Unicode(), + value_trait=Dict(), + default_value={}, + help="""Key-value pairs for model id and corresponding parameters that + are passed to the provider class. The values are unpacked and passed to + the provider class as-is.""", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() @@ -96,6 +107,9 @@ def initialize_settings(self): self.log.info(f"Configured model allowlist: {self.allowed_models}") self.log.info(f"Configured model blocklist: {self.blocked_models}") + self.settings["model_parameters"] = self.model_parameters + self.log.info(f"Configured model parameters: {self.model_parameters}") + # Fetch LM & EM providers self.settings["lm_providers"] = get_lm_providers( log=self.log, restrictions=restrictions @@ -147,6 +161,7 @@ def initialize_settings(self): "log": self.log, "config_manager": self.settings["jai_config_manager"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], + "model_parameters": self.settings["model_parameters"], } default_chat_handler = DefaultChatHandler( **chat_handler_kwargs, chat_history=self.settings["chat_history"]