-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Fix, Feat] Solve llava wild issue when task num can't divide, add fu…
…yu ppl (#37) * Add fuyu ppl, fix llava-bench gather issue * Add gpt4V * Black lint
- Loading branch information
Showing
5 changed files
with
215 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
from .otterhd import OtterHD | ||
from .qwen_vl import Qwen_VL | ||
from .fuyu import Fuyu | ||
from .gpt4v import GPT4V |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from io import BytesIO | ||
import os | ||
import base64 | ||
from typing import List, Tuple | ||
from tqdm import tqdm | ||
import requests as url_requests | ||
import time | ||
import logging | ||
|
||
from lmms_eval.api.instance import Instance | ||
from lmms_eval.api.model import lmms | ||
from lmms_eval.api.registry import register_model | ||
from lmms_eval import utils | ||
|
||
from PIL import Image | ||
|
||
API_TYPE = os.getenv("API_TYPE", "openai") | ||
NUM_SECONDS_TO_SLEEP = 5 | ||
eval_logger = logging.getLogger("lmms-eval") | ||
|
||
if API_TYPE == "openai": | ||
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") | ||
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") | ||
headers = { | ||
"Authorization": f"Bearer {API_KEY}", | ||
"Content-Type": "application/json", | ||
} | ||
elif API_TYPE == "azure": | ||
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") | ||
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") | ||
headers = { | ||
"api-key": API_KEY, | ||
"Content-Type": "application/json", | ||
} | ||
|
||
|
||
@register_model("gpt4V") | ||
class GPT4V(lmms): | ||
def __init__(self, **kwargs) -> None: | ||
super().__init__() | ||
|
||
# Function to encode the image | ||
def encode_image(self, image: Image): | ||
output_buffer = BytesIO() | ||
image.save(output_buffer, format="JPEG") | ||
byte_data = output_buffer.getvalue() | ||
base64_str = base64.b64encode(byte_data).decode("utf-8") | ||
return base64_str | ||
|
||
def flatten(self, input): | ||
new_list = [] | ||
for i in input: | ||
for j in i: | ||
new_list.append(j) | ||
return new_list | ||
|
||
def generate_until(self, requests) -> List[str]: | ||
res = [] | ||
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") | ||
|
||
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: | ||
# encode, pad, and truncate contexts for this batch | ||
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] | ||
visuals = self.flatten(visuals) | ||
|
||
payload = {"model": "gpt-4-vision-preview", "messages": [{"role": "user", "content": []}]} | ||
payload["messages"][0]["content"].append({"type": "text", "text": contexts}) | ||
|
||
for visual in visuals: | ||
img = self.encode_image(visual) | ||
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) | ||
if "max_new_tokens" not in gen_kwargs: | ||
gen_kwargs["max_new_tokens"] = 1024 | ||
if "temperature" not in gen_kwargs: | ||
gen_kwargs["temperature"] = 0 | ||
if "top_p" not in gen_kwargs: | ||
gen_kwargs["top_p"] = None | ||
if "num_beams" not in gen_kwargs: | ||
gen_kwargs["num_beams"] = 1 | ||
|
||
# payload["max_tokens"] = gen_kwargs["max_new_tokens"] | ||
# payload["temperature"] = gen_kwargs["temperature"] | ||
|
||
for attempt in range(5): | ||
try: | ||
response = url_requests.post(API_URL, headers=headers, json=payload) | ||
response_data = response.json() | ||
|
||
content = response_data["choices"][0]["message"]["content"].strip() | ||
break # If successful, break out of the loop | ||
|
||
except Exception as e: | ||
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}") | ||
if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt | ||
time.sleep(NUM_SECONDS_TO_SLEEP) | ||
else: # If this was the last attempt, log and return empty | ||
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}") | ||
content = "" | ||
res.append(content) | ||
return res | ||
|
||
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: | ||
# TODO | ||
assert False, "GPT4V not support" | ||
|
||
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: | ||
# TODO | ||
assert False, "GPT4V not support" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters