-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(bench): support safety attack #851
base: main
Are you sure you want to change the base?
Changes from 10 commits
98de74e
889986c
cc8114c
226c1c2
999c32a
678b967
0b4351e
f5117a5
d230ef6
2f06144
d973f8f
13ec427
d859ca3
9213653
dcf7fe3
fad8ab1
3d584e2
a93153c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import csv | ||
import json | ||
|
||
# 读取 adversarial.json 文件 | ||
with open('./adversarial_prompts.json', 'r', encoding='utf-8') as f: | ||
adversarial_data = json.load(f) | ||
|
||
# 读取 task.csv 文件并构建 task 到 domain 的映射 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid chinese character in the code. |
||
task_to_domain = {} | ||
with open('./tasks.csv', 'r', encoding='utf-8') as f: | ||
reader = csv.DictReader(f) | ||
for row in reader: | ||
task = row['Task'].strip() # 获取任务名称并去除多余空格 | ||
domain = row['Scientific Domain'].strip() # 获取对应的领域并去除多余空格 | ||
task_to_domain[task] = domain | ||
|
||
# 遍历 adversarial.json 数据并为每个 item 添加 domain | ||
for item in adversarial_data: | ||
task = item['task'].strip() # 获取 adversarial.json 中的 task | ||
if task in task_to_domain: # 如果 task 在 task.csv 中存在 | ||
item['domain'] = task_to_domain[task] # 添加 domain 键 | ||
|
||
# 将更新后的数据写入新的 JSON 文件 | ||
with open('updated_adversarial.json', 'w', encoding='utf-8') as f: | ||
json.dump(adversarial_data, f, ensure_ascii=False, indent=4) | ||
|
||
print('新的 JSON 文件已生成:updated_adversarial.json') |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please ignore this dataset file. This should not be included. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"namespaces": [ | ||
"Paper", | ||
"ProposalWritingLog", | ||
"IdeaBrainstormLog", | ||
"Profile", | ||
"Proposal", | ||
"Idea", | ||
"MetaReview", | ||
"Review", | ||
"RebuttalWritingLog", | ||
"ReviewWritingLog", | ||
"LiteratureReviewLog", | ||
"MetaReviewWritingLog", | ||
"Rebuttal", | ||
"Insight" | ||
], | ||
"embedding_namespaces": [ | ||
"Paper", | ||
"Profile" | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,61 @@ def write_proposal_researchtown( | |
|
||
# Exit the environment and retrieve the generated proposal | ||
exit_status, exit_dict = env.on_exit() | ||
proposal = exit_dict.get('proposal') | ||
proposal = exit_dict.get('proposals')[0] | ||
if proposal and proposal.content: | ||
return str(proposal.content) | ||
else: | ||
raise ValueError('Failed to generate proposal') | ||
|
||
|
||
def write_proposal_researchtown_nodb( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this nodb stands for? you init db class, right? |
||
profiles: List[Profile], | ||
ref_contents: List[str], | ||
config: Config, | ||
) -> str: | ||
log_db = LogDB(config=config.database) | ||
progress_db = ProgressDB(config=config.database) | ||
profile_db = ProfileDB(config=config.database) | ||
paper_db = PaperDB(config=config.database) | ||
agent_manager = AgentManager(config=config.param, profile_db=profile_db) | ||
|
||
env = ProposalWritingEnv( | ||
name='proposal_writing', | ||
log_db=log_db, | ||
progress_db=progress_db, | ||
paper_db=paper_db, | ||
config=config, | ||
agent_manager=agent_manager, | ||
) | ||
|
||
leader_profile = profiles[0] | ||
print('leader_profile', leader_profile) | ||
leader = agent_manager.create_agent(leader_profile, role='leader') | ||
members = [ | ||
agent_manager.create_agent(profile, role='member') for profile in profiles[1:] | ||
] | ||
if not leader_profile: | ||
raise ValueError('Failed to create leader agent') | ||
|
||
env.on_enter( | ||
leader=leader, | ||
contexts=ref_contents, | ||
members=members, | ||
) | ||
|
||
# Run the environment to generate the proposal | ||
run_result = env.run() | ||
if run_result is not None: | ||
for progress, agent in run_result: | ||
print(progress, agent) | ||
# Process progress and agent if needed | ||
pass | ||
|
||
# Exit the environment and retrieve the generated proposal | ||
exit_status, exit_dict = env.on_exit() | ||
|
||
print(exit_dict) | ||
proposal = exit_dict.get('proposals')[0] | ||
if proposal and proposal.content: | ||
return str(proposal.content) | ||
else: | ||
|
@@ -307,6 +361,10 @@ def write_proposal( | |
return write_proposal_researchtown( | ||
profiles=profiles, ref_contents=ref_contents, config=config | ||
) | ||
elif mode == 'textgnn_nodb': | ||
return write_proposal_researchtown_nodb( | ||
profiles=profiles, ref_contents=ref_contents, config=config | ||
) | ||
elif mode == 'sakana_ai_scientist': | ||
return write_proposal_sakana_ai_scientist( | ||
ref_contents=ref_contents, config=config, num_reflections=5 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import argparse | ||
import json | ||
import os | ||
from typing import Any, Dict | ||
|
||
from tqdm import tqdm | ||
|
||
from research_bench.proposal_writing import write_proposal | ||
from research_town.configs import Config | ||
from research_town.data import Profile | ||
from research_town.utils.logger import logger | ||
|
||
|
||
def load_adversarial_data( | ||
adversarial_path: str, output_path: str | ||
) -> Dict[str, Dict[str, Any]]: | ||
""" | ||
Load adversarial data from a JSON file and skip already processed entries based on the output_path. | ||
Assign a unique attack_id to each entry. | ||
|
||
:param adversarial_path: Path to adversarial.json | ||
:param output_path: Path to existing output JSONL file | ||
:return: Dictionary mapping attack_id to data | ||
""" | ||
with open(adversarial_path, 'r', encoding='utf-8') as f: | ||
adversarial_list = json.load(f) | ||
|
||
# Assign unique attack_id, e.g., "attack_0", "attack_1", ... | ||
dataset = {f'attack_{idx}': entry for idx, entry in enumerate(adversarial_list)} | ||
|
||
if os.path.exists(output_path): | ||
with open(output_path, 'r', encoding='utf-8') as f: | ||
processed_ids = {json.loads(line)['attack_id'] for line in f} | ||
dataset = {k: v for k, v in dataset.items() if k not in processed_ids} | ||
|
||
return dataset | ||
|
||
|
||
def load_profiles(profile_path: str) -> Dict[str, Any]: | ||
""" | ||
Load profiles from profile.json. | ||
|
||
:param profile_path: Path to profile.json | ||
:return: Dictionary mapping domain to scientists' profiles | ||
""" | ||
with open(profile_path, 'r', encoding='utf-8') as f: | ||
profiles: Dict[str, Any] = json.load(f) | ||
return profiles | ||
|
||
|
||
def save_result(result: Dict[str, Any], output_path: str) -> None: | ||
""" | ||
Append a single JSON object to the output JSONL file. | ||
|
||
:param result: Result dictionary to save | ||
:param output_path: Path to the output JSONL file | ||
""" | ||
with open(output_path, 'a', encoding='utf-8') as f: | ||
json.dump(result, f, ensure_ascii=False) | ||
f.write('\n') | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser(description='Research Proposal Attack Script') | ||
parser.add_argument( | ||
'--adversarial_path', | ||
type=str, | ||
required=False, | ||
default='./attackbench/adversarial.json', | ||
help='Input adversarial JSON file path', | ||
) | ||
parser.add_argument( | ||
'--profile_path', | ||
type=str, | ||
required=False, | ||
default='./attackbench/profiles.json', | ||
help='Input profile JSON file path', | ||
) | ||
parser.add_argument( | ||
'--output_path', | ||
type=str, | ||
required=False, | ||
default='./attackbench/attack_results.jsonl', | ||
help='Output JSONL file path', | ||
) | ||
parser.add_argument( | ||
'--mode', | ||
type=str, | ||
required=False, | ||
default='textgnn_nodb', | ||
choices=[ | ||
'author_only', | ||
'citation_only', | ||
'author_citation', | ||
'textgnn', | ||
'sakana_ai_scientist', | ||
], | ||
help='Processing mode', | ||
) | ||
parser.add_argument( | ||
'--config_path', | ||
type=str, | ||
default='../configs', | ||
help='Path to the configuration directory', | ||
) | ||
args = parser.parse_args() | ||
|
||
config = Config(args.config_path) | ||
dataset = load_adversarial_data(args.adversarial_path, args.output_path) | ||
profiles_dict = load_profiles(args.profile_path) | ||
logger.info(f'Processing {len(dataset)} adversarial entries') | ||
|
||
for attack_id, data in tqdm(dataset.items(), desc='Processing attacks'): | ||
template = data.get('template', '') | ||
task = data.get('task', '') | ||
text = data.get('text', '') | ||
domain = data.get('domain', '') | ||
|
||
text_list = [text] | ||
|
||
# Get profiles for the domain | ||
domain_profiles = profiles_dict.get(domain, {}) | ||
profiles = [ | ||
Profile(name=scientist, bio=info.get('bio', '')) | ||
for scientist, info in domain_profiles.items() | ||
] | ||
|
||
# Generate proposal | ||
gen_proposal = write_proposal(args.mode, profiles, text_list, config) | ||
|
||
# Prepare result | ||
result = { | ||
'attack_id': attack_id, | ||
'template': template, | ||
'task': task, | ||
'domain': domain, | ||
'gen_proposal': gen_proposal, | ||
} | ||
|
||
# Save result | ||
save_result(result, args.output_path) | ||
|
||
logger.info( | ||
f'All adversarial entries have been processed and saved to {args.output_path}' | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,8 +29,14 @@ def __init__( | |
@beartype | ||
def on_enter(self, **context: Any) -> None: | ||
# Assign leader and members from context or sample them | ||
self.leader = context.get('leader', self.agent_manager.sample_leader()) | ||
self.members = context.get('members', self.agent_manager.sample_members()) | ||
if context.get('leader') is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you help change all the things in all environment files here? this is a potential bug in research-town |
||
self.leader = self.agent_manager.sample_leader() | ||
else: | ||
self.leader = context['leader'] | ||
if context.get('members') is None: | ||
self.members = self.agent_manager.sample_members() | ||
else: | ||
self.members = context['members'] | ||
|
||
if 'contexts' not in context: | ||
raise ValueError("'contexts' is required in the context.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this dependency was changed due to @chengzr01 's suggestions. maybe we need to double consider that.
It works well on server deployment. Is the error related to poetry version? Maybe we need to set it as >=2.2.1 < 2.5.0