Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for google gemini #53

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
# Needle In A Haystack - Pressure Testing LLMs

This repository is a fork of Greg Kamradt code https://twitter.com/GregKamradt.

Original Repository: https://github.com/gkamradt/LLMTest_NeedleInAHaystack

This fork adds support for using **Google Gemini** models as the provider and/or evaluator.

```zsh
needlehaystack.run_test --provider google --evaluator google --model_name "gemini-1.5-pro" --evaluator_model_name "gemini-1.5-pro" --document_depth_percents "[50]" --context_lengths "[200]"
```

### Original readme below
A simple 'needle in a haystack' analysis to test in-context retrieval ability of long context LLMs.

Supported model providers: OpenAI, Anthropic, Cohere
Supported model providers: **Google**, OpenAI, Anthropic, Cohere

Get the behind the scenes on the [overview video](https://youtu.be/KwRRuiCCdmc).

Expand Down Expand Up @@ -32,7 +43,7 @@ source venv/bin/activate
### Environment Variables

- `NIAH_MODEL_API_KEY` - API key for interacting with the model. Depending on the provider, this gets used appropriately with the correct sdk.
- `NIAH_EVALUATOR_API_KEY` - API key to use if `openai` evaluation strategy is used.
- `NIAH_EVALUATOR_API_KEY` - API key to use if `openai` or `google` evaluation strategy is used.

### Install Package

Expand All @@ -46,10 +57,10 @@ pip install needlehaystack

Start using the package by calling the entry point `needlehaystack.run_test` from command line.

You can then run the analysis on OpenAI, Anthropic, or Cohere models with the following command line arguments:
You can then run the analysis on Google Gemini, OpenAI, Anthropic, or Cohere models with the following command line arguments:

- `provider` - The provider of the model, available options are `openai`, `anthropic`, and `cohere`. Defaults to `openai`
- `evaluator` - The evaluator, which can either be a `model` or `LangSmith`. See more on `LangSmith` below. If using a `model`, only `openai` is currently supported. Defaults to `openai`.
- `provider` - The provider of the model, available options are `openai`, `anthropic`, `google`, and `cohere`. Defaults to `openai`
- `evaluator` - The evaluator, which can either be a `openai`, `google`, or `LangSmith`. See more on `LangSmith` below. Defaults to `openai`.
- `model_name` - Model name of the language model accessible by the provider. Defaults to `gpt-3.5-turbo-0125`
- `evaluator_model_name` - Model name of the language model accessible by the evaluator. Defaults to `gpt-3.5-turbo-0125`

Expand Down
1 change: 1 addition & 0 deletions needlehaystack/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .evaluator import Evaluator
from .openai import OpenAIEvaluator
from .langsmith import LangSmithEvaluator
from .google import GoogleEvaluator
66 changes: 66 additions & 0 deletions needlehaystack/evaluators/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os

from .evaluator import Evaluator

from langchain.evaluation import load_evaluator
from langchain_google_genai import ChatGoogleGenerativeAI

class GoogleEvaluator(Evaluator):
DEFAULT_MODEL_KWARGS: dict = dict(temperature=0)
CRITERIA = {"accuracy": """
Score 1: The answer is completely unrelated to the reference.
Score 3: The answer has minor relevance but does not align with the reference.
Score 5: The answer has moderate relevance but contains inaccuracies.
Score 7: The answer aligns with the reference but has minor omissions.
Score 10: The answer is completely accurate and aligns perfectly with the reference.
Only respond with a numberical score"""}

def __init__(self,
model_name: str = "gemini-1.5-pro",
model_kwargs: dict = DEFAULT_MODEL_KWARGS,
true_answer: str = None,
question_asked: str = None,):
"""
:param model_name: The name of the model.
:param model_kwargs: Model configuration. Default is {temperature: 0}
:param true_answer: The true answer to the question asked.
:param question_asked: The question asked to the model.
"""

if (not true_answer) or (not question_asked):
raise ValueError("true_answer and question_asked must be supplied with init.")

self.model_name = model_name
self.model_kwargs = model_kwargs
self.true_answer = true_answer
self.question_asked = question_asked

api_key = os.getenv('NIAH_EVALUATOR_API_KEY')
if (not api_key):
raise ValueError("NIAH_EVALUATOR_API_KEY must be in env for using google evaluator.")

self.api_key = api_key

self.evaluator = ChatGoogleGenerativeAI(model=self.model_name,
google_api_key=self.api_key,
**self.model_kwargs)

def evaluate_response(self, response: str) -> int:
evaluator = load_evaluator(
"labeled_score_string",
criteria=self.CRITERIA,
llm=self.evaluator,
)

eval_result = evaluator.evaluate_strings(
# The models response
prediction=response,

# The actual answer
reference=self.true_answer,

# The question asked
input=self.question_asked,
)

return int(eval_result['score'])
1 change: 1 addition & 0 deletions needlehaystack/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .anthropic import Anthropic
from .cohere import Cohere
from .google import Google
from .model import ModelProvider
from .openai import OpenAI
171 changes: 171 additions & 0 deletions needlehaystack/providers/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import os
from operator import itemgetter
import pkg_resources
import requests
from typing import Optional

import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
import sentencepiece

from .model import ModelProvider


class Google(ModelProvider):
"""
A wrapper class for interacting with Google's Gemini API, providing methods to encode text, generate prompts,
evaluate models, and create LangChain runnables for language model interactions.

Attributes:
model_name (str): The name of the Google model to use for evaluations and interactions.
model: An instance of the Google Gemini client for API calls.
tokenizer: A tokenizer instance for encoding and decoding text to and from token representations.
"""

DEFAULT_MODEL_KWARGS: dict = dict(max_output_tokens=300,
temperature=0)
VOCAB_FILE_URL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model"

def __init__(self,
model_name: str = "gemini-1.5-pro",
model_kwargs: dict = DEFAULT_MODEL_KWARGS,
vocab_file_url: str = VOCAB_FILE_URL):
"""
Initializes the Google model provider with a specific model.

Args:
model_name (str): The name of the Google model to use. Defaults to 'gemini-1.5-pro'.
model_kwargs (dict): Model configuration. Defaults to {max_tokens: 300, temperature: 0}.
vocab_file_url (str): Sentencepiece model file that defines tokenization vocabulary. Deafults to gemma
tokenizer https://github.com/google/gemma_pytorch/blob/main/tokenizer/tokenizer.model

Raises:
ValueError: If NIAH_MODEL_API_KEY is not found in the environment.
"""
api_key = os.getenv('NIAH_MODEL_API_KEY')
if (not api_key):
raise ValueError("NIAH_MODEL_API_KEY must be in env.")

self.model_name = model_name
self.model_kwargs = model_kwargs
self.api_key = api_key
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(self.model_name)

local_vocab_file = 'tokenizer.model'
if not os.path.exists(local_vocab_file):
response = requests.get(vocab_file_url) # Download Tokenizer Vocab File (4MB)
response.raise_for_status()

with open(local_vocab_file, 'wb') as f:
for chunk in response.iter_content():
f.write(chunk)
self.tokenizer = sentencepiece.SentencePieceProcessor(local_vocab_file)

resource_path = pkg_resources.resource_filename('needlehaystack', 'providers/Anthropic_prompt.txt')

# Generate the prompt structure for the model
# Replace the following file with the appropriate prompt structure
with open(resource_path, 'r') as file:
self.prompt_structure = file.read()

async def evaluate_model(self, prompt: str) -> str:
"""
Evaluates a given prompt using the Google model and retrieves the model's response.

Args:
prompt (str): The prompt to send to the model.

Returns:
str: The content of the model's response to the prompt.
"""
response = await self.model.generate_content_async(
prompt,
generation_config=self.model_kwargs
)

return response.text

def generate_prompt(self, context: str, retrieval_question: str) -> str:
"""
Generates a structured prompt for querying the model, based on a given context and retrieval question.

Args:
context (str): The context or background information relevant to the question.
retrieval_question (str): The specific question to be answered by the model.

Returns:
str: The text prompt
"""
return self.prompt_structure.format(
retrieval_question=retrieval_question,
context=context)

def encode_text_to_tokens(self, text: str) -> list[int]:
"""
Encodes a given text string to a sequence of tokens using the model's tokenizer.

Args:
text (str): The text to encode.

Returns:
list[int]: A list of token IDs representing the encoded text.
"""
return self.tokenizer.encode(text)

def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
"""
Decodes a sequence of tokens back into a text string using the model's tokenizer.

Args:
tokens (list[int]): The sequence of token IDs to decode.
context_length (Optional[int], optional): An optional length specifying the number of tokens to decode. If not provided, decodes all tokens.

Returns:
str: The decoded text string.
"""
return self.tokenizer.decode(tokens[:context_length])

def get_langchain_runnable(self, context: str) -> str:
"""
Creates a LangChain runnable that constructs a prompt based on a given context and a question,
queries the Google model, and returns the model's response. This method leverages the LangChain
library to build a sequence of operations: extracting input variables, generating a prompt,
querying the model, and processing the response.

Args:
context (str): The context or background information relevant to the user's question.
This context is provided to the model to aid in generating relevant and accurate responses.

Returns:
str: A LangChain runnable object that can be executed to obtain the model's response to a
dynamically provided question. The runnable encapsulates the entire process from prompt
generation to response retrieval.

Example:
To use the runnable:
- Define the context and question.
- Execute the runnable with these parameters to get the model's response.
"""

template = """You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n
\n ------- \n
{context}
\n ------- \n
Here is the user question: \n --- --- --- \n {question} \n Don't give information outside the document or repeat your findings."""

prompt = PromptTemplate(
template=template,
input_variables=["context", "question"],
)
# Create a LangChain runnable
model = ChatGoogleGenerativeAI(temperature=0, model=self.model_name)
chain = ({"context": lambda x: context,
"question": itemgetter("question")}
| prompt
| model
)
return chain


10 changes: 8 additions & 2 deletions needlehaystack/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jsonargparse import CLI

from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
from .providers import Anthropic, ModelProvider, OpenAI, Cohere
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator, GoogleEvaluator
from .providers import Anthropic, ModelProvider, OpenAI, Cohere, Google

load_dotenv()

Expand Down Expand Up @@ -65,6 +65,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider:
return Anthropic(model_name=args.model_name)
case "cohere":
return Cohere(model_name=args.model_name)
case "google":
return Google(model_name=args.model_name)
case _:
raise ValueError(f"Invalid provider: {args.provider}")

Expand All @@ -86,6 +88,10 @@ def get_evaluator(args: CommandArgs) -> Evaluator:
return OpenAIEvaluator(model_name=args.evaluator_model_name,
question_asked=args.retrieval_question,
true_answer=args.needle)
case "google":
return GoogleEvaluator(model_name=args.evaluator_model_name,
question_asked=args.retrieval_question,
true_answer=args.needle)
case "langsmith":
return LangSmithEvaluator()
case _:
Expand Down
15 changes: 9 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ distro==1.8.0
filelock==3.13.1
frozenlist==1.4.0
fsspec==2023.10.0
google-generativeai==0.7.2
h11==0.14.0
httpcore==1.0.2
httpx==0.25.2
Expand All @@ -20,13 +21,14 @@ idna==3.6
jsonargparse==4.27.5
jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.9
langchain-community>=0.0.24
langchain-core>=0.1.26
langchain==0.2.9
langchain-anthropic==0.1.13
langchain-cohere==0.1.5
langchain-community==0.2.7
langchain-core==0.2.21
langchain-google-genai==1.0.7
langchain-openai==0.1.7
langsmith>=0.1.8
langchain_openai
langchain_anthropic
langchain_cohere
marshmallow==3.20.1
multidict==6.0.4
mypy-extensions==1.0.0
Expand All @@ -50,3 +52,4 @@ typing_extensions==4.8.0
urllib3==2.1.0
yarl==1.9.3
pytest==8.1.1
sentencepiece==0.2.0