Skip to content

Commit

Permalink
Backport PR #430: Model parameters option to pass in model tuning, ar…
Browse files Browse the repository at this point in the history
…bitrary parameters (#453)

Co-authored-by: Piyush Jain <[email protected]>
  • Loading branch information
JasonWeill and 3coins authored Nov 8, 2023
1 parent 33831e9 commit bf22409
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, Type
from typing import ClassVar, List

from jupyter_ai_magics.providers import (
AuthStrategy,
Expand All @@ -12,7 +12,6 @@
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
)
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel, Extra


Expand Down
19 changes: 7 additions & 12 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,11 @@ def handle_error(self, args: ErrorArgs):

prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
cell_args = CellArgs(
type="root", model_id=args.model_id, format=args.format, reset=False
)
values = args.dict()
values["type"] = "root"
values["reset"] = False
cell_args = CellArgs(**values)

return self.run_ai_cell(cell_args, prompt)

def _append_exchange_openai(self, prompt: str, output: str):
Expand Down Expand Up @@ -538,16 +540,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

# Validate that the request schema is well-formed JSON
try:
json.loads(args.request_schema)
except json.JSONDecodeError as e:
raise ValueError(
"request-schema must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
) from None
model_parameters = json.loads(args.model_parameters)

provider = Provider(**provider_params)
provider = Provider(**provider_params, **model_parameters)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)
Expand Down
42 changes: 42 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Literal, Optional, get_args

import click
Expand Down Expand Up @@ -32,12 +33,21 @@
+ "does nothing with other providers."
)

MODEL_PARAMETERS_SHORT_OPTION = "-m"
MODEL_PARAMETERS_LONG_OPTION = "--model-parameters"
MODEL_PARAMETERS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the model. The accepted value parsed to a dict, unpacked "
"and passed as-is to the provider class."
)


class CellArgs(BaseModel):
type: Literal["root"] = "root"
model_id: str
format: FORMAT_CHOICES_TYPE
reset: bool
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand All @@ -49,6 +59,7 @@ class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand Down Expand Up @@ -93,6 +104,19 @@ def get_help(self, ctx):
click.echo(super().get_help(ctx))


def verify_json_value(ctx, param, value):
if not value:
return value
try:
json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(
f"{param.get_error_hint(ctx)} must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
)
return value


@click.command()
@click.argument("model_id")
@click.option(
Expand Down Expand Up @@ -120,13 +144,22 @@ def get_help(self, ctx):
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def cell_magic_parser(**kwargs):
"""
Invokes a language model identified by MODEL_ID, with the prompt being
Expand Down Expand Up @@ -166,13 +199,22 @@ def line_magic_parser():
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def error_subparser(**kwargs):
"""
Explains the most recent error. Takes the same options (except -r) as
Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models import (
Expand Down Expand Up @@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs):
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
self.llm = provider(**provider_params, **model_parameters)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from uuid import uuid4

from jupyter_ai.config_manager import ConfigManager, Logger
Expand All @@ -23,10 +23,12 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self.model_parameters = model_parameters
self.parser = argparse.ArgumentParser()
self.llm = None
self.llm_params = None
Expand Down Expand Up @@ -122,6 +124,13 @@ def get_llm_chain(self):
self.llm_params = lm_provider_params
return self.llm_chain

def get_model_parameters(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
return self.model_parameters.get(
f"{provider.id}:{provider_params['model_id']}", {}
)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
Expand Down
4 changes: 3 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

self.llm = llm
return llm

Expand Down
46 changes: 45 additions & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode
from traitlets import Dict, List, Unicode

from .chat_handlers import (
AskChatHandler,
Expand Down Expand Up @@ -53,13 +53,56 @@ class AiExtension(ExtensionApp):
config=True,
)

allowed_models = List(
Unicode(),
default_value=None,
help="""
Language models to allow, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to
`None`.
Note: Currently, if `allowed_providers` is also set, then this field is
ignored. This is subject to change in a future non-major release. Using
both traits is considered to be undefined behavior at this time.
""",
allow_none=True,
config=True,
)

blocked_models = List(
Unicode(),
default_value=None,
help="""
Language models to block, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to
`None`.
""",
allow_none=True,
config=True,
)

model_parameters = Dict(
key_trait=Unicode(),
value_trait=Dict(),
default_value={},
help="""Key-value pairs for model id and corresponding parameters that
are passed to the provider class. The values are unpacked and passed to
the provider class as-is.""",
allow_none=True,
config=True,
)

def initialize_settings(self):
start = time.time()
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

self.settings["model_parameters"] = self.model_parameters
self.log.info(f"Configured model parameters: {self.model_parameters}")

# Fetch LM & EM providers
self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
Expand Down Expand Up @@ -107,6 +150,7 @@ def initialize_settings(self):
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"model_parameters": self.settings["model_parameters"],
}
default_chat_handler = DefaultChatHandler(
**chat_handler_kwargs, chat_history=self.settings["chat_history"]
Expand Down

0 comments on commit bf22409

Please sign in to comment.