Skip to content

Commit

Permalink
feat: support history in fill QA template
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyiQ committed Dec 19, 2024
1 parent 37c908d commit ed2676d
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,22 +697,26 @@ def dict_to_dialogue_list(


def fill_in_QA_template(
instruction: str,
instruction: str = "",
input: str = "",
suffix: str = "",
full_dict: dict = None,
model_repoid_or_path: Union[Literal["alpaca", "mistral", "llama3"], str] = "alpaca",
) -> str:
"""Provided with a task instruction and (optionally) supplementary input, fill them into a QA template and return the resulting prompt.
:param instruction: The task instruction.
:type instruction: str
:param instruction: The task instruction, defaults to "". Either this or full_dict must be provided.
:type instruction: str, optional
:param input: Supplementary input to the task, defaults to "".
:type input: str, optional
:param suffix: Suffix to add to the prompt, defaults to "".
:type suffix: str, optional
:param full_dict: The full dictionary containing the instruction and input, defaults to None. Either this or instruction must be provided. If this is provided, instruction, input, and suffix will be ignored.
:type full_dict: dict, optional
:param model_repoid_or_path: The model repo ID or path (e.g., "meta-llama/Meta-Llama-3-8B-Instruct"), or one of the special values "alpaca" or "mistral" or "llama3", defaults to "alpaca".
:type model_repoid_or_path: Union[Literal["alpaca", "mistral", "llama3"], str], optional
Expand All @@ -722,6 +726,18 @@ def fill_in_QA_template(

instruction = instruction.strip()
input = input.strip()

# Convert full_dict to instruction and input
if full_dict and model_repoid_or_path in ["alpaca", "mistral"]:
assert "history" not in full_dict, "History field not supported with alpaca/mistral template."
assert "system" not in full_dict, "System field not supported with alpaca/mistral template."

instruction = full_dict.get("instruction", "")
input = full_dict.get("input", "")

if input and not instruction:
warnings.warn("Swapping instruction and input fields.")
instruction, input = input, instruction

if model_repoid_or_path == "alpaca":
if suffix:
Expand Down Expand Up @@ -757,7 +773,9 @@ def fill_in_QA_template(
f"Suffix not supported except with mistral template. Ignoring suffix."
)

prompt = dict_to_dialogue_list({"instruction": instruction, "input": input})
prompt = dict_to_dialogue_list(
full_dict if full_dict else {"instruction": instruction, "input": input}
)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path)
input_full = tokenizer.apply_chat_template(
prompt, tokenize=False, add_generation_prompt=True
Expand Down

0 comments on commit ed2676d

Please sign in to comment.