Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added VQA as a evaluator. #52

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions test/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
from test.test_utils import evaluate_fuzzy_match
from test.test_utils import evaluate_must_include
from test.test_utils import evaluate_ua_match
from test.test_utils import list_items_in_folder
from test.test_utils import compress_png
from typing import Any

from ae.utils.logger import logger
from playwright.sync_api import CDPSession
from playwright.sync_api import Page
from termcolor import colored

import os
from .validation_agent.validator import validate_task_vqa
from ae.config import PROJECT_ROOT, PROJECT_TEST_ROOT
TEST_LOGS = os.path.join(PROJECT_TEST_ROOT, 'logs')

class Evaluator:
"""Base class for evaluation strategies.
Expand Down Expand Up @@ -396,6 +402,47 @@ async def __call__(
return {"score": score, "reason": reason} # type: ignore


class VQAEvaluator(Evaluator):
async def __call__(
self,
task_config: dict[str, Any],
page: Page,
client: CDPSession,
answer: str
) -> float:
"""Evaluates the current task using a VQA model
Parameters:
task_config (dict[str, Any]): The task configuration containing evaluation criteria.
page (Page): The Playwright page object for the current webpage.
client (CDPSession | None, optional): The Chrome DevTools Protocol session object.
answer (str | None, optional): Not used in this evaluator.
Returns:
float: 0.0 for failure and 1.0 if the VQA evaluates the task as complete
"""
task_id = task_config["task_id"]
task = task_config["intent"]
state_seq: list[Any] = []
score = -1.0

# Get path to screenshots for the given task
test_folder = list_items_in_folder(TEST_LOGS)[-1] # Get the most recent log folder
path_to_screenshots = f"{TEST_LOGS}/{test_folder}/logs_for_task_{task_id}/snapshots"
screenshot_names = list_items_in_folder(path_to_screenshots) # type: ignore

# Load and compress screenshots
for screenshot_name in screenshot_names:
screenshot_path = f"{path_to_screenshots}/{screenshot_name}"
compress_png(screenshot_path)
state_seq.append({"id":task_id, "path_to_screenshot": f"{path_to_screenshots}/{screenshot_name}"})

#Calculate VQA Score
score_dict = validate_task_vqa(state_seq, task) # type: ignore
score = score_dict["pred_task_completed"]
reason = score_dict["pred_rationale"]

print(f"VQA score is {score} becauase {reason}\n ")
return {"score": score, "reason": reason}

def evaluator_router(task_config: dict[str, Any]) -> EvaluatorComb:
"""Creates and configures a composite evaluator based on the evaluation types specified in the configuration file.

Expand Down Expand Up @@ -425,6 +472,9 @@ def evaluator_router(task_config: dict[str, Any]) -> EvaluatorComb:
case "manual":
logger.info("Adding manual evaluator")
evaluators.append(ManualContentEvaluator())
case "vqa":
logger.info("Adding vqa evaluator")
evaluators.append(VQAEvaluator())
case _:
raise ValueError(f"eval_type {eval_type} is not supported")

Expand Down
56 changes: 56 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from datetime import datetime
from pathlib import Path
from typing import Any
from PIL import Image

from dotenv import load_dotenv
from nltk.tokenize import word_tokenize # type: ignore
from openai import OpenAI
from pyparsing import str_type

load_dotenv()
client = OpenAI()
Expand Down Expand Up @@ -261,3 +263,57 @@ def get_formatted_current_timestamp(format: str = "%Y-%m-%d %H:%M:%S") -> str:
# Format the timestamp as a human-readable string
timestamp_str = current_time.strftime(format)
return timestamp_str

def list_items_in_folder(path:str_type)-> list[str] | None:
'''Returns all items inside a given file directory

Parameters:
path (str): Path to a directory.

Return:
list[str]: Name of all items found in the given directory.
'''
try:
items = os.listdir(path)
items_with_mtime = [(item, os.path.getmtime(os.path.join(path, item))) for item in items]
items_with_mtime.sort(key=lambda x: x[1])
sorted_items = [item for item, mtime in items_with_mtime]
return sorted_items
except FileNotFoundError:
print(f"The path {path} does not exist.")
return None
except NotADirectoryError:
print(f"The path {path} is not a directory.")
return None
except PermissionError:
print(f"Permission denied to access {path}.")
return None

def compress_png(file_path, max_size_mb=20, reduce_factor=0.9):
''' Compresses a png file
Parameters:
file_path (str): Path to a png file
max_size_mb (int): The maximum size allowed after compression
reduce_factor (int): Amount the png is reduced each iteration

Return:
bool: True if the png was compressed successfully. False otherwise.

'''
try:
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
while file_size_mb > max_size_mb:
print(f"Compressing {file_path} (Initial Size: {file_size_mb:.2f} MB)")
with Image.open(file_path) as img:
width, height = img.size
new_width = int(width * reduce_factor)
new_height = int(height * reduce_factor)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
img.save(file_path, optimize=True)
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
print(f"Resized to: {new_width}x{new_height}, Size: {file_size_mb:.2f} MB")
print(f"Final Size of {file_path}: {file_size_mb:.2f} MB")
return file_size_mb <= max_size_mb
except Exception as e:
print(f"Error compressing {file_path}: {e}")
return False
Empty file.
73 changes: 73 additions & 0 deletions test/validation_agent/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
prompt__validate_action: str = lambda task_action: f"""# Task
You are an RPA bot that navigates digital UIs like a human. Your job is to validate that a certain action was successfully taken.

# Action
The action that was supposed to be taken was: "{task_action}"

# Question

The first screenshot shows the digital UI BEFORE the action was supposedly taken.
The second screenshot shows the digital UI AFTER the action was supposedly taken.

Given the change between the screenshots, was the action successfully taken? Be lenient and assume that the action was taken if the UI is "close enough" to the expected UI.

Answer in the JSON format:
{{
"rationale": <rationale>,
"was_taken": <true/false>
}}

Answer:"""

prompt__validate_task__intro: str = lambda task_descrip: f"""# Task
Your job is to decide whether the workflow was successfully completed, as depicted by the following sequence of screenshots.

# Workflow

The workflow is: "{task_descrip}"

# User Interface

The workflow was executed within the web application shown in the screenshots.

# Workflow Demonstration

You are given the following sequence of screenshots which were sourced from a demonstration of the workflow.
The screenshots are presented in chronological order.

Here are the screenshots of the workflow:"""

prompt__validate_task__close: str = lambda : f"""
# Instructions

Given what you observe in the previous sequence of screenshots, was the workflow successfully completed?
If the workflow is asking a question, consider it completed successfully if you could deduce the answer to the question by viewing the screenshots.
If the workflow was completed successfully, then set `was_completed` to `true`

Provide your answer as a JSON dictionary with the following format:
{{
"rationale": <rationale>,
"was_completed": <true/false>
}}

Please write your JSON below:
"""

prompt__validate_VQA_task__close: str = lambda : f"""
# Instructions

Given what you observe in the previous sequence of screenshots, was the workflow successfully completed?
To determine this, derive few visual questions from the task description that upon answering will help decide if the workflow was successfully completed.
If the workflow is asking a question, consider it completed successfully if you could deduce the answer to the question by viewing the screenshots.
If the workflow was completed successfully, then set `was_completed` to `true`.
Also, provide the visual questions and their answers as part of the response.

Provide your answer as a JSON dictionary with the following format:
{{
"visual_questions": <list of visual questions and their answers>,
"rationale": <rationale>,
"was_completed": <true/false>
}}

Please write your JSON below:
"""
83 changes: 83 additions & 0 deletions test/validation_agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
### Subset of helper functions from eclair-agents
import json
import subprocess
import time
import openai
import base64
import sys
import traceback
from typing import Dict, Any, Tuple, List

SYSTEM_PROMPT: str = "You are a helpful assistant that automates digital workflows."

def encode_image(path_to_img: str):
"""Base64 encode an image"""
with open(path_to_img, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")


def load_screenshot_for_state(state: Dict[str, Any]) -> Tuple[str, str]:
path_to_screenshot: str = state["path_to_screenshot"]
encoded_image: str = encode_image(path_to_screenshot)
return path_to_screenshot, encoded_image

def fetch_openai_vision_completion(
prompt: str, base64_images: List[str], **kwargs
) -> str:
"""Helper function to call OpenAI's Vision API. Handles rate limit errors and other exceptions"""
messages: List[Any] = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img}"},
}
for img in base64_images
]
+ [{"type": "text", "text": prompt}],
},
]
return _fetch_openai_completion(messages, model="gpt-4-vision-preview", **kwargs)


def _fetch_openai_completion(messages: List[Any], model: str, **kwargs) -> str:
"""Helper function to call OpenAI's Vision API. Handles rate limit errors and other exceptions"""
client = openai.OpenAI()
try:
response = client.chat.completions.create(
messages=[{"role": "system", "content": SYSTEM_PROMPT}] + messages,
model=model,
max_tokens=4096,
**kwargs,
)
except openai.RateLimitError:
print("Rate limit exceeded -- waiting 1 min before retrying")
time.sleep(60)
return _fetch_openai_completion(messages, model, **kwargs)
except openai.APIError as e:
traceback.print_exc()
print(f"OpenAI API error: {e}")
sys.exit(1)
except Exception as e:
traceback.print_exc()
print(f"Unknown error: {e}")
sys.exit(1)
return response.choices[0].message.content


def build_prompt_sequence(state_seq: List[Any]) -> List[str]:
# Loop through states
prompt_sequence: List[str] = []
for item in state_seq:
path_to_screenshot, encoded_image = load_screenshot_for_state(item)
prompt_sequence.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}"
},
}],
})
return prompt_sequence
Loading