Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model parameters option to pass in model tuning, arbitrary parameters #430

Merged
merged 8 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 113 additions & 5 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -855,30 +855,138 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP

## Configuration

You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers.
You can specify an allowlist, to only allow only a certain list of providers, or
a blocklist, to block some providers.

### Blocklisting providers
This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section.

This configuration allows for blocking specific providers in the settings panel.
This list takes precedence over the allowlist in the next section.

```
jupyter lab --AiExtension.blocked_providers=openai
```

To block more than one provider in the block-list, repeat the runtime configuration.
To block more than one provider in the block-list, repeat the runtime
configuration.

```
jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21
```

### Allowlisting providers
This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers.

This configuration allows for filtering the list of providers in the settings
panel to only an allowlisted set of providers.

```
jupyter lab --AiExtension.allowed_providers=openai
```

To allow more than one provider in the allowlist, repeat the runtime configuration.
To allow more than one provider in the allowlist, repeat the runtime
configuration.

```
jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21
```

### Model parameters

This configuration allows specifying arbitrary parameters that are unpacked and
passed to the provider class. This is useful for passing parameters such as
model tuning that affect the response generation by the model. This is also an
appropriate place to pass in custom attributes required by certain
providers/models.

The accepted value is a dictionary, with top level keys as the model id
(provider:model_id), and value should be any arbitrary dictionary which is
unpacked and passed as-is to the provider class.

#### Configuring as a startup option

In this sample, the `bedrock` provider will be created with the value for
`model_kwargs` when `ai21.j2-mid-v1` model is selected.

```bash
jupyter lab --AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}'
```

Note the usage of single quotes surrounding the dictionary to escape the double
quotes. This is required in some shells. The above will result in the following
LLM class to be generated.

```python
BedrockProvider(model_kwargs={"maxTokens":200}, ...)
```

Here is another example, where `anthropic` provider will be created with the
values for `max_tokens` and `temperature`, when `claude-2` model is selected.


```bash
jupyter lab --AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}'
```

The above will result in the following LLM class to be generated.

```python
AnthropicProvider(max_tokens=1024, temperature=0.9, ...)
```

To pass multiple sets of model parameters for multiple models in the
command-line, you can append them as additional arguments to
`--AiExtension.model_parameters`, as shown below.

```bash
jupyter lab \
--AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}' \
--AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}'
```

However, for more complex configuration, we highly recommend that you specify
this in a dedicated configuration file. We will describe how to do so in the
following section.

#### Configuring as a config file

This configuration can also be specified in a config file in json format. The
file should be named `jupyter_jupyter_ai_config.json` and saved in a path that
JupyterLab can pick from. You can find this path by running `jupyter --paths`
command, and picking one of the paths from the `config` section.

Here is an example of running the `jupyter --paths` command.

```bash
(jupyter-ai-lab4) ➜ jupyter --paths
config:
/opt/anaconda3/envs/jupyter-ai-lab4/etc/jupyter
/Users/3coins/.jupyter
/Users/3coins/.local/etc/jupyter
/usr/3coins/etc/jupyter
/etc/jupyter
data:
/opt/anaconda3/envs/jupyter-ai-lab4/share/jupyter
/Users/3coins/Library/Jupyter
/Users/3coins/.local/share/jupyter
/usr/local/share/jupyter
/usr/share/jupyter
runtime:
/Users/3coins/Library/Jupyter/runtime
```

Here is an example for configuring the `bedrock` provider for `ai21.j2-mid-v1`
model.

```json
{
"AiExtension": {
"model_parameters": {
"bedrock:ai21.j2-mid-v1": {
"model_kwargs": {
"maxTokens": 200
}
}
}
}
}
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, Type
from typing import ClassVar, List

from jupyter_ai_magics.providers import (
AuthStrategy,
Expand All @@ -12,7 +12,6 @@
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
)
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel, Extra


Expand Down
19 changes: 7 additions & 12 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,11 @@ def handle_error(self, args: ErrorArgs):

prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
cell_args = CellArgs(
type="root", model_id=args.model_id, format=args.format, reset=False
)
values = args.dict()
values["type"] = "root"
values["reset"] = False
cell_args = CellArgs(**values)

return self.run_ai_cell(cell_args, prompt)

def _append_exchange_openai(self, prompt: str, output: str):
Expand Down Expand Up @@ -518,16 +520,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

# Validate that the request schema is well-formed JSON
try:
json.loads(args.request_schema)
except json.JSONDecodeError as e:
raise ValueError(
"request-schema must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
) from None
model_parameters = json.loads(args.model_parameters)

provider = Provider(**provider_params)
provider = Provider(**provider_params, **model_parameters)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)
Expand Down
42 changes: 42 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Literal, Optional, get_args

import click
Expand Down Expand Up @@ -32,12 +33,21 @@
+ "does nothing with other providers."
)

MODEL_PARAMETERS_SHORT_OPTION = "-m"
MODEL_PARAMETERS_LONG_OPTION = "--model-parameters"
MODEL_PARAMETERS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the model. The accepted value parsed to a dict, unpacked "
"and passed as-is to the provider class."
)


class CellArgs(BaseModel):
type: Literal["root"] = "root"
model_id: str
format: FORMAT_CHOICES_TYPE
reset: bool
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand All @@ -49,6 +59,7 @@ class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand Down Expand Up @@ -93,6 +104,19 @@ def get_help(self, ctx):
click.echo(super().get_help(ctx))


def verify_json_value(ctx, param, value):
if not value:
return value
try:
json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(
f"{param.get_error_hint(ctx)} must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
)
return value


@click.command()
@click.argument("model_id")
@click.option(
Expand Down Expand Up @@ -120,13 +144,22 @@ def get_help(self, ctx):
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def cell_magic_parser(**kwargs):
"""
Invokes a language model identified by MODEL_ID, with the prompt being
Expand Down Expand Up @@ -166,13 +199,22 @@ def line_magic_parser():
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def error_subparser(**kwargs):
"""
Explains the most recent error. Takes the same options (except -r) as
Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models import (
Expand Down Expand Up @@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs):
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
self.llm = provider(**provider_params, **model_parameters)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from uuid import uuid4

from jupyter_ai.config_manager import ConfigManager, Logger
Expand All @@ -23,10 +23,12 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self.model_parameters = model_parameters
self.parser = argparse.ArgumentParser()
self.llm = None
self.llm_params = None
Expand Down Expand Up @@ -122,6 +124,13 @@ def get_llm_chain(self):
self.llm_params = lm_provider_params
return self.llm_chain

def get_model_parameters(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
return self.model_parameters.get(
f"{provider.id}:{provider_params['model_id']}", {}
)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
Expand Down
4 changes: 3 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

self.llm = llm
return llm

Expand Down
Loading
Loading