Skip to content

Commit

Permalink
Improve weblinx towards single turn loading (#32)
Browse files Browse the repository at this point in the history
* Add tests for processing prompt

* add option to use find_turns_with_instructor_chat without using replay's method

* Add candidate scores for unittesting

* Update `weblinx.processing.dom` to be simpler, by only keeping the list of uids to keep.

* Add new test for building prompt records

* add function for converting element dict to str dmr

* add new fucntionality to llama modeling code
  • Loading branch information
xhluca authored Apr 27, 2024
1 parent 548e512 commit 026be27
Show file tree
Hide file tree
Showing 7 changed files with 2,213 additions and 20 deletions.
22 changes: 19 additions & 3 deletions modeling/llama/processing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from functools import partial
from typing import Callable

Expand Down Expand Up @@ -139,6 +140,8 @@ def build_prompt_records_for_llama_truncated(
max_candidates_tokens=65 * 10,
add_unused_len_to_cands=True,
allow_iterative_reduction=False,
use_tokenizer_template=False,
template_tokenizer=None,
parser=None,
):
"""
Expand Down Expand Up @@ -221,9 +224,22 @@ def build_prompt_records_for_llama_truncated(
# Add the unused length to the candidates
num_html_tokens = len(tokenizer.tokenize(html))
num_utter_tokens = len(tokenizer.tokenize(utterance_context))
num_prev_turns_tokens = len(
tokenizer.tokenize(" ".join(prev_turns_text_list))
)
if use_tokenizer_template:
if template_tokenizer is None:
raise ValueError(
"template_tokenizer must be provided when use_tokenizer_template is True."
)
prev_turns_merged_copy = deepcopy(prev_turns_merged)
if prev_turns_merged[0]['role'] == 'assistant':
# insert a dummy user turn
prev_turns_merged_copy.insert(0, {'role': 'user', 'content': ''})
num_prev_turns_tokens = len(template_tokenizer.apply_chat_template(
[{'role': 'system', 'content': ''}, *prev_turns_merged_copy], tokenize=True
))
else:
num_prev_turns_tokens = len(
tokenizer.tokenize(" ".join(prev_turns_text_list))
)
remain_html_tokens = max_html_tokens - num_html_tokens
remain_utter_tokens = max_utterance_tokens - num_utter_tokens
remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens
Expand Down
2,055 changes: 2,055 additions & 0 deletions tests/demonstrations/candidates_unittest.jsonl

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions tests/test_build_prompt_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import unittest

import weblinx as wl
from weblinx.processing import load_candidate_elements

# import llama's processing code to test
from modeling.llama.processing import (
format_candidates,
format_utterances,
format_utterances_truncated,
get_speaker,
multi_attempt_format_prev_turns_truncated,
)


# if needed, run this function to get candidates:
def create_candidates_unittest_jsonl():
import json
from pathlib import Path
from weblinx.processing import load_candidate_elements

demo_name = "aaabtsd" # change if needed
candidate_path = "wl_data/candidates/test_geo.jsonl" # change if needed
save_path = "tests/demonstrations/candidates_unittest.jsonl"
candidate_elements = load_candidate_elements(candidate_path, group_keys=None)
filt_elems = [e for e in candidate_elements if e["demo_name"] == demo_name]

with open(save_path, "w") as f:
for elem in filt_elems:
f.write(json.dumps(elem) + "\n")


class TestBuildPromptRecords(unittest.TestCase):
def setUp(self):
self.demo = wl.Demonstration("aaabtsd", base_dir="./tests/demonstrations")
# load tests/demonstrations/candidates_unittest.jsonl
self.candidates = load_candidate_elements(
"tests/demonstrations/candidates_unittest.jsonl"
)

def test_format_candidates(self):
"""
Tests the format_candidates function to ensure it returns the expected
string representation of a list of candidates, including the candidate
index and the candidate's intent.
"""


if __name__ == "__main__":
unittest.main()
36 changes: 36 additions & 0 deletions tests/test_processing_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
from weblinx.processing.prompt import find_turns_with_instructor_chat
import weblinx as wl


class TestProcessingPrompt(unittest.TestCase):
def setUp(self):
self.demo = wl.Demonstration("aaabtsd", base_dir="./tests/demonstrations")

def test_find_turns_with_instructor_chat(self):
"""
Test the find_turns_with_instructor_chat function to ensure it correctly
filters out turns that contain instructor chat. It checks that the output
list contains only turns with instructor chat.
"""
replay = wl.Replay.from_demonstration(self.demo)
turn = replay[15]

result = find_turns_with_instructor_chat(
replay, turn, speaker="instructor", num_prev_turns=5
)

# In this demo, we checked that there are 3 turns with instructor chat
# so it should return a list of 3 turns
self.assertEqual(len(result), 3)

# now, compare this with filter() function which should return the same
start_index = max(0, turn.index - 5)
result_filter = filter(
lambda turn: turn.get("speaker") == "instructor"
and turn.index < start_index,
replay,
)
result_filter = list(result_filter)
self.assertEqual(len(result_filter), 3)
self.assertEqual(result, result_filter)
39 changes: 25 additions & 14 deletions weblinx/processing/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def prune_tree(


def clean_and_prune_tree(
dom_tree, cands_turn, max_depth=1, max_children=5, max_sibling=2
dom_tree, cands_turn=None, candidate_uids=None, max_depth=1, max_children=5, max_sibling=2,
):
"""
This function will clean and prune the tree based on the candidates in the cands_turn. This
Expand All @@ -341,7 +341,12 @@ def clean_and_prune_tree(
The tree to clean and prune.
cands_turn : list
The list of candidates for the turn.
The list of candidates for the turn. If this is None, we are expected to pass in the
`candidate_uids`; otherwise an error will be raised.
candidate_uids : list, optional
The list of candidate uids to keep. If this is None, we are expected to pass in the
`cands_turn`; otherwise an error will be raised.
max_depth : int, optional
The maximum depth to prune the tree. Defaults to 1.
Expand All @@ -360,23 +365,29 @@ def clean_and_prune_tree(
Raises
------
ValueError
If cands_turn is None.
If cands_turn is None and candidate_uids is None. Alternatively, if both
cands_turn and candidate_uids are passed in, an error will be raised.
"""
if cands_turn is None:
if cands_turn is None and candidate_uids is None:
raise ValueError(
"cands_turn cannot be None. The dom_tree cannot be pruned this way."
"cands_turn or candidate_uids must be provided. The dom_tree cannot be pruned this way."
)

if cands_turn is not None:
candidate_uids = [cand["uid"] for cand in cands_turn]
dom_tree = prune_tree(
dom_tree,
set(candidate_uids),
max_depth=max_depth,
max_children=max_children,
max_sibling=max_sibling,
if cands_turn is not None and candidate_uids is not None:
raise ValueError(
"cands_turn and candidate_uids cannot both be provided. Please provide only one."
)
remove_uid_when_not_candidate(dom_tree, candidate_uids=candidate_uids)
if candidate_uids is None:
candidate_uids = [cand["uid"] for cand in cands_turn]

dom_tree = prune_tree(
dom_tree,
set(candidate_uids),
max_depth=max_depth,
max_children=max_children,
max_sibling=max_sibling,
)
remove_uid_when_not_candidate(dom_tree, candidate_uids=candidate_uids)

remove_html_comments(dom_tree)
sanitize_elem_attributes(dom_tree)
Expand Down
11 changes: 8 additions & 3 deletions weblinx/processing/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,13 @@ def find_turns_with_instructor_chat(
This output of this function should be used by format_utterances to display the utterances.
"""
start_index = max(0, turn.index - num_prev_turns)
instructor_chat_turns = replay.filter_turns(
lambda_func = (
lambda turn: turn.get("speaker") == speaker and turn.index < start_index
)
if isinstance(replay, list):
instructor_chat_turns = list(filter(lambda_func, replay))
else:
instructor_chat_turns = replay.filter_turns(lambda_func)
return instructor_chat_turns


Expand Down Expand Up @@ -638,7 +642,7 @@ def select_turns_and_candidates_for_prompts(
remove_turns_without_elements : bool, optional
Whether to remove turns that do not have elements. Defaults to True.
Returns
-------
list
Expand Down Expand Up @@ -672,7 +676,8 @@ def select_turns_and_candidates_for_prompts(
turns = filter_turns(
turns,
lambda turn: not (
turn.intent in ("click", "change", "textinput", "submit") and turn.element is None
turn.intent in ("click", "change", "textinput", "submit")
and turn.element is None
),
)

Expand Down
19 changes: 19 additions & 0 deletions weblinx/processing/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,25 @@ def convert_elem_dict_to_str(elem_dict: dict, remove_empty=False):

return element_str

def convert_elem_dict_to_str_dmr(elem_dict: dict):
"""
Convert an element dictionary to a string.
"""
elem_dict = deepcopy(elem_dict)

element_str = f"[[tag]] {elem_dict.pop('tag')}\n"
element_str += f"[[xpath]] {elem_dict.pop('xpath')}\n"
element_str += f"[[text]] {elem_dict.pop('text')}\n"
element_str += f"[[bbox]] {elem_dict.pop('bbox')}\n"
element_str += f"[[attributes]] {elem_dict.pop('attributes')}\n"
element_str += f"[[children]] {elem_dict.pop('children')}"

# for other keys, we just add them to the end

for k, v in elem_dict.items():
element_str += f"\n[[{k}]] {v}"

return element_str

def truncate_cands_turn(
cands_turn: list,
Expand Down

0 comments on commit 026be27

Please sign in to comment.