diff --git a/flaml/autogen/datagen.py b/flaml/autogen/datagen.py new file mode 100644 index 0000000000..d83bf5c68c --- /dev/null +++ b/flaml/autogen/datagen.py @@ -0,0 +1,79 @@ +import json +from flaml import oai +import regex as re +from itertools import compress +import time +import logging + +logger = logging.getLogger(__name__) + +def generate_adversarial_examples(data, test_func, eval_func, num_examples=5, **config): + base_prompt = """ +# Instructions +- Generate a complex version of the example in the following task. +- Make sure that the inputs are of the same types that are specified in the examples. +- Maintain the same format as the input examples, but feel free to be creative within that. +- Generate a json with double quotes. +- Do not replace integers with words. +- For mathematical examples use programmatic syntax. For example, use '*' instead of 'x' for multiplication +<|start|>(example) +{example} +<|end|> +<|start|>(answer) + """ + + # base_settings = { + # "max_tokens": 64, + # "temperature": 1, + # "top_p": 1, + # "n": 5, + # "model": "gpt-4", + # } + max_iter = 10 + iteration = 0 + adv_examples = [] + + def group_check(candidate): # replace with loss function + eval_cands = eval_func(candidate) + test_cands = test_func(candidate, eval_cands) + return (test_cands == 0) + + ii = 0 + while len(adv_examples) < num_examples and iteration < max_iter: + # query = base_settings + # query["prompt"] = base_prompt.format(examples=str(data)) + print(f"iteration={iteration}") + sample = data[ii % len(data)] + response = oai.Completion.create({"example": sample}, prompt=base_prompt, **config) + resp_candidates = re.findall(r"(?={).*(?<=})", oai.Completion.extract_text(response)[0]) + if len(resp_candidates) > 0: + adv_candidates = list(map(eval, resp_candidates)) + time.sleep(30) + eval_candidates = list(map(group_check, adv_candidates)) + valid_candidates = list(compress(adv_candidates, eval_candidates)) + if len(valid_candidates) > 0: + adv_examples.append(valid_candidates) + iteration = 0 + else: + iteration += 1 + time.sleep(30) + ii += 1 + + return adv_examples + + +# base_prompt = """ + # <|meta_start|> + # # Introduction + # - You are an adversarial example generation assistant + # - Your goal is to generate more complex versions of the examples in the following task. + # - Make sure that the input would result in the same target as specified. + # - Make sure that the inputs are of the same types that are specified in the examples. + # - Generate parsable json with double quotes. + # - Do not replace integers with words. + # <|meta_end|> + # <|start|>(example) + # {examples} + # <|end|> + # <|start|>(answer) + # """ \ No newline at end of file diff --git a/test/autogen/configs/config.yaml b/test/autogen/configs/config.yaml new file mode 100644 index 0000000000..62b687f8fc --- /dev/null +++ b/test/autogen/configs/config.yaml @@ -0,0 +1,14 @@ +hydra: + job: + chdir: false + +openai: + key_path: + adv: + model: + # api_base: + # Other override arguments for adv + eval: + model: 'text-davinci-003' + # api_base: + # other override args \ No newline at end of file diff --git a/test/autogen/test_adv_gen.py b/test/autogen/test_adv_gen.py new file mode 100644 index 0000000000..be9edd39ca --- /dev/null +++ b/test/autogen/test_adv_gen.py @@ -0,0 +1,207 @@ +from flaml import oai +from flaml.autogen.datagen import generate_adversarial_examples +import re +import logging +import hydra +import wikipedia + +KEY_LOC = "./test/autogen" +logger = logging.getLogger(__name__) + + +@hydra.main(config_path="configs", config_name="config-srsharm") +def test_adv_gen(cfg): + try: + import openai + except ImportError: + return + + # config_list_adv = oai.config_list_gpt4_gpt35(KEY_LOC) + config_list_adv = oai.config_list_openai_aoai(KEY_LOC) # [1:] + config_list_adv[0].update(cfg.openai.adv) + config_list_eval = oai.config_list_openai_aoai(KEY_LOC) + config_list_eval[0].update(cfg.openai.eval) + + test_cases = [ SimpleArith(config_list=config_list_eval) + # WikipediaQGen( + # config_list_adv=config_list_adv, + # config_adv=cfg.openai.adv, + # config_list_eval=config_list_eval, + # config_eval=cfg.openai.eval, + # ) + ] + + for case in test_cases: + adv_examples = generate_adversarial_examples( + data=case.input_examples, + test_func=case.test_func, + eval_func=case.eval_func, + num_examples=5, + # reduction=np.mean, + config_list=config_list_adv, + **cfg.openai.adv, + ) + print(adv_examples) + + +class SimpleArith: + input_examples = [ + {"input": "1 + 4 =", "target": "5"}, + {"input": "4 + 9 =", "target": "13"}, + {"input": "8 + 3 =", "target": "11"}, + {"input": "30 + 89 =", "target": "119"}, + {"input": "486 + 141 =", "target": "627"}, + {"input": "13 + 476 =", "target": "489"}, + {"input": "773 + 546 =", "target": "1319"}, + {"input": "348 + 227 =", "target": "575"}, + ] + + def __init__(self, config_list): + self.config_list = config_list + + @staticmethod + def test_func(example, eval_out): + logger.info(f"example input = {example['input']}") + try: + lhs = eval(re.findall(r"^(.*?)=", example["input"])[0].strip()) + logger.info(f"example={example}, llm_response={eval_out}") + rhs = float(eval_out) + return lhs == rhs + except: + logger.info('eval was unsuccessful due to errors') + return -1 + + + + def eval_func(self, example): + base_prompt = "{input}" + config = { + "max_tokens": 5, + "temperature": 0, + "top_p": 1, + "n": 1, + "stream": False, + "model": "text-davinci-003" + } + # query['prompt'] = base_prompt.format(example['input']) + # resp = oai.Completion.create(**query) + response = oai.Completion.create(example, prompt=base_prompt, config_list=self.config_list, **config) + return oai.Completion.extract_text(response)[0].strip() + + +class WikipediaQGen: + def __init__( + self, config_list_adv={}, search_term="Cornell University", config_eval={}, config_adv={}, config_list_eval={} + ): + self.config_list_adv = config_list_adv + self.config_list_eval = config_list_eval + self.config_eval = config_eval + self.config_adv = config_adv + r = wikipedia.search(search_term) + page = wikipedia.page(r[0]) + self.title = page.title + self.content = page.content + example_gen_prompt = f"""<|im_start|>system +You are a question generating assistant. Your objective is to take some context and generate questions together with their corresponding answer or possible answers +<|im_end|> +<|im_start|>user +Context +--- +# +{page.title} + +{page.content} +<|im_end|> +<|im_start|>user +Generate a series of questions related to {page.title} as follows. + +1. Mode = "paragraph" + +Write a question for which the answer is a short paragraph. + +2. Mode = "few-words" + +The answer is at most a few words. + +3. Mode = "number" + +The answer is a number. + +4. Mode = "bool" + +Generate a question with a True/False answer. + +For each question above, provide the corresponding correct answer. If there is more than one correct answer, provide a list of all possible answers. +<|im_end|> +<|im_start|>assistant +""" + config = { + "max_tokens": 512, + "temperature": 0.7, + "top_p": 1, + "n": 1, + "model": "gpt-4-32k", + } + response = oai.Completion.create(prompt=example_gen_prompt, config_list=self.config_list_adv, **config) + answer = oai.Completion.extract_text(response)[0].strip() + # find qa + qa_parsed = re.findall(r"(?=Question:)[\s\S]*?(?=[0-9]. Mode|$)", answer) + self.input_examples = [] + for qa in qa_parsed: + example = { + "input": re.findall(r"(?<=Question:)[\s\S]*?(?=Answer:)", qa)[0].strip(), + "target": re.findall(r"(?<=Answer:)", qa)[0].strip(), + } + self.input_examples.append(example) + + # def add_message(self, content, role="user"): + # self.messages.append({"role": role, "content": content}) + + def verif_func(self, example): + print(example) + base_prompt = """Respond with Yes or No, does the text below answer the question provided? + Question: {input} + Text: {target} + Answer: + """ + config = { + "max_tokens": 512, + "temperature": 0, + "top_p": 1, + "n": 1, + **self.config_adv, + } + response = oai.Completion.create(example, prompt=base_prompt, config_list=self.config_list_adv, **config) + answer = oai.Completion.extract_text(response)[0].strip() + return answer == "Yes" + + def test_func(self, example): + base_prompt = f"""Answer the following question based on the context provided. + Question: + {{input}} + Context: + {self.title} + {self.content} + Answer: + """ + config = { + "max_tokens": 512, + "temperature": 0, + "top_p": 1, + "n": 1, + **self.config_eval, + } + response = oai.Completion.create(example, prompt=base_prompt, config_list=self.config_list_eval, **config) + answer = oai.Completion.extract_text(response)[0] + pred_example = {"input": example["input"], "target": answer} + return self.verif_func(pred_example) + + +if __name__ == "__main__": + # import openai + # import os + + # config_list = oai.config_list_openai_aoai(KEY_LOC) + # assert len(config_list) >= 3, config_list + # openai.api_key = os.environ["OPENAI_API_KEY"] + test_adv_gen()