Skip to content

Commit

Permalink
Made litellm judge backend more robust. (#485)
Browse files Browse the repository at this point in the history
* Made litellm judge backend more robust.

* Added failed flag to ModelResponse.

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
JoelNiklaus and clefourrier authored Jan 7, 2025
1 parent 2073a29 commit fdb12f4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
29 changes: 20 additions & 9 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from tqdm import tqdm

from lighteval.models.model_output import ModelResponse
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available


Expand Down Expand Up @@ -195,20 +196,30 @@ def __call_litellm(self, prompts):
def __call_api(prompt):
for _ in range(self.API_MAX_RETRY):
try:
response = litellm.completion(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
caching=True,
)
kwargs = {
"model": self.model,
"messages": prompt,
"response_format": {"type": "text"},
"max_tokens": 512,
"n": 1,
"caching": True,
}
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
if not text or response.failed:
kwargs["caching"] = False
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
if not text or response.failed:
# Just return an error response if the second attempt fails too
return ModelResponse(
text="Failed to get response from the API.", model=self.model, failed=True
)
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
raise Exception("Failed to get response from the API")
return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True)

results = []
with ThreadPoolExecutor(100) as executor:
Expand Down
1 change: 1 addition & 0 deletions src/lighteval/models/model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ModelResponse:
generated_tokens: list[int] = field(default_factory=list) # model generations
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
padded_tokens_count: Optional[int] = 0 # How many tokens of padding
failed: bool = False

def get_result_for_eval(self):
raise NotImplementedError()
Expand Down

0 comments on commit fdb12f4

Please sign in to comment.