diff --git a/src/abstractions/backends.py b/src/abstractions/backends.py index 9e4c369..a62f07d 100644 --- a/src/abstractions/backends.py +++ b/src/abstractions/backends.py @@ -697,15 +697,16 @@ def dict_to_dialogue_list( def fill_in_QA_template( - instruction: str, + instruction: str = "", input: str = "", suffix: str = "", + full_dict: dict = None, model_repoid_or_path: Union[Literal["alpaca", "mistral", "llama3"], str] = "alpaca", ) -> str: """Provided with a task instruction and (optionally) supplementary input, fill them into a QA template and return the resulting prompt. - :param instruction: The task instruction. - :type instruction: str + :param instruction: The task instruction, defaults to "". Either this or full_dict must be provided. + :type instruction: str, optional :param input: Supplementary input to the task, defaults to "". :type input: str, optional @@ -713,6 +714,9 @@ def fill_in_QA_template( :param suffix: Suffix to add to the prompt, defaults to "". :type suffix: str, optional + :param full_dict: The full dictionary containing the instruction and input, defaults to None. Either this or instruction must be provided. If this is provided, instruction, input, and suffix will be ignored. + :type full_dict: dict, optional + :param model_repoid_or_path: The model repo ID or path (e.g., "meta-llama/Meta-Llama-3-8B-Instruct"), or one of the special values "alpaca" or "mistral" or "llama3", defaults to "alpaca". :type model_repoid_or_path: Union[Literal["alpaca", "mistral", "llama3"], str], optional @@ -722,6 +726,18 @@ def fill_in_QA_template( instruction = instruction.strip() input = input.strip() + + # Convert full_dict to instruction and input + if full_dict and model_repoid_or_path in ["alpaca", "mistral"]: + assert "history" not in full_dict, "History field not supported with alpaca/mistral template." + assert "system" not in full_dict, "System field not supported with alpaca/mistral template." + + instruction = full_dict.get("instruction", "") + input = full_dict.get("input", "") + + if input and not instruction: + warnings.warn("Swapping instruction and input fields.") + instruction, input = input, instruction if model_repoid_or_path == "alpaca": if suffix: @@ -757,7 +773,9 @@ def fill_in_QA_template( f"Suffix not supported except with mistral template. Ignoring suffix." ) - prompt = dict_to_dialogue_list({"instruction": instruction, "input": input}) + prompt = dict_to_dialogue_list( + full_dict if full_dict else {"instruction": instruction, "input": input} + ) tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path) input_full = tokenizer.apply_chat_template( prompt, tokenize=False, add_generation_prompt=True