Skip to content

Commit

Permalink
feat(models): add YAML configuration for MMMU random validation task …
Browse files Browse the repository at this point in the history
…and enhance GPT4V_MMMU model response handling
  • Loading branch information
pufanyi committed Dec 16, 2024
1 parent 8abc737 commit 46c4a0c
Show file tree
Hide file tree
Showing 6 changed files with 1,146 additions and 47 deletions.
103 changes: 56 additions & 47 deletions lmms_eval/models/gpt4v_mmmu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from pathlib import Path
from typing import List, Tuple, Union

import crawl4ai
import numpy as np
import openai
import requests as url_requests
from accelerate import Accelerator, DistributedType
from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode
from duckduckgo_search import DDGS
from pydantic import BaseModel
from tqdm import tqdm
Expand Down Expand Up @@ -55,13 +57,14 @@ def __str__(self):

class SearchResponse(BaseModel):
reason: list[str]
need_search: bool
# use_image_search: bool
search_text: str


class FinalResponse(BaseModel):
steps: list[str]
result: str
final_result: str


SIMPLE_QA_PROMPT = """\
Expand Down Expand Up @@ -100,6 +103,8 @@ class FinalResponse(BaseModel):
Note: Your search should address the model's uncertainties rather than simply searching the question directly.
If you think there is nothing uncertain, then there is no need to search.
For example:
If the model's response shows uncertainty about a specific detail or concept, focus your search on clarifying that uncertainty rather than repeating the original question.
Expand Down Expand Up @@ -214,11 +219,13 @@ def generate_until(self, requests) -> List[str]:
doc_uuid = f"{task}___{split}___{doc_id}"
if doc_uuid in self.response_cache:
response_text = self.response_cache[doc_uuid]
if response_text and len(response_text) < 16:
if response_text:
res.append(response_text)
pbar.update(1)
continue

print(f"Generating for {task} {split} {doc_id}")

try:
search_dir = Path("temp") / "search" / task / split / str(doc_id)
if not search_dir.exists():
Expand All @@ -242,7 +249,7 @@ def generate_until(self, requests) -> List[str]:
{"role": "user", "content": [{"type": "text", "text": contexts}] + image_contents},
],
response_format=SimpleResponse,
max_tokens=4096,
max_tokens=1024,
)
.choices[0]
.message.parsed
Expand All @@ -258,7 +265,7 @@ def generate_until(self, requests) -> List[str]:
},
{"role": "user", "content": [{"type": "text", "text": REVIEW_PROMPT.format(question=contexts, response=str(simple_response))}] + image_contents},
],
max_tokens=4096,
max_tokens=1024,
)
.choices[0]
.message.content
Expand All @@ -275,7 +282,7 @@ def generate_until(self, requests) -> List[str]:
{"role": "user", "content": [{"type": "text", "text": REQUERY_PROMPT.format(question=contexts, response=str(simple_response), review=review)}] + image_contents},
],
response_format=SearchResponse,
max_tokens=4096,
max_tokens=1024,
)
.choices[0]
.message.parsed
Expand All @@ -295,63 +302,65 @@ def generate_until(self, requests) -> List[str]:
f,
)

# Search using DuckDuckGo and get first result URL
ddgs = DDGS(timeout=50)
news_results = ddgs.text(keywords=search_content, region="wt-wt", safesearch="off", timelimit="m", max_results=10)
urls = [news["href"] for news in news_results]
search_image_contents = []
if requery.need_search:
# Search using DuckDuckGo and get first result URL
ddgs = DDGS(timeout=50)
news_results = ddgs.text(keywords=search_content, region="wt-wt", safesearch="off", timelimit="m", max_results=10)
urls = [news["href"] for news in news_results]

if urls:
# Take screenshot of the first 3 webpages

for url_idx, url in enumerate(urls[:3]):
try:
# Create and run async screenshot capture
import asyncio

if urls:
# Take screenshot of the first 3 webpages
import selenium.webdriver
from selenium.webdriver.chrome.options import Options
async def capture(screenshot_path):
browser_config = BrowserConfig(viewport_width=768, viewport_height=2048)
async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun(url=url, screenshot=True, screenshot_wait_for=2.0, simulate_user=True, magic=True, cache_mode=CacheMode.BYPASS)

chrome_options = Options()
chrome_options.add_argument("--headless")
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--window-size=1024,1024")
if result.screenshot is None:
return

search_image_contents = []
with open(screenshot_path, "wb") as f:
f.write(base64.b64decode(result.screenshot))

for url_idx, url in enumerate(urls[:3]):
try:
driver = selenium.webdriver.Chrome(options=chrome_options)
driver.get(url)
img = Image.open(screenshot_path)
if img.height > 2048:
img = img.crop((0, 0, img.width, 2048))
img.save(screenshot_path)

# Take 3 screenshots while scrolling down
for i in range(3):
# Scroll down
if i > 0:
driver.execute_script(f"window.scrollTo(0, {1024 * i})")
time.sleep(1) # Wait for content to load
# Convert PIL Image to base64
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()

screenshot_path = search_dir / f"search_result_{url_idx}_{i}.png"
search_image_contents.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})

driver.save_screenshot(screenshot_path)
screenshot_path = search_dir / f"search_result_{url_idx}.png"

# Load and encode screenshot
with open(screenshot_path, "rb") as f:
screenshot_bytes = f.read()
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
search_image_contents.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{screenshot_b64}"}})
asyncio.run(capture(screenshot_path))

driver.quit()
except Exception as e:
print(f"Error occurred: {e}")
continue
except Exception as e:
print(f"Error occurred: {e}")
continue

final_response = (
client.beta.chat.completions.parse(
model=self.model_version,
messages=[
{
"role": "user",
"content": [{"type": "text", "text": FINAL_PROMPT.format(question=contexts, response=str(simple_response), review=review, requery=requery)}]
+ [{"type": "text", "text": "Here is the image for the query."}]
+ image_contents
+ [{"type": "text", "text": "Here are the search results."}]
"content": [{"type": "text", "text": FINAL_PROMPT.format(question=contexts, response=str(simple_response), review=review, requery=requery)}] + [{"type": "text", "text": "Here are the search results."}]
if search_image_contents
else [] + search_image_contents + [{"type": "text", "text": f'In "result", you need to directly answer the question as concise as possible: {contexts}'}],
else []
+ search_image_contents
+ [{"type": "text", "text": f'In "steps", you need to answer the question step by step. In "final_result", you need to directly answer the question as concise as possible, only one simple phrase.'}]
+ [{"type": "text", "text": "[Question]\n\n" + contexts}]
+ image_contents,
},
],
response_format=FinalResponse,
Expand All @@ -361,11 +370,11 @@ def generate_until(self, requests) -> List[str]:
.message.parsed
)

res.append(final_response.result)
res.append(final_response.final_result)

if self.continual_mode is True: # Cache the response
doc_uuid = f"{task}___{split}___{doc_id}"
self.response_cache[doc_uuid] = final_response.result
self.response_cache[doc_uuid] = final_response.final_result
with open(self.response_persistent_file, "w") as f:
json.dump(self.response_cache, f)

Expand Down
3 changes: 3 additions & 0 deletions lmms_eval/tasks/mmmu_random_100/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
metadata:
version: 0.0
interleaved_format: false
Binary file added lmms_eval/tasks/mmmu_random_100/arial.ttf
Binary file not shown.
16 changes: 16 additions & 0 deletions lmms_eval/tasks/mmmu_random_100/mmmu_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
dataset_path: lmms-lab/MMMU_random_100
task: "mmmu_val_random_100"
test_split: validation
output_type: generate_until
doc_to_visual: !function utils.mmmu_doc_to_visual
doc_to_text: !function utils.mmmu_doc_to_text
doc_to_target: "answer"
# The return value of process_results will be used by metrics
process_results: !function utils.mmmu_process_results

metric_list:
- metric: mmmu_acc
aggregation: !function utils.mmmu_aggregate_results
higher_is_better: true

include: _default_template_yaml
Loading

0 comments on commit 46c4a0c

Please sign in to comment.