Skip to content

Commit

Permalink
Model parameters option to pass in model tuning, arbitrary parameters (
Browse files Browse the repository at this point in the history
…jupyterlab#430)

* Endpoint args for SM endpoints

* Added model and endpoints kwargs options.

* Added configurable option for model parameters.

* Updated magics, added model_parameters, removed model_kwargs and endpoint_kwargs.

* Fixes %ai error for SM endpoints.

* Fixed docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 430 fixes (jupyterlab#2)

* log configured model_parameters

* fix markdown formatting in docs

* fix single quotes and use preferred traitlets CLI syntax

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: david qiu <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2023
1 parent 4ce0607 commit 09a339e
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 25 deletions.
118 changes: 113 additions & 5 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -855,30 +855,138 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP

## Configuration

You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers.
You can specify an allowlist, to only allow only a certain list of providers, or
a blocklist, to block some providers.

### Blocklisting providers
This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section.

This configuration allows for blocking specific providers in the settings panel.
This list takes precedence over the allowlist in the next section.

```
jupyter lab --AiExtension.blocked_providers=openai
```

To block more than one provider in the block-list, repeat the runtime configuration.
To block more than one provider in the block-list, repeat the runtime
configuration.

```
jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21
```

### Allowlisting providers
This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers.

This configuration allows for filtering the list of providers in the settings
panel to only an allowlisted set of providers.

```
jupyter lab --AiExtension.allowed_providers=openai
```

To allow more than one provider in the allowlist, repeat the runtime configuration.
To allow more than one provider in the allowlist, repeat the runtime
configuration.

```
jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21
```

### Model parameters

This configuration allows specifying arbitrary parameters that are unpacked and
passed to the provider class. This is useful for passing parameters such as
model tuning that affect the response generation by the model. This is also an
appropriate place to pass in custom attributes required by certain
providers/models.

The accepted value is a dictionary, with top level keys as the model id
(provider:model_id), and value should be any arbitrary dictionary which is
unpacked and passed as-is to the provider class.

#### Configuring as a startup option

In this sample, the `bedrock` provider will be created with the value for
`model_kwargs` when `ai21.j2-mid-v1` model is selected.

```bash
jupyter lab --AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}'
```

Note the usage of single quotes surrounding the dictionary to escape the double
quotes. This is required in some shells. The above will result in the following
LLM class to be generated.

```python
BedrockProvider(model_kwargs={"maxTokens":200}, ...)
```

Here is another example, where `anthropic` provider will be created with the
values for `max_tokens` and `temperature`, when `claude-2` model is selected.


```bash
jupyter lab --AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}'
```

The above will result in the following LLM class to be generated.

```python
AnthropicProvider(max_tokens=1024, temperature=0.9, ...)
```

To pass multiple sets of model parameters for multiple models in the
command-line, you can append them as additional arguments to
`--AiExtension.model_parameters`, as shown below.

```bash
jupyter lab \
--AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}' \
--AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}'
```

However, for more complex configuration, we highly recommend that you specify
this in a dedicated configuration file. We will describe how to do so in the
following section.

#### Configuring as a config file

This configuration can also be specified in a config file in json format. The
file should be named `jupyter_jupyter_ai_config.json` and saved in a path that
JupyterLab can pick from. You can find this path by running `jupyter --paths`
command, and picking one of the paths from the `config` section.

Here is an example of running the `jupyter --paths` command.

```bash
(jupyter-ai-lab4) ➜ jupyter --paths
config:
/opt/anaconda3/envs/jupyter-ai-lab4/etc/jupyter
/Users/3coins/.jupyter
/Users/3coins/.local/etc/jupyter
/usr/3coins/etc/jupyter
/etc/jupyter
data:
/opt/anaconda3/envs/jupyter-ai-lab4/share/jupyter
/Users/3coins/Library/Jupyter
/Users/3coins/.local/share/jupyter
/usr/local/share/jupyter
/usr/share/jupyter
runtime:
/Users/3coins/Library/Jupyter/runtime
```

Here is an example for configuring the `bedrock` provider for `ai21.j2-mid-v1`
model.

```json
{
"AiExtension": {
"model_parameters": {
"bedrock:ai21.j2-mid-v1": {
"model_kwargs": {
"maxTokens": 200
}
}
}
}
}
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, Type
from typing import ClassVar, List

from jupyter_ai_magics.providers import (
AuthStrategy,
Expand All @@ -12,7 +12,6 @@
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
)
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel, Extra


Expand Down
19 changes: 7 additions & 12 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,11 @@ def handle_error(self, args: ErrorArgs):

prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
cell_args = CellArgs(
type="root", model_id=args.model_id, format=args.format, reset=False
)
values = args.dict()
values["type"] = "root"
values["reset"] = False
cell_args = CellArgs(**values)

return self.run_ai_cell(cell_args, prompt)

def _append_exchange_openai(self, prompt: str, output: str):
Expand Down Expand Up @@ -518,16 +520,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

# Validate that the request schema is well-formed JSON
try:
json.loads(args.request_schema)
except json.JSONDecodeError as e:
raise ValueError(
"request-schema must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
) from None
model_parameters = json.loads(args.model_parameters)

provider = Provider(**provider_params)
provider = Provider(**provider_params, **model_parameters)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)
Expand Down
42 changes: 42 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Literal, Optional, get_args

import click
Expand Down Expand Up @@ -32,12 +33,21 @@
+ "does nothing with other providers."
)

MODEL_PARAMETERS_SHORT_OPTION = "-m"
MODEL_PARAMETERS_LONG_OPTION = "--model-parameters"
MODEL_PARAMETERS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the model. The accepted value parsed to a dict, unpacked "
"and passed as-is to the provider class."
)


class CellArgs(BaseModel):
type: Literal["root"] = "root"
model_id: str
format: FORMAT_CHOICES_TYPE
reset: bool
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand All @@ -49,6 +59,7 @@ class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand Down Expand Up @@ -93,6 +104,19 @@ def get_help(self, ctx):
click.echo(super().get_help(ctx))


def verify_json_value(ctx, param, value):
if not value:
return value
try:
json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(
f"{param.get_error_hint(ctx)} must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
)
return value


@click.command()
@click.argument("model_id")
@click.option(
Expand Down Expand Up @@ -120,13 +144,22 @@ def get_help(self, ctx):
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def cell_magic_parser(**kwargs):
"""
Invokes a language model identified by MODEL_ID, with the prompt being
Expand Down Expand Up @@ -166,13 +199,22 @@ def line_magic_parser():
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def error_subparser(**kwargs):
"""
Explains the most recent error. Takes the same options (except -r) as
Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models import (
Expand Down Expand Up @@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs):
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
self.llm = provider(**provider_params, **model_parameters)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from uuid import uuid4

from jupyter_ai.config_manager import ConfigManager, Logger
Expand All @@ -23,10 +23,12 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self.model_parameters = model_parameters
self.parser = argparse.ArgumentParser()
self.llm = None
self.llm_params = None
Expand Down Expand Up @@ -122,6 +124,13 @@ def get_llm_chain(self):
self.llm_params = lm_provider_params
return self.llm_chain

def get_model_parameters(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
return self.model_parameters.get(
f"{provider.id}:{provider_params['model_id']}", {}
)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
Expand Down
4 changes: 3 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

self.llm = llm
return llm

Expand Down
Loading

0 comments on commit 09a339e

Please sign in to comment.