Skip to content

Commit

Permalink
implement model allow/blocklists in config manager
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Nov 8, 2023
1 parent d4ef987 commit 3bae8cf
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 38 deletions.
110 changes: 90 additions & 20 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import Optional, Union
from typing import List, Optional, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand All @@ -12,10 +12,8 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -57,6 +55,10 @@ class KeyEmptyError(Exception):
pass


class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
Expand Down Expand Up @@ -99,27 +101,34 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
allowed_providers: Optional[List[str]],
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
"""List of LM providers."""

self._lm_providers = lm_providers
"""List of EM providers."""
"""List of LM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions
"""List of EM providers."""

self._allowed_providers = allowed_providers
self._blocked_providers = blocked_providers
self._allowed_models = allowed_models
self._blocked_models = blocked_models

self._last_read: Optional[int] = None
"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
`self._config`."""
self._last_read: Optional[int] = None

self._config: Optional[GlobalConfig] = None
"""In-memory cache of the `GlobalConfig` object parsed from the config
file."""
self._config: Optional[GlobalConfig] = None

self._init_config_schema()
self._init_validator()
Expand All @@ -140,6 +149,26 @@ def _init_config(self):
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
Expand Down Expand Up @@ -181,33 +210,74 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return

# verify model is declared by some provider
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.model_provider_id)

# verify model is authenticated
_validate_provider_authn(config, lm_provider)

# validate embedding model config
if config.embeddings_provider_id:
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return

# verify model is declared by some provider
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.embeddings_provider_id)

# verify model is authenticated
_validate_provider_authn(config, em_provider)

def _validate_model(self, model_id: str, raise_exc=True):
"""
Validates a model against the set of allow/blocklists specified by the
traitlets configuration, returning `True` if the model is allowed, and
raising a `BlockedModelError` otherwise. If `raise_exc=False`, this
function returns `False` if the model is not allowed.
"""

assert model_id is not None
components = model_id.split(":", 1)
assert len(components) == 2
provider_id, _ = components

try:
if self._allowed_providers and provider_id not in self._allowed_providers:
raise BlockedModelError(
"Model provider not included in the provider allowlist."
)

if self._blocked_providers and provider_id in self._blocked_providers:
raise BlockedModelError(
"Model provider included in the provider blocklist."
)

if self._allowed_models and model_id not in self._allowed_models:
raise BlockedModelError("Model not included in the model allowlist.")

if self._blocked_models and model_id in self._blocked_models:
raise BlockedModelError("Model included in the model blocklist.")
except BlockedModelError as e:
if raise_exc:
raise e
else:
return False

return True

def _write_config(self, new_config: GlobalConfig):
"""Updates configuration and persists it to disk. This accepts a
complete `GlobalConfig` object, and should not be called publicly."""
Expand Down
21 changes: 18 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,27 @@ class AiExtension(ExtensionApp):
allowed_models = List(
Unicode(),
default_value=None,
help="Language models to allow, as a list of global model IDs in the format `<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to `None`.",
help="""
Language models to allow, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to
`None`.
Note: Currently, if `allowed_providers` is also set, then this field is
ignored. This is subject to change in a future non-major release. Using
both traits is considered to be undefined behavior at this time.
""",
allow_none=True,
config=True,
)

blocked_models = List(
Unicode(),
default_value=None,
help="Language models to block, as a list of global model IDs in the format `<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to `None`.",
help="""
Language models to block, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to
`None`.
""",
allow_none=True,
config=True,
)
Expand Down Expand Up @@ -98,7 +110,10 @@ def initialize_settings(self):
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
allowed_providers=self.allowed_providers,
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
blocked_models=self.blocked_models,
)

self.log.info("Registered providers.")
Expand Down
73 changes: 58 additions & 15 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def common_cm_kwargs(config_path, schema_path):
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"allowed_providers": None,
"blocked_providers": None,
"allowed_models": None,
"blocked_models": None,
}


Expand All @@ -46,6 +49,26 @@ def cm(common_cm_kwargs):
return ConfigManager(**common_cm_kwargs)


@pytest.fixture
def cm_with_blocklists(common_cm_kwargs):
kwargs = {
**common_cm_kwargs,
"blocked_providers": ["ai21"],
"blocked_models": ["cohere:medium"],
}
return ConfigManager(**kwargs)


@pytest.fixture
def cm_with_allowlists(common_cm_kwargs):
kwargs = {
**common_cm_kwargs,
"allowed_providers": ["ai21"],
"allowed_models": ["cohere:medium"],
}
return ConfigManager(**kwargs)


@pytest.fixture(autouse=True)
def reset(config_path, schema_path):
"""Fixture that deletes the config and config schema after each test."""
Expand Down Expand Up @@ -98,23 +121,43 @@ def test_snapshot_default_config(cm: ConfigManager, snapshot):
assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read")


def test_init_with_existing_config(
cm: ConfigManager, config_path: str, schema_path: str
):
def test_init_with_existing_config(cm: ConfigManager, common_cm_kwargs):
configure_to_cohere(cm)
del cm

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
)
ConfigManager(**common_cm_kwargs)


def test_init_with_blocklists(cm: ConfigManager, common_cm_kwargs):
configure_to_openai(cm)
del cm

blocked_providers = ["openai"] # blocks EM
blocked_models = ["openai-chat-new:gpt-3.5-turbo"] # blocks LM
kwargs = {
**common_cm_kwargs,
"blocked_providers": blocked_providers,
"blocked_models": blocked_models,
}
test_cm = ConfigManager(**kwargs)
assert test_cm._blocked_providers == blocked_providers
assert test_cm._blocked_models == blocked_models
assert test_cm.lm_gid == None
assert test_cm.em_gid == None


def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs):
configure_to_cohere(cm)
del cm

allowed_providers = ["openai"] # blocks both LM & EM

kwargs = {**common_cm_kwargs, "allowed_providers": allowed_providers}
test_cm = ConfigManager(**kwargs)
assert test_cm._allowed_providers == allowed_providers
assert test_cm._allowed_models == None
assert test_cm.lm_gid == None
assert test_cm.em_gid == None


def test_property_access_on_default_config(cm: ConfigManager):
Expand Down

0 comments on commit 3bae8cf

Please sign in to comment.