Skip to content

Commit

Permalink
feat(abstractions): extra functionalities in dialogue manipulation
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyiQ committed Dec 4, 2024
1 parent fd24631 commit 5ff60d9
Showing 1 changed file with 69 additions and 16 deletions.
85 changes: 69 additions & 16 deletions src/abstractions/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,13 @@ def inv_map_key_fields_fn(sample_dict: Dict) -> Dict:
result.key_fields = self.key_fields.copy()
return result

def move_current_to_history(self):
def move_current_to_history(self, out_of_place: bool = False) -> "Data":
"""
Move the current dialogue turn in the prompt/question field and the response/predict field to the history field.
:param out_of_place: Whether to perform the operation out-of-place. If out_of_place is True, the original data will not be modified, and a new Data instance with an annotated name will be returned. Otherwise, the original data will be modified in-place, and the same Data instance will be returned.
:type out_of_place: bool = False
:return: The data after the operation.
:rtype: Data.
"""
Expand All @@ -339,68 +342,118 @@ def move_to_history_fn(sample_dict: Dict) -> Dict:

return sample_dict

return self.transform(move_to_history_fn, self.data_name, forced_rewrite=True, map_key_fields=True)
new_data_name = (self.data_name + "_moved") if out_of_place else self.data_name
return self.transform(move_to_history_fn, new_data_name, forced_rewrite=True, map_key_fields=True)

def switch_role_to_user(self, user_system_prompt: str = None, dialogue_starter: str = None):
def switch_role_to_user(self, user_system_prompt: Union[str, Iterable[str]] = None, dialogue_starter: Union[str, Iterable[str]] = None, out_of_place: bool = False) -> "Data":
"""
Switch the prompt/question field and the response/predict field, thereby shifting the dialogue turn from the assistant to the user.
:param user_system_prompt: The system prompt of the user role.
:type user_system_prompt: str = None
:param user_system_prompt: The system prompt of the user role. Can be a single string or an iterable of strings, where each string corresponds to the prompt for a different sample in the dataset. If None, a default prompt will be used.
:type user_system_prompt: Union[str, Iterable[str]] = None
:param dialogue_starter: Placeholder message for the "zeroth" dialogue turn by the assistant that prompts the user to start the conversation.
:type dialogue_starter: str = None
:param out_of_place: Whether to perform the operation out-of-place. If out_of_place is True, the original data will not be modified, and a new Data instance with an annotated name will be returned. Otherwise, the original data will be modified in-place, and the same Data instance will be returned.
:type out_of_place: bool = False
:return: The data after the operation.
:rtype: Data.
"""
if user_system_prompt is None:
user_system_prompt = "You are an assistant tasked with questioning the user, aka your partner. Ask informed questions to guide the conversation, follow up on the user's responses, and generally follow a natural conversation flow. Don't be too courteous; be concise."
elif isinstance(user_system_prompt, list):
user_system_prompt = iter(user_system_prompt)

if dialogue_starter is None:
dialogue_starter = "I am your partner. Please directly ask your first question."
dialogue_starter = "I am your partner. Please start the conversation."
elif isinstance(dialogue_starter, list):
dialogue_starter = iter(dialogue_starter)

moved_to_history = self.move_current_to_history()
moved_to_history = self.move_current_to_history(out_of_place)

def switch_role_to_user_fn(sample_dict: Dict) -> Dict:
assert not (sample_dict.get("instruction", "") or sample_dict.get("input", "") or sample_dict.get("output", "") or sample_dict.get("predict", ""))

current_user_system_prompt = user_system_prompt if isinstance(user_system_prompt, str) else next(user_system_prompt)
current_dialogue_starter = dialogue_starter if isinstance(dialogue_starter, str) else next(dialogue_starter)

all_histories = [h[i] for h in sample_dict.get("history", []) for i in range(2)]
all_histories = [dialogue_starter] + all_histories
all_histories = [current_dialogue_starter] + all_histories
assert len(all_histories) % 2 == 1
sample_dict["history"] = [[all_histories[i], all_histories[i + 1]] for i in range(0, len(all_histories)-1, 2)]
sample_dict["instruction"] = all_histories[-1]
sample_dict["system"] = user_system_prompt
sample_dict["system"] = current_user_system_prompt
return sample_dict

return moved_to_history.transform(switch_role_to_user_fn, self.data_name, forced_rewrite=True, map_key_fields=True)
new_data_name = (self.data_name + "_user") if out_of_place else self.data_name
return moved_to_history.transform(switch_role_to_user_fn, new_data_name, forced_rewrite=True, map_key_fields=True)

def switch_role_to_assistant(self, assistant_system_prompt: str = None):
def switch_role_to_assistant(self, assistant_system_prompt: Union[str, Iterable[str]] = None, out_of_place: bool = False) -> "Data":
"""
Switch the prompt/question field and the response/predict field, thereby shifting the dialogue turn from the user to the assistant.
:param assistant_system_prompt: The system prompt of the assistant role.
:type assistant_system_prompt: str = None
:param assistant_system_prompt: The system prompt of the assistant role. Can be a single string or an iterable of strings, where each string corresponds to the prompt for a different sample in the dataset. If None, a default prompt will be used.
:type assistant_system_prompt: Union[str, Iterable[str]] = None
:param out_of_place: Whether to perform the operation out-of-place. If out_of_place is True, the original data will not be modified, and a new Data instance with an annotated name will be returned. Otherwise, the original data will be modified in-place, and the same Data instance will be returned.
:type out_of_place: bool = False
:return: The data after the operation.
:rtype: Data.
"""
if assistant_system_prompt is None:
assistant_system_prompt = "Please answer the user's questions. Be concise and not overly courteous, but be informative and provide all necessary details."
elif isinstance(assistant_system_prompt, list):
assistant_system_prompt = iter(assistant_system_prompt)

moved_to_history = self.move_current_to_history()
moved_to_history = self.move_current_to_history(out_of_place)

def switch_role_to_assistant_fn(sample_dict: Dict) -> Dict:
assert not (sample_dict.get("instruction", "") or sample_dict.get("input", "") or sample_dict.get("output", "") or sample_dict.get("predict", ""))

current_assistant_system_prompt = assistant_system_prompt if isinstance(assistant_system_prompt, str) else next(assistant_system_prompt)

all_histories = [h[i] for h in sample_dict.get("history", []) for i in range(2)]
assert len(all_histories) % 2 == 0
sample_dict["history"] = [[all_histories[i], all_histories[i + 1]] for i in range(1, len(all_histories)-1, 2)]
sample_dict["instruction"] = all_histories[-1]
sample_dict["system"] = assistant_system_prompt
sample_dict["system"] = current_assistant_system_prompt
return sample_dict

return moved_to_history.transform(switch_role_to_assistant_fn, self.data_name, forced_rewrite=True, map_key_fields=True)
new_data_name = (self.data_name + "_assistant") if out_of_place else self.data_name
return moved_to_history.transform(switch_role_to_assistant_fn, new_data_name, forced_rewrite=True, map_key_fields=True)

def append_content(self, field_key: str, content: Union[str, Iterable[str]], out_of_place: bool = False, map_key_fields: bool = False) -> "Data":
"""
Append content to a specified field in the dataset.
:param field_key: The key of the field to append content to.
:type field_key: str
:param content: The content to append. Can be a single string or an iterable of strings, where each string corresponds to the content to append for a different sample in the dataset.
:type content: Union[str, Iterable[str]]
:param out_of_place: Whether to perform the operation out-of-place. If out_of_place is True, the original data will not be modified, and a new Data instance with an annotated name will be returned. Otherwise, the original data will be modified in-place, and the same Data instance will be returned.
:type out_of_place: bool = False
:param map_key_fields: Whether to map the key fields to the default key fields before appending content.
:type map_key_fields: bool = False
:return: The data after the operation.
:rtype: Data.
"""
if isinstance(content, list):
content = iter(content)

def append_content_fn(sample_dict: Dict) -> Dict:
current_content = content if isinstance(content, str) else next(content)
sample_dict[field_key] = sample_dict.get(field_key, "") + current_content
return sample_dict

new_data_name = (self.data_name + "_appended") if out_of_place else self.data_name
return self.transform(append_content_fn, new_data_name, forced_rewrite=True, map_key_fields=map_key_fields)

def manage_llama_factory_registration(
self, operation: Literal["add", "remove", "query"], forced_update: bool = True
Expand Down

0 comments on commit 5ff60d9

Please sign in to comment.