-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve weblinx towards single turn loading (#32)
* Add tests for processing prompt * add option to use find_turns_with_instructor_chat without using replay's method * Add candidate scores for unittesting * Update `weblinx.processing.dom` to be simpler, by only keeping the list of uids to keep. * Add new test for building prompt records * add function for converting element dict to str dmr * add new fucntionality to llama modeling code
- Loading branch information
Showing
7 changed files
with
2,213 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import sys | ||
import unittest | ||
|
||
import weblinx as wl | ||
from weblinx.processing import load_candidate_elements | ||
|
||
# import llama's processing code to test | ||
from modeling.llama.processing import ( | ||
format_candidates, | ||
format_utterances, | ||
format_utterances_truncated, | ||
get_speaker, | ||
multi_attempt_format_prev_turns_truncated, | ||
) | ||
|
||
|
||
# if needed, run this function to get candidates: | ||
def create_candidates_unittest_jsonl(): | ||
import json | ||
from pathlib import Path | ||
from weblinx.processing import load_candidate_elements | ||
|
||
demo_name = "aaabtsd" # change if needed | ||
candidate_path = "wl_data/candidates/test_geo.jsonl" # change if needed | ||
save_path = "tests/demonstrations/candidates_unittest.jsonl" | ||
candidate_elements = load_candidate_elements(candidate_path, group_keys=None) | ||
filt_elems = [e for e in candidate_elements if e["demo_name"] == demo_name] | ||
|
||
with open(save_path, "w") as f: | ||
for elem in filt_elems: | ||
f.write(json.dumps(elem) + "\n") | ||
|
||
|
||
class TestBuildPromptRecords(unittest.TestCase): | ||
def setUp(self): | ||
self.demo = wl.Demonstration("aaabtsd", base_dir="./tests/demonstrations") | ||
# load tests/demonstrations/candidates_unittest.jsonl | ||
self.candidates = load_candidate_elements( | ||
"tests/demonstrations/candidates_unittest.jsonl" | ||
) | ||
|
||
def test_format_candidates(self): | ||
""" | ||
Tests the format_candidates function to ensure it returns the expected | ||
string representation of a list of candidates, including the candidate | ||
index and the candidate's intent. | ||
""" | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
from weblinx.processing.prompt import find_turns_with_instructor_chat | ||
import weblinx as wl | ||
|
||
|
||
class TestProcessingPrompt(unittest.TestCase): | ||
def setUp(self): | ||
self.demo = wl.Demonstration("aaabtsd", base_dir="./tests/demonstrations") | ||
|
||
def test_find_turns_with_instructor_chat(self): | ||
""" | ||
Test the find_turns_with_instructor_chat function to ensure it correctly | ||
filters out turns that contain instructor chat. It checks that the output | ||
list contains only turns with instructor chat. | ||
""" | ||
replay = wl.Replay.from_demonstration(self.demo) | ||
turn = replay[15] | ||
|
||
result = find_turns_with_instructor_chat( | ||
replay, turn, speaker="instructor", num_prev_turns=5 | ||
) | ||
|
||
# In this demo, we checked that there are 3 turns with instructor chat | ||
# so it should return a list of 3 turns | ||
self.assertEqual(len(result), 3) | ||
|
||
# now, compare this with filter() function which should return the same | ||
start_index = max(0, turn.index - 5) | ||
result_filter = filter( | ||
lambda turn: turn.get("speaker") == "instructor" | ||
and turn.index < start_index, | ||
replay, | ||
) | ||
result_filter = list(result_filter) | ||
self.assertEqual(len(result_filter), 3) | ||
self.assertEqual(result, result_filter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters