Skip to content

Commit

Permalink
feat: add support for Gemini Pro
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Dec 13, 2023
1 parent e7a22e4 commit bf1c3f7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 224 deletions.
20 changes: 19 additions & 1 deletion pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class GoogleVertexAI(BaseGoogle):
"text-bison@002",
"text-unicorn@001",
]
_supported_generative_models = [
"gemini-pro",
]

def __init__(
self, project_id: str, location: str, model: Optional[str] = None, **kwargs
Expand All @@ -52,7 +55,8 @@ def __init__(
**kwargs: Arguments to control the Model Parameters
"""

self.model = "text-bison@001" if model is None else model
self.model = model or "text-bison@001"

self._configure(project_id, location)
self.project_id = project_id
self.location = location
Expand Down Expand Up @@ -108,6 +112,7 @@ def _generate_text(self, prompt: str) -> str:
CodeGenerationModel,
TextGenerationModel,
)
from vertexai.preview.generative_models import GenerativeModel

if self.model in self._supported_code_models:
code_generation = CodeGenerationModel.from_pretrained(self.model)
Expand All @@ -127,6 +132,19 @@ def _generate_text(self, prompt: str) -> str:
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_generative_models:
model = GenerativeModel(self.model)
responses = model.generate_content(
[prompt],
generation_config={
"max_output_tokens": self.max_output_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
},
)

completion = responses.candidates[0].content.parts[0]
else:
raise UnsupportedModelError(self.model)

Expand Down
Loading

0 comments on commit bf1c3f7

Please sign in to comment.