diff --git a/spacy_llm/models/rest/bedrock/model.py b/spacy_llm/models/rest/bedrock/model.py index ea62b1b8..149e292a 100644 --- a/spacy_llm/models/rest/bedrock/model.py +++ b/spacy_llm/models/rest/bedrock/model.py @@ -13,6 +13,7 @@ class Models(str, Enum): TITAN_LITE = "amazon.titan-text-lite-v1" AI21_JURASSIC_ULTRA = "ai21.j2-ultra-v1" AI21_JURASSIC_MID = "ai21.j2-mid-v1" + COHERE_COMMAND = "cohere.command-text-v14" TITAN_PARAMS = ["maxTokenCount", "stopSequences", "temperature", "topP"] @@ -24,6 +25,7 @@ class Models(str, Enum): "presencePenalty", "frequencyPenalty", ] +COHERE_PARAMS = ["max_tokens", "temperature"] class Bedrock(REST): @@ -45,6 +47,8 @@ def __init__( config_params = TITAN_PARAMS if self._model_id in [Models.AI21_JURASSIC_ULTRA, Models.AI21_JURASSIC_MID]: config_params = AI21_JURASSIC_PARAMS + if self._model_id in [Models.COHERE_COMMAND]: + config_params = COHERE_PARAMS for i in config_params: self._config[i] = config[i] @@ -141,6 +145,10 @@ def _request(json_data: str) -> str: responses = json.loads(r["body"].read().decode())["completions"][0][ "data" ]["text"] + elif self._model_id in [Models.COHERE_COMMAND]: + responses = json.loads(r["body"].read().decode())["generations"][0][ + "text" + ] return responses @@ -151,7 +159,12 @@ def _request(json_data: str) -> str: {"inputText": prompt, "textGenerationConfig": self._config} ) ) - if self._model_id in [Models.AI21_JURASSIC_ULTRA, Models.AI21_JURASSIC_MID]: + elif self._model_id in [ + Models.AI21_JURASSIC_ULTRA, + Models.AI21_JURASSIC_MID, + ]: + responses = _request(json.dumps({"prompt": prompt, **self._config})) + elif self._model_id in [Models.COHERE_COMMAND]: responses = _request(json.dumps({"prompt": prompt, **self._config})) api_responses.append(responses) @@ -181,4 +194,5 @@ def get_model_names(self) -> Tuple[str, ...]: "amazon.titan-text-lite-v1", "ai21.j2-ultra-v1", "ai21.j2-mid-v1", + "cohere.command-text-v14", ) diff --git a/spacy_llm/models/rest/bedrock/registry.py b/spacy_llm/models/rest/bedrock/registry.py index ed9197b7..9ca8f958 100644 --- a/spacy_llm/models/rest/bedrock/registry.py +++ b/spacy_llm/models/rest/bedrock/registry.py @@ -31,6 +31,8 @@ def bedrock( presencePenalty=_DEFAULT_PRESENCE_PENALTY, frequencyPenalty=_DEFAULT_FREQUENCY_PENALTY, stop_sequences=_DEFAULT_STOP_SEQUENCES, + # Params for Cohere models + max_tokens=_DEFAULT_MAX_TOKEN_COUNT, ), max_tries: int = _DEFAULT_RETRIES, ) -> Callable[[Iterable[str]], Iterable[str]]: