Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added VQA as a evaluator. #52

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
from test.test_utils import evaluate_fuzzy_match
from test.test_utils import evaluate_must_include
from test.test_utils import evaluate_ua_match
from test.test_utils import list_items_in_folder
from test.test_utils import compress_png
from typing import Any

from ae.utils.logger import logger
from playwright.sync_api import CDPSession
from playwright.sync_api import Page
from termcolor import colored

import os
from .validation_agent.validator import validate_task_vqa

class Evaluator:
"""Base class for evaluation strategies.
Expand Down Expand Up @@ -396,6 +400,46 @@ async def __call__(
return {"score": score, "reason": reason} # type: ignore


class VQAEvaluator(Evaluator):
async def __call__(
self,
task_config: dict[str, Any],
page: Page,
client: CDPSession,
answer: str
) -> float:
"""Evaluates the current task using a VQA model
Parameters:
task_config (dict[str, Any]): The task configuration containing evaluation criteria.
page (Page): The Playwright page object for the current webpage.
client (CDPSession | None, optional): The Chrome DevTools Protocol session object.
answer (str | None, optional): Not used in this evaluator.
Returns:
float: 0.0 for failure and 1.0 if the VQA evaluates the task as complete
"""
task_id = task_config["task_id"]
task = task_config["intent"]
state_seq: list[Any] = []
score = -1.0

# Get path to screenshots for the given task
test_folder = list_items_in_folder(f"{os. getcwd()}/test/logs/")[-1] # Get the most recent log folder, this may take look for the wrong folder TODO: fix to take correct folder
path_to_screenshots = f"{os. getcwd()}/test/logs/{test_folder}/logs_for_task_{task_id}/snapshots"
screenshot_names = list_items_in_folder(path_to_screenshots) # type: ignore

# Load and compress screenshots
for screenshot_name in screenshot_names:
screenshot_path = f"{path_to_screenshots}/{screenshot_name}"
compress_png(screenshot_path)
state_seq.append({"id":task_id, "path_to_screenshot": f"{path_to_screenshots}/{screenshot_name}"})

#Calculate VQA Score
score_dict = validate_task_vqa(state_seq, task) # type: ignore
score = score_dict["pred_task_completed"]

print(f"VQA score is {score}")
return {"score": score}

def evaluator_router(task_config: dict[str, Any]) -> EvaluatorComb:
"""Creates and configures a composite evaluator based on the evaluation types specified in the configuration file.

Expand Down Expand Up @@ -425,6 +469,9 @@ def evaluator_router(task_config: dict[str, Any]) -> EvaluatorComb:
case "manual":
logger.info("Adding manual evaluator")
evaluators.append(ManualContentEvaluator())
case "vqa":
logger.info("Adding vqa evaluator")
evaluators.append(VQAEvaluator())
case _:
raise ValueError(f"eval_type {eval_type} is not supported")

Expand Down
34 changes: 34 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from pathlib import Path
from typing import Any
from PIL import Image

from dotenv import load_dotenv
from nltk.tokenize import word_tokenize # type: ignore
Expand Down Expand Up @@ -261,3 +262,36 @@ def get_formatted_current_timestamp(format: str = "%Y-%m-%d %H:%M:%S") -> str:
# Format the timestamp as a human-readable string
timestamp_str = current_time.strftime(format)
return timestamp_str

def list_items_in_folder(path):
try:
items = os.listdir(path)
items_with_mtime = [(item, os.path.getmtime(os.path.join(path, item))) for item in items]
items_with_mtime.sort(key=lambda x: x[1])
sorted_items = [item for item, mtime in items_with_mtime]
return sorted_items
except FileNotFoundError:
return f"The path {path} does not exist."
except NotADirectoryError:
return f"The path {path} is not a directory."
except PermissionError:
return f"Permission denied to access {path}."

def compress_png(file_path, max_size_mb=20, reduce_factor=0.9):
try:
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
while file_size_mb > max_size_mb:
print(f"Compressing {file_path} (Initial Size: {file_size_mb:.2f} MB)")
with Image.open(file_path) as img:
width, height = img.size
new_width = int(width * reduce_factor)
new_height = int(height * reduce_factor)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
img.save(file_path, optimize=True)
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
print(f"Resized to: {new_width}x{new_height}, Size: {file_size_mb:.2f} MB")
print(f"Final Size of {file_path}: {file_size_mb:.2f} MB")
return file_size_mb <= max_size_mb
except Exception as e:
print(f"Error compressing {file_path}: {e}")
return False