Skip to content

Commit 601e857

Browse files
committed
rebase
1 parent 8ae9896 commit 601e857

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

app/config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class EmbeddingsProvider(Enum):
2525
HUGGINGFACETEI = "huggingfacetei"
2626
OLLAMA = "ollama"
2727
BEDROCK = "bedrock"
28-
28+
MISTRAL = "mistral"
2929

3030
def get_env_variable(
3131
var_name: str, default_value: str = None, required: bool = False
@@ -176,6 +176,7 @@ async def dispatch(self, request, call_next):
176176
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
177177
AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
178178
AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")
179+
MISTRAL_API_KEY = get_env_variable("MISTRAL_API_KEY", "")
179180

180181
## Embeddings
181182

@@ -226,6 +227,13 @@ def init_embeddings(provider, model):
226227
model_id=model,
227228
region_name=AWS_DEFAULT_REGION,
228229
)
230+
elif provider == EmbeddingsProvider.MISTRAL:
231+
from langchain_mistralai import MistralAIEmbeddings
232+
233+
return MistralAIEmbeddings(
234+
model=model,
235+
api_key=MISTRAL_API_KEY,
236+
)
229237
else:
230238
raise ValueError(f"Unsupported embeddings provider: {provider}")
231239

@@ -253,6 +261,10 @@ def init_embeddings(provider, model):
253261
"EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1"
254262
)
255263
AWS_DEFAULT_REGION = get_env_variable("AWS_DEFAULT_REGION", "us-east-1")
264+
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.MISTRAL:
265+
EMBEDDINGS_MODEL = get_env_variable(
266+
"EMBEDDINGS_MODEL", "mistral-embed"
267+
)
256268
else:
257269
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")
258270

0 commit comments

Comments
 (0)