Skip to content

Commit

Permalink
Fix api_base undefined bug in Gemini embeddings (#11393)
Browse files Browse the repository at this point in the history
* Support Gemini "transport" configuration

Added Gemini transportation method configuration support.

* Sync updates in multi_modal_llms\gemini

* Updated Dashscope qwen llm defaults

Setting qwen default num_outputs and temperature

* cr

* support gemini embedding configuration

support configuring api_base, api_key, transport method

* fix gptrepo data connector encoding issue

reading a file in default encoding(GBK) will cause error characters problem. Added encoding configuration

* sync latest repo

* sync latest repo

* cr

* cr

* Fix api_base undefined bug in Gemini embeddings

* add comments

* fix linter test

* sync fix in integrations/embeddings

* fix unit test

---------

Co-authored-by: Haotian Zhang <[email protected]>
  • Loading branch information
BetterAndBetterII and hatianzhang authored Feb 27, 2024
1 parent e6e9abd commit 27f7691
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Gemini embeddings file."""

import os
from typing import Any, List, Optional

import google.generativeai as gemini
Expand All @@ -19,6 +20,8 @@ class GeminiEmbedding(BaseEmbedding):
Defaults to "models/embedding-001".
api_key (Optional[str]): API key to access the model. Defaults to None.
api_base (Optional[str]): API base to access the model. Defaults to Official Base.
transport (Optional[str]): Transport to access the model.
"""

_model: Any = PrivateAttr()
Expand All @@ -36,12 +39,24 @@ def __init__(
model_name: str = "models/embedding-001",
task_type: Optional[str] = "retrieval_document",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
transport: Optional[str] = None,
title: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
):
gemini.configure(api_key=api_key)
# API keys are optional. The API can be authorised via OAuth (detected
# environmentally) or by the GOOGLE_API_KEY environment variable.
config_params: Dict[str, Any] = {
"api_key": api_key or os.getenv("GOOGLE_API_KEY"),
}
if api_base:
config_params["client_options"] = {"api_endpoint": api_base}
if transport:
config_params["transport"] = transport
# transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
gemini.configure(**config_params)
self._model = gemini

super().__init__(
Expand Down
4 changes: 4 additions & 0 deletions llama-index-legacy/llama_index/legacy/embeddings/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class GeminiEmbedding(BaseEmbedding):
Defaults to "models/embedding-001".
api_key (Optional[str]): API key to access the model. Defaults to None.
api_base (Optional[str]): API base to access the model. Defaults to Official Base.
transport (Optional[str]): Transport to access the model.
"""

_model: Any = PrivateAttr()
Expand All @@ -35,6 +37,8 @@ def __init__(
model_name: str = "models/embedding-001",
task_type: Optional[str] = "retrieval_document",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
transport: Optional[str] = None,
title: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
Expand Down

0 comments on commit 27f7691

Please sign in to comment.