Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test case for adversarial example generation #1036

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
79 changes: 79 additions & 0 deletions flaml/autogen/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from flaml import oai
import regex as re
from itertools import compress
import time
import logging

logger = logging.getLogger(__name__)

def generate_adversarial_examples(data, test_func, eval_func, num_examples=5, **config):
base_prompt = """
# Instructions
- Generate a complex version of the example in the following task.
- Make sure that the inputs are of the same types that are specified in the examples.
- Maintain the same format as the input examples, but feel free to be creative within that.
- Generate a json with double quotes.
- Do not replace integers with words.
- For mathematical examples use programmatic syntax. For example, use '*' instead of 'x' for multiplication
<|start|>(example)
{example}
<|end|>
<|start|>(answer)
"""

# base_settings = {
# "max_tokens": 64,
# "temperature": 1,
# "top_p": 1,
# "n": 5,
# "model": "gpt-4",
# }
max_iter = 10
iteration = 0
adv_examples = []

def group_check(candidate): # replace with loss function
eval_cands = eval_func(candidate)
test_cands = test_func(candidate, eval_cands)
return (test_cands == 0)

ii = 0
while len(adv_examples) < num_examples and iteration < max_iter:
# query = base_settings
# query["prompt"] = base_prompt.format(examples=str(data))
print(f"iteration={iteration}")
sample = data[ii % len(data)]
response = oai.Completion.create({"example": sample}, prompt=base_prompt, **config)
resp_candidates = re.findall(r"(?={).*(?<=})", oai.Completion.extract_text(response)[0])
if len(resp_candidates) > 0:
adv_candidates = list(map(eval, resp_candidates))
time.sleep(30)
eval_candidates = list(map(group_check, adv_candidates))
valid_candidates = list(compress(adv_candidates, eval_candidates))
if len(valid_candidates) > 0:
adv_examples.append(valid_candidates)
iteration = 0
else:
iteration += 1
time.sleep(30)
ii += 1

return adv_examples


# base_prompt = """
# <|meta_start|>
# # Introduction
# - You are an adversarial example generation assistant
# - Your goal is to generate more complex versions of the examples in the following task.
# - Make sure that the input would result in the same target as specified.
# - Make sure that the inputs are of the same types that are specified in the examples.
# - Generate parsable json with double quotes.
# - Do not replace integers with words.
# <|meta_end|>
# <|start|>(example)
# {examples}
# <|end|>
# <|start|>(answer)
# """
14 changes: 14 additions & 0 deletions test/autogen/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
hydra:
job:
chdir: false

openai:
key_path: <key-path>
adv:
model: <adversarial-gen-model>
# api_base:
# Other override arguments for adv
eval:
model: 'text-davinci-003'
# api_base:
# other override args
207 changes: 207 additions & 0 deletions test/autogen/test_adv_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from flaml import oai
from flaml.autogen.datagen import generate_adversarial_examples
import re
import logging
import hydra
import wikipedia

KEY_LOC = "./test/autogen"
logger = logging.getLogger(__name__)


@hydra.main(config_path="configs", config_name="config-srsharm")
def test_adv_gen(cfg):
try:
import openai
except ImportError:
return

# config_list_adv = oai.config_list_gpt4_gpt35(KEY_LOC)
config_list_adv = oai.config_list_openai_aoai(KEY_LOC) # [1:]
config_list_adv[0].update(cfg.openai.adv)
config_list_eval = oai.config_list_openai_aoai(KEY_LOC)
config_list_eval[0].update(cfg.openai.eval)

test_cases = [ SimpleArith(config_list=config_list_eval)
# WikipediaQGen(
# config_list_adv=config_list_adv,
# config_adv=cfg.openai.adv,
# config_list_eval=config_list_eval,
# config_eval=cfg.openai.eval,
# )
]

for case in test_cases:
adv_examples = generate_adversarial_examples(
data=case.input_examples,
test_func=case.test_func,
eval_func=case.eval_func,
Comment on lines +37 to +38
Copy link
Contributor Author

@sonichi sonichi Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment about the meaning of these two functions? Do we need a better name for them?

num_examples=5,
# reduction=np.mean,
config_list=config_list_adv,
**cfg.openai.adv,
)
print(adv_examples)


class SimpleArith:
input_examples = [
{"input": "1 + 4 =", "target": "5"},
{"input": "4 + 9 =", "target": "13"},
{"input": "8 + 3 =", "target": "11"},
{"input": "30 + 89 =", "target": "119"},
{"input": "486 + 141 =", "target": "627"},
{"input": "13 + 476 =", "target": "489"},
{"input": "773 + 546 =", "target": "1319"},
{"input": "348 + 227 =", "target": "575"},
]

def __init__(self, config_list):
self.config_list = config_list

@staticmethod
def test_func(example, eval_out):
logger.info(f"example input = {example['input']}")
try:
lhs = eval(re.findall(r"^(.*?)=", example["input"])[0].strip())
logger.info(f"example={example}, llm_response={eval_out}")
rhs = float(eval_out)
return lhs == rhs
except:
logger.info('eval was unsuccessful due to errors')
return -1



def eval_func(self, example):
base_prompt = "{input}"
config = {
"max_tokens": 5,
"temperature": 0,
"top_p": 1,
"n": 1,
"stream": False,
"model": "text-davinci-003"
}
# query['prompt'] = base_prompt.format(example['input'])
# resp = oai.Completion.create(**query)
response = oai.Completion.create(example, prompt=base_prompt, config_list=self.config_list, **config)
return oai.Completion.extract_text(response)[0].strip()


class WikipediaQGen:
def __init__(
self, config_list_adv={}, search_term="Cornell University", config_eval={}, config_adv={}, config_list_eval={}
):
self.config_list_adv = config_list_adv
self.config_list_eval = config_list_eval
self.config_eval = config_eval
self.config_adv = config_adv
r = wikipedia.search(search_term)
page = wikipedia.page(r[0])
self.title = page.title
self.content = page.content
example_gen_prompt = f"""<|im_start|>system
You are a question generating assistant. Your objective is to take some context and generate questions together with their corresponding answer or possible answers
<|im_end|>
<|im_start|>user
Context
---
#
{page.title}

{page.content}
<|im_end|>
<|im_start|>user
Generate a series of questions related to {page.title} as follows.

1. Mode = "paragraph"

Write a question for which the answer is a short paragraph.

2. Mode = "few-words"

The answer is at most a few words.

3. Mode = "number"

The answer is a number.

4. Mode = "bool"

Generate a question with a True/False answer.

For each question above, provide the corresponding correct answer. If there is more than one correct answer, provide a list of all possible answers.
<|im_end|>
<|im_start|>assistant
"""
config = {
"max_tokens": 512,
"temperature": 0.7,
"top_p": 1,
"n": 1,
"model": "gpt-4-32k",
}
response = oai.Completion.create(prompt=example_gen_prompt, config_list=self.config_list_adv, **config)
answer = oai.Completion.extract_text(response)[0].strip()
# find qa
qa_parsed = re.findall(r"(?=Question:)[\s\S]*?(?=[0-9]. Mode|$)", answer)
self.input_examples = []
for qa in qa_parsed:
example = {
"input": re.findall(r"(?<=Question:)[\s\S]*?(?=Answer:)", qa)[0].strip(),
"target": re.findall(r"(?<=Answer:)", qa)[0].strip(),
}
self.input_examples.append(example)

# def add_message(self, content, role="user"):
# self.messages.append({"role": role, "content": content})

def verif_func(self, example):
print(example)
base_prompt = """Respond with Yes or No, does the text below answer the question provided?
Question: {input}
Text: {target}
Answer:
"""
config = {
"max_tokens": 512,
"temperature": 0,
"top_p": 1,
"n": 1,
**self.config_adv,
}
response = oai.Completion.create(example, prompt=base_prompt, config_list=self.config_list_adv, **config)
answer = oai.Completion.extract_text(response)[0].strip()
return answer == "Yes"

def test_func(self, example):
base_prompt = f"""Answer the following question based on the context provided.
Question:
{{input}}
Context:
{self.title}
{self.content}
Answer:
"""
config = {
"max_tokens": 512,
"temperature": 0,
"top_p": 1,
"n": 1,
**self.config_eval,
}
response = oai.Completion.create(example, prompt=base_prompt, config_list=self.config_list_eval, **config)
answer = oai.Completion.extract_text(response)[0]
pred_example = {"input": example["input"], "target": answer}
return self.verif_func(pred_example)


if __name__ == "__main__":
# import openai
# import os

# config_list = oai.config_list_openai_aoai(KEY_LOC)
# assert len(config_list) >= 3, config_list
# openai.api_key = os.environ["OPENAI_API_KEY"]
test_adv_gen()