Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lack of Gemini Model Integration Without LiteLLM(vertexAI) #81

Closed
aniketqw opened this issue Jan 6, 2025 · 7 comments
Closed

Lack of Gemini Model Integration Without LiteLLM(vertexAI) #81

aniketqw opened this issue Jan 6, 2025 · 7 comments

Comments

@aniketqw
Copy link

aniketqw commented Jan 6, 2025

Issue Summary:

The current implementation of the smolagent does not support integration with the Gemini language model. This limitation hinders developers who wish to leverage the advanced capabilities of the Gemini model in their applications. Gemini's advanced features, such as configurable parameters (temperature, top_p, top_k), high token limits, and robust chat session management, make it a valuable addition for modern AI-driven workflows.

Proposed Solution:
The smolagent framework can integrate the Gemini model by defining a specialized GeminiLLM class and updating the existing infrastructure to support this new model. Below is an outline of the integration steps:

GeminiLLM Class Definition: The GeminiLLM class encapsulates the configuration and interaction with the Gemini API. It includes methods for initializing the model, sending prompts, and handling stop sequences.

Updating the Agent: Modify the smolagent's architecture to include Gemini as a supported model. This involves adding conditional logic or a plugin-based system to initialize the appropriate LLM based on user preference.
code snippet for GeminiLLM class

class GeminiLLM:
    model_name: str = "gemini-2.0-flash-exp"
    temperature: float = 0.7
    top_p: float = 0.95
    top_k: int = 40
    max_tokens: int = 2048

    def __call__(self, prompt: str, stop: Optional[List[str]] = None, max_tokens: int = 1500) -> str:
        generation_config = {
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "max_output_tokens": max_tokens,
        }

        try:
            logger.debug(f"Initializing GenerativeModel with config: {generation_config}")
            model = genai.GenerativeModel(
                model_name=self.model_name,
                generation_config=generation_config,
            )
            logger.debug("GenerativeModel initialized successfully.")

            chat_session = model.start_chat(history=[])
            logger.debug("Chat session started.")

            response = chat_session.send_message(prompt)
            logger.debug(f"Prompt sent to model: {prompt}")
            logger.debug(f"Raw response received: {response.text}")

            if stop:
                for stop_seq in stop:
                    if stop_seq in response.text:
                        response.text = response.text.split(stop_seq)[0]
                        break

            return response.text.strip()
        except Exception as e:
            logger.error(f"Error generating response with GeminiLLM: {e}")
            logger.debug("Exception details:", exc_info=True)
            raise e

If the community agrees, I can work on implementing this integration and submit a pull request. Feedback and suggestions on this proposal are welcome!

@aymeric-roucher
Copy link
Collaborator

Have you tried using Gemini with LiteLLMModel? Gemini should be supported in LiteLLM.

@aniketqw
Copy link
Author

aniketqw commented Jan 6, 2025

Have you tried using Gemini with LiteLLMModel? Gemini should be supported in LiteLLM.

I wanted it to directly use as we automatically do with hfapiMethod . It is will be very easy if we can directly call gemini model . Moreover in hfapiMethod inference api has a limit for most of the other model
I tried to do with it but did not have any progress (https://docs.litellm.ai/docs/providers/gemini) If you have idea how to do, I would be grateful if you could share it with me.

@aniketqw
Copy link
Author

aniketqw commented Jan 6, 2025

Is it the way to use LIteLLM function for gemini?

import os
import json
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent
import litellm
from typing import List, Dict, Optional
import getpass

if not os.environ.get("GEMINI_API_KEY"):
    os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")

class LiteLLMModel:
    def __init__(
        self,
        model_id="gemini/gemini-pro",
        api_base=None,
        api_key=None,
        **kwargs,
    ):
        self.model_id = model_id
        self.api_base = api_base
        self.api_key = api_key
        self.kwargs = kwargs

    def __call__(
        self,
        messages: List[Dict[str, str]],
        stop_sequences: Optional[List[str]] = None,
        max_tokens: int = 1500,
    ) -> str:
        response = litellm.completion(
            model=self.model_id,
            messages=messages,
            stop=stop_sequences,
            max_tokens=max_tokens,
            api_base=self.api_base,
            api_key=self.api_key,
            **self.kwargs,
        )
        self.last_input_token_count = response.usage.prompt_tokens
        self.last_output_token_count = response.usage.completion_tokens
        return response.choices[0].message.content

    def get_tool_call(
        self,
        messages: List[Dict[str, str]],
        available_tools: List[Dict],
        stop_sequences: Optional[List[str]] = None,
        max_tokens: int = 1500,
    ):
        response = litellm.completion(
            model=self.model_id,
            messages=messages,
            tools=[tool['json_schema'] for tool in available_tools],
            tool_choice="required",
            stop=stop_sequences,
            max_tokens=max_tokens,
            api_base=self.api_base,
            api_key=self.api_key,
            **self.kwargs,
        )
        tool_calls = response.choices[0].message.tool_calls[0]
        self.last_input_token_count = response.usage.prompt_tokens
        self.last_output_token_count = response.usage.completion_tokens
        arguments = json.loads(tool_calls.function.arguments)
        return tool_calls.function.name, arguments, tool_calls.id

model = LiteLLMModel(model_id="gemini/gemini-pro")

tools = [DuckDuckGoSearchTool]

web_agent = CodeAgent(
    tools=tools,
    model=model,
)

managed_web_agent = ManagedAgent(
    agent=web_agent,
    name="web_search",
    description="Runs web search for you. Give it your query as an argument."
)

manager_agent = CodeAgent(
    tools=[],
    model=model,
    managed_agents=[managed_web_agent],
)

def run_with_gemini(query):
    response = model(
        messages=[{"role": "user", "content": query}],
        max_tokens=1500
    )
    return response

query = "How to become a warrior?"
response = run_with_gemini(query)

print("Response from Gemini Model:", response)

@aymeric-roucher
Copy link
Collaborator

To use Gemini, just initializing your model like model = LiteLLMModel("gemini/gemini-pro", api_key="YOUR_GEMINI_API_KEY") should work directly.

@aniketqw
Copy link
Author

aniketqw commented Jan 7, 2025

To use Gemini, just initializing your model like model = LiteLLMModel("gemini/gemini-pro", api_key="YOUR_GEMINI_API_KEY") should work directly.

Thanks for your assistance it is working correctly. I have just one question how my code is different (if any)functionality wise from the current api call ?

@aymeric-roucher
Copy link
Collaborator

Your code might be re-implementing the base functionalities correctly! I'd advise you to check against the LiteLLMModel source code to compare!

@aniketqw
Copy link
Author

aniketqw commented Jan 8, 2025

Your code might be re-implementing the base functionalities correctly! I'd advise you to check against the LiteLLMModel source code to compare!

Thanks for the help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants