@@ -25,7 +25,7 @@ class EmbeddingsProvider(Enum):
2525 HUGGINGFACETEI = "huggingfacetei"
2626 OLLAMA = "ollama"
2727 BEDROCK = "bedrock"
28-
28+ MISTRAL = "mistral"
2929
3030def get_env_variable (
3131 var_name : str , default_value : str = None , required : bool = False
@@ -176,6 +176,7 @@ async def dispatch(self, request, call_next):
176176OLLAMA_BASE_URL = get_env_variable ("OLLAMA_BASE_URL" , "http://ollama:11434" )
177177AWS_ACCESS_KEY_ID = get_env_variable ("AWS_ACCESS_KEY_ID" , "" )
178178AWS_SECRET_ACCESS_KEY = get_env_variable ("AWS_SECRET_ACCESS_KEY" , "" )
179+ MISTRAL_API_KEY = get_env_variable ("MISTRAL_API_KEY" , "" )
179180
180181## Embeddings
181182
@@ -226,6 +227,13 @@ def init_embeddings(provider, model):
226227 model_id = model ,
227228 region_name = AWS_DEFAULT_REGION ,
228229 )
230+ elif provider == EmbeddingsProvider .MISTRAL :
231+ from langchain_mistralai import MistralAIEmbeddings
232+
233+ return MistralAIEmbeddings (
234+ model = model ,
235+ api_key = MISTRAL_API_KEY ,
236+ )
229237 else :
230238 raise ValueError (f"Unsupported embeddings provider: { provider } " )
231239
@@ -253,6 +261,10 @@ def init_embeddings(provider, model):
253261 "EMBEDDINGS_MODEL" , "amazon.titan-embed-text-v1"
254262 )
255263 AWS_DEFAULT_REGION = get_env_variable ("AWS_DEFAULT_REGION" , "us-east-1" )
264+ elif EMBEDDINGS_PROVIDER == EmbeddingsProvider .MISTRAL :
265+ EMBEDDINGS_MODEL = get_env_variable (
266+ "EMBEDDINGS_MODEL" , "mistral-embed"
267+ )
256268else :
257269 raise ValueError (f"Unsupported embeddings provider: { EMBEDDINGS_PROVIDER } " )
258270
0 commit comments