From 695dfa8df04ec5cd0e0796235e16c30cd5f13fde Mon Sep 17 00:00:00 2001 From: mdciri Date: Tue, 10 Dec 2024 16:09:00 +0100 Subject: [PATCH] Update embedding model --- apps/chatbot/.env.example | 3 ++- apps/chatbot/src/modules/models.py | 7 ++++--- apps/chatbot/src/modules/vector_database.py | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/apps/chatbot/.env.example b/apps/chatbot/.env.example index 5d1f2ed1ca..2ca70aad6c 100644 --- a/apps/chatbot/.env.example +++ b/apps/chatbot/.env.example @@ -5,7 +5,8 @@ AUTH_COGNITO_ISSUER=https://cognito-idp.eu-south-1.amazonaws.com/eu-south-1_xxxx AUTH_DISABLE_SIGNUP=false AUTH_DISABLE_USERNAME_PASSWORD=true CHB_AWS_ACCESS_KEY_ID=... -CHB_AWS_BEDROCK_REGION=eu-west-3 +CHB_AWS_BEDROCK_EMBED_REGION=eu-central-1 +CHB_AWS_BEDROCK_LLM_REGION=eu-west-3 CHB_AWS_DEFAULT_REGION=eu-south-1 CHB_AWS_GUARDRAIL_ID=... CHB_AWS_GUARDRAIL_VERSION=... diff --git a/apps/chatbot/src/modules/models.py b/apps/chatbot/src/modules/models.py index d9734cd7de..41abfc642e 100644 --- a/apps/chatbot/src/modules/models.py +++ b/apps/chatbot/src/modules/models.py @@ -21,7 +21,8 @@ GOOGLE_API_KEY = get_ssm_parameter(name=os.getenv("CHB_GOOGLE_API_KEY")) AWS_ACCESS_KEY_ID = os.getenv("CHB_AWS_ACCESS_KEY_ID") AWS_SECRET_ACCESS_KEY = os.getenv("CHB_AWS_SECRET_ACCESS_KEY") -AWS_BEDROCK_REGION = os.getenv("CHB_AWS_BEDROCK_REGION") +AWS_BEDROCK_LLM_REGION = os.getenv("CHB_AWS_BEDROCK_LLM_REGION") +AWS_BEDROCK_EMBED_REGION = os.getenv("CHB_AWS_BEDROCK_EMBED_REGION") AWS_GUARDRAIL_ID = os.getenv("CHB_AWS_GUARDRAIL_ID") AWS_GUARDRAIL_VERSION = os.getenv("CHB_AWS_GUARDRAIL_VERSION") @@ -41,7 +42,7 @@ def get_llm(): max_tokens=int(MODEL_MAXTOKENS), aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_BEDROCK_REGION + region_name=AWS_BEDROCK_LLM_REGION ) else: @@ -71,7 +72,7 @@ def get_embed_model(): model_name = EMBED_MODEL_ID, aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_BEDROCK_REGION + region_name=AWS_BEDROCK_EMBED_REGION ) else: embed_model = GeminiEmbedding( diff --git a/apps/chatbot/src/modules/vector_database.py b/apps/chatbot/src/modules/vector_database.py index 8ce403b118..c4995b26b5 100644 --- a/apps/chatbot/src/modules/vector_database.py +++ b/apps/chatbot/src/modules/vector_database.py @@ -55,7 +55,8 @@ EMBED_MODEL_ID = os.getenv("CHB_EMBED_MODEL_ID") EMBEDDING_DIMS = { "models/text-embedding-004": 768, - "cohere.embed-multilingual-v3": 1024 + "cohere.embed-multilingual-v3": 1024, + "amazon.titan-embed-text-v2:0": 1024 } REDIS_SCHEMA = IndexSchema.from_dict({ "index": {"name": f"{INDEX_ID}", "prefix": f"{INDEX_ID}/vector"},