Skip to content

Commit

Permalink
support rebuttal process (#58)
Browse files Browse the repository at this point in the history
* support paper rebuttal environment (#36)

* fix review paper

* fix review paper test

* fix type errors (#36)

* fix test errors

* fix test errors

* update decision making and paper rebuttal (#36)

* fix test errors

* support return bool decision and int score

* fix ruff

* support full testing

* support full testing

* fix one typo

* fix mypy error

* fix pre-commit error

---------

Co-authored-by: timsanders256 <[email protected]>
Co-authored-by: Haofei Yu <[email protected]>
  • Loading branch information
3 people authored May 22, 2024
1 parent b1e8be1 commit ddba8ba
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 54 deletions.
31 changes: 23 additions & 8 deletions research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
communicate_with_multiple_researchers_prompting,
find_collaborators_prompting,
generate_ideas_prompting,
make_review_decision_prompting,
rebut_review_prompting,
review_paper_prompting,
review_score_prompting,
summarize_research_direction_prompting,
summarize_research_field_prompting,
write_paper_abstract_prompting,
Expand Down Expand Up @@ -44,7 +47,7 @@ def get_profile(self, author_name: str) -> Dict[str, Any]:
papers_list, papers_by_year = self._get_papers(entries, author_name)
if len(papers_list) > 40:
papers_list = self._select_papers(papers_by_year, author_name)

# Trim the list to the 10 most recent papers
papers_list = papers_list[:10]

Expand Down Expand Up @@ -159,7 +162,7 @@ def read_paper(
trend_output = trend[0]
return trend_output

def find_collaborators(self, input: Dict[str, str], parameter: float =0.5, max_number: int =3) -> List[str]:
def find_collaborators(self, input: Dict[str, str], parameter: float = 0.5, max_number: int = 3) -> List[str]:
start_author = [self.name]
graph, _, _ = bfs(
author_list=start_author, node_limit=max_number)
Expand Down Expand Up @@ -207,11 +210,23 @@ def write_paper(self, input: List[str], external_data: Dict[str, Dict[str, List[
paper_abstract = write_paper_abstract_prompting(input, external_data)
return paper_abstract[0]

def review_paper(self, input: Dict[str, str], external_data: Dict[str, str]) -> str:
paper_review = review_paper_prompting(input, external_data)
return paper_review[0]
def review_paper(self, external_data: Dict[str, str]) -> Tuple[int, str]:
paper_review = review_paper_prompting(external_data)[0]
print(paper_review)
review_score = review_score_prompting(paper_review)
print(review_score, paper_review)
return review_score, paper_review

def make_review_decision(
self, input: Dict[str, str], external_data: Dict[str, str]
) -> str:
return "accept"
self, submission: Dict[str, str], review: Dict[str, Tuple[int, str]]
) -> Tuple[bool, str]:
meta_review = make_review_decision_prompting(submission, review)
if "accept" in meta_review[0].lower():
review_decision = True
else:
review_decision = False
return review_decision, meta_review[0]

def rebut_review(self, submission: Dict[str, str], review: Dict[str, Tuple[int, str]], decision: Dict[str, Tuple[bool, str]]) -> str:
rebut_review = rebut_review_prompting(submission, review, decision)
return rebut_review[0]
76 changes: 67 additions & 9 deletions research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,78 @@
from typing import Dict
from typing import Dict, Tuple

from .env_base import BaseMultiAgentEnv


class PaperRebuttalMultiAgentEnv(BaseMultiAgentEnv):
def __init__(self, agent_dict: Dict[str, str]) -> None:
super().__init__(agent_dict)
self.turn_number = 0
self.turn_max = 1
self.terminated = False
self.roles: Dict[str, str] = {}
self.submission: Dict[str, str] = {}
self.review = ""
self.decision = ""
self.rebuttal = ""

def assign_roles(self, role_dict: Dict[str, str]) -> None:
self.roles = role_dict

def initialize_submission(self, external_data: Dict[str, str]) -> None:
self.submission = external_data

def submit_review(self, review_dict: Dict[str, Tuple[int, str]]) -> None:
review_serialize = [
f"Reviewer: {name}\nScore: {review[0]}\nReview: {review[1]}" for name, review in review_dict.items()]
self.review = "\n\n".join(review_serialize)

def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None:
decision_count = {"accept": 0, "reject": 0}
for _, decision in decision_dict.items():
if decision[0]:
decision_count["accept"] += 1
else:
decision_count["reject"] += 1
count_max = 0
for d, count in decision_count.items():
if count > count_max:
count_max = count
self.decision = d

def submit_rebuttal(self, rebuttal_dict: Dict[str, str]) -> None:
rebuttal_serialize = [
f"Author: {name}\nRebuttal: {rebuttal}" for name, rebuttal in rebuttal_dict.items()]
self.rebuttal = "\n\n".join(rebuttal_serialize)

def step(self) -> None:
external_data = self.kb.get_data(10, "machine learning")
for agent_name, agent in self.agents.items():
agent.read_paper(external_data=external_data, domain="machine learning")
agent.review_paper({}, {})
agent.make_review_decision({}, {})
# Paper Reviewing
review_dict: Dict[str, Tuple[int, str]] = {}
for name, role in self.roles.items():
if role == "reviewer":
review_dict[name] = self.agents[name].review_paper(
external_data=self.submission)
self.submit_review(review_dict)

# Decision Making
decision_dict: Dict[str, Tuple[bool, str]] = {}
for name, role in self.roles.items():
if role == "reviewer":
decision_dict[name] = self.agents[name].make_review_decision(
submission=self.submission, review=review_dict)
self.submit_decision(decision_dict)

self.submit_rebuttal()
# Rebuttal Submitting
rebuttal_dict: Dict[str, str] = {}
for name, role in self.roles.items():
if role == "author":
rebuttal_dict[name] = self.agents[name].rebut_review(
submission=self.submission,
review=review_dict,
decision=decision_dict)
self.submit_rebuttal(rebuttal_dict)

def submit_rebuttal(self) -> None:
pass
self.turn_number += 1
if self.decision == "accept":
self.terminated = True
if self.turn_number >= self.turn_max:
self.terminated = True
115 changes: 95 additions & 20 deletions research_town/utils/agent_prompting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import openai

Expand Down Expand Up @@ -33,7 +33,8 @@ def get_query_embedding(query: str) -> Any:


def find_nearest_neighbors(data_embeddings: List[Any], query_embedding: Any, num_neighbors: int) -> Any:
neighbors = neiborhood_search(data_embeddings, query_embedding, num_neighbors)
neighbors = neiborhood_search(
data_embeddings, query_embedding, num_neighbors)
neighbors = neighbors.reshape(-1)

return neighbors.tolist()
Expand All @@ -60,18 +61,22 @@ def summarize_research_field_prompting(

query_embedding = get_query_embedding(query)

text_chunks = [abstract for papers in dataset.values() for abstract in papers["abstract"]]
data_embeddings = [embedding for embeddings in data_embedding.values() for embedding in embeddings]
text_chunks = [abstract for papers in dataset.values()
for abstract in papers["abstract"]]
data_embeddings = [embedding for embeddings in data_embedding.values()
for embedding in embeddings]

nearest_indices = find_nearest_neighbors(data_embeddings, query_embedding, num_neighbors=10)
nearest_indices = find_nearest_neighbors(
data_embeddings, query_embedding, num_neighbors=10)
context = [text_chunks[i] for i in nearest_indices]

template_input["papers"] = "; ".join(context)
prompt = query_template.format_map(template_input)

return openai_prompting(llm_model, prompt)

def find_collaborators_prompting(input: Dict[str, str], self_profile: Dict[str, str], collaborator_profiles: Dict[str, str], parameter: float =0.5, max_number: int =3, llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",) -> List[str]:

def find_collaborators_prompting(input: Dict[str, str], self_profile: Dict[str, str], collaborator_profiles: Dict[str, str], parameter: float = 0.5, max_number: int = 3, llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",) -> List[str]:
self_serialize = [
f"Name: {name}\nProfile: {self_profile[name]}" for _, name in enumerate(self_profile.keys())]
self_serialize_all = "\n\n".join(self_serialize)
Expand All @@ -96,6 +101,7 @@ def find_collaborators_prompting(input: Dict[str, str], self_profile: Dict[str,
prompt = prompt_qa.format_map(input)
return openai_prompting(llm_model, prompt)


def generate_ideas_prompting(
trend: str,
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
Expand Down Expand Up @@ -158,53 +164,122 @@ def write_paper_abstract_prompting(
"Here are the external data, which is a list abstracts of related papers: {papers_serialize_all}"
)

template_input = {"ideas_serialize_all": ideas_serialize_all, "papers_serialize_all": papers_serialize_all}
template_input = {"ideas_serialize_all": ideas_serialize_all,
"papers_serialize_all": papers_serialize_all}
prompt = prompt_template.format_map(template_input)
return openai_prompting(llm_model, prompt)

def review_paper_prompting(titles: Dict[str, str], external_data: Dict[str, str], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]:
def review_score_prompting(paper_review: str, llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> int:
prompt_qa = (
"Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score."
"Here are the reviews: {paper_review}"
)
input = {"paper_review": paper_review}
prompt = prompt_qa.format_map(input)
score_str = openai_prompting(llm_model, prompt)
if score_str[0].isdigit():
return int(score_str[0])
else:
return 0

def review_paper_prompting(external_data: Dict[str, str], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]:
"""
Review paper from using list, and external data (published papers)
"""

titles_serialize = []
for _, timestamp in enumerate(titles.keys()):
title_entry = f"Time: {timestamp}\nPaper: {external_data[timestamp]}"
titles_serialize.append(title_entry)
titles_serialize_all = "\n\n".join(titles_serialize)

papers_serialize = []
for _, timestamp in enumerate(external_data.keys()):
paper_entry = f"Time: {timestamp}\nPaper: {external_data[timestamp]}"
paper_entry = f"Title: {timestamp}\nPaper: {external_data[timestamp]}"
papers_serialize.append(paper_entry)
papers_serialize_all = "\n\n".join(papers_serialize)

prompt_qa = (
"Please give some reviews based on the following inputs and external data."
"You might use two or more of these titles if they are related and works well together."
"Here are the titles: {titles_serialize_all}"
"Here are the external data, which is a list of related papers: {papers_serialize_all}"
)

input = {"titles_serialize_all": titles_serialize_all,
"papers_serialize_all": papers_serialize_all}
input = {"papers_serialize_all": papers_serialize_all}

prompt = prompt_qa.format_map(input)
return openai_prompting(llm_model, prompt)


def make_review_decision_prompting(submission: Dict[str, str], review: Dict[str, Tuple[int,str]], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]:
submission_serialize = []
for _, title in enumerate(submission.keys()):
abstract = submission[title]
submission_entry = f"Title: {title}\nAbstract:{abstract}\n"
submission_serialize.append(submission_entry)
submission_serialize_all = "\n\n".join(submission_serialize)

review_serialize = []
for _, name in enumerate(review.keys()):
content = review[name]
review_entry = f"Name: {name}\nContent: {content}\n"
review_serialize.append(review_entry)
review_serialize_all = "\n\n".join(review_serialize)

prompt_template = (
"Please make an review decision to decide whether the following submission should be accepted or rejected by an academic conference. Here are several reviews from reviewers for this submission. Please indicate your review decision as accept or reject."
"Here is the submission: {submission_serialize_all}"
"Here are the reviews: {review_serialize_all}"
)
template_input = {"submission_serialize_all": submission_serialize_all,
"review_serialize_all": review_serialize_all}
prompt = prompt_template.format_map(template_input)
return openai_prompting(llm_model, prompt)


def rebut_review_prompting(submission: Dict[str, str], review: Dict[str, Tuple[int, str]], decision: Dict[str, Tuple[bool, str]], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]:
submission_serialize = []
for _, title in enumerate(submission.keys()):
abstract = submission[title]
submission_entry = f"Title: {title}\nAbstract:{abstract}\n"
submission_serialize.append(submission_entry)
submission_serialize_all = "\n\n".join(submission_serialize)

review_serialize = []
for _, name in enumerate(review.keys()):
content = review[name]
review_entry = f"Name: {name}\nContent: {content}\n"
review_serialize.append(review_entry)
review_serialize_all = "\n\n".join(review_serialize)

decision_serialize = []
for _, name in enumerate(decision.keys()):
content = decision[name]
decision_entry = f"Name: {name}\nDecision: {content}\n"
decision_serialize.append(decision_entry)
decision_serialize_all = "\n\n".join(decision_serialize)

prompt_template = (
"Please write a rebuttal for the following submission you have made to an academic conference. Here are the reviews and decisions from the reviewers. Your rebuttal should rebut the reviews to convince the reviewers to accept your submission."
"Here is the submission: {submission_serialize_all}"
"Here are the reviews: {review_serialize_all}"
"Here are the decisions: {decision_serialize_all}"
)
template_input = {"submission_serialize_all": submission_serialize_all,
"review_serialize_all": review_serialize_all, "decision_serialize_all": decision_serialize_all}
prompt = prompt_template.format_map(template_input)
return openai_prompting(llm_model, prompt)


def communicate_with_multiple_researchers_prompting(
input: Dict[str, str],
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
) -> List[str]:
"""
This is a single-round chat method. One that contains a chat history can better enable
"""
single_round_chat_serialize = [f"Message from researcher named {name}: {message}" for name, message in input.items()]
single_round_chat_serialize = [
f"Message from researcher named {name}: {message}" for name, message in input.items()]
single_round_chat_serialize_all = "\n".join(single_round_chat_serialize)
prompt_template = (
"Please continue in a conversation with other fellow researchers for me, where you will address their concerns in a scholarly way. "
"Here are the messages from other researchers: {single_round_chat_serialize_all}"
)
template_input = {"single_round_chat_serialize_all": single_round_chat_serialize_all}
template_input = {
"single_round_chat_serialize_all": single_round_chat_serialize_all}
prompt = prompt_template.format_map(template_input)
return openai_prompting(llm_model, prompt)
Loading

0 comments on commit ddba8ba

Please sign in to comment.