From 2ed57ce2bdaadee71470dd909c355404a43ee577 Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Sat, 25 May 2024 11:26:19 -0400 Subject: [PATCH] rebuild data type (#70) * rebuild data type * add reserach paper class * change dict to list class * prompter should keep the same without any class in and out * rebuild data type in rebuttal-related functions * update agent list * rebuild data type for test functions for rebuttal-related functions * add constants for testing * add a rebuttal instance * update get profile * rebuild data type for submission-related functions * fix test errors * rebuild data type for test functions for submission-related functions * fix test errors * rebuild the database class * fix some bugs * add some optional param * udpate db class content * fix id to pk * fix pre-commit bug * fix pre-commit bug * fix some mypy error * add a TODO comment * fix part of mypy errors * fix part of mypy errors * fix mypy.ini format * recover test functions * fix pre-commit * add default --------- Co-authored-by: chengzr01 --- mypy.ini | 11 + research_town/agents/agent_base.py | 247 ++++++++++++++---- research_town/dbs/__init__.py | 29 ++ research_town/dbs/agent_db.py | 57 ++++ research_town/dbs/env_db.py | 104 ++++++++ research_town/dbs/paper_db.py | 72 +++++ research_town/dbs/progress_db.py | 75 ++++++ research_town/envs/env_base.py | 15 +- research_town/envs/env_paper_rebuttal.py | 80 +++--- research_town/envs/env_paper_submission.py | 5 +- research_town/kbs/kb_base.py | 29 -- ...author_collector.py => agent_collector.py} | 3 +- research_town/utils/agent_prompter.py | 43 ++- scripts/run.py | 3 - tests/constants.py | 72 +++++ tests/test_agent_base.py | 187 +++++++------ tests/test_db_base.py | 4 + tests/test_envs.py | 28 +- tests/test_kb_base.py | 8 - tests/utils.py | 4 + 20 files changed, 798 insertions(+), 278 deletions(-) create mode 100644 mypy.ini create mode 100644 research_town/dbs/__init__.py create mode 100644 research_town/dbs/agent_db.py create mode 100644 research_town/dbs/env_db.py create mode 100644 research_town/dbs/paper_db.py create mode 100644 research_town/dbs/progress_db.py delete mode 100644 research_town/kbs/kb_base.py rename research_town/utils/{author_collector.py => agent_collector.py} (96%) create mode 100644 tests/constants.py create mode 100644 tests/test_db_base.py delete mode 100644 tests/test_kb_base.py diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..276a8974 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,11 @@ +[mypy] +plugins = pydantic.mypy + +[mypy-arxiv.*] +ignore_missing_imports = True + +[mypy-faiss.*] +ignore_missing_imports = True + +[mypy-transformers.*] +ignore_missing_imports = True diff --git a/research_town/agents/agent_base.py b/research_town/agents/agent_base.py index 4f53843c..eb488eac 100644 --- a/research_town/agents/agent_base.py +++ b/research_town/agents/agent_base.py @@ -1,5 +1,15 @@ +from datetime import datetime from typing import Any, Dict, List, Tuple +from ..dbs import ( + AgentAgentDiscussionLog, + AgentPaperMetaReviewLog, + AgentPaperRebuttalLog, + AgentPaperReviewLog, + AgentProfile, + PaperProfile, +) +from ..utils.agent_collector import bfs from ..utils.agent_prompter import ( communicate_with_multiple_researchers_prompting, find_collaborators_prompting, @@ -8,69 +18,138 @@ rebut_review_prompting, review_paper_prompting, review_score_prompting, - summarize_research_direction_prompting, summarize_research_field_prompting, write_paper_abstract_prompting, ) -from ..utils.author_collector import bfs -from ..utils.paper_collector import get_paper_list class BaseResearchAgent(object): - def __init__(self, name: str) -> None: - self.profile = self.get_profile(name) - self.name = name + def __init__(self, + agent_profile: AgentProfile, + ) -> None: + self.profile: AgentProfile = agent_profile self.memory: Dict[str, str] = {} - def get_profile(self, author_name: str) -> Dict[str, Any]: - papers_list = get_paper_list(author_name) - if papers_list: - personal_info = "; ".join( - [f"{details['Title & Abstract']}" for details in papers_list] - ) - profile_info = summarize_research_direction_prompting(personal_info) - return {"name": author_name, "profile": profile_info[0]} - else: - return {"info": "fail!"} + def get_profile(self, author_name: str) -> AgentProfile: + # TODO: db get based on name + agent_profile = AgentProfile( + name='Geoffrey Hinton', + bio="A researcher in the field of neural network.", + ) + return agent_profile - def communicate(self, message: Dict[str, str]) -> str: - return communicate_with_multiple_researchers_prompting(message)[0] + def communicate( + self, + message: AgentAgentDiscussionLog + ) -> AgentAgentDiscussionLog: + # TODO: find a meaningful key + message_dict: Dict[str, str] = {} + if message.message is not None: + message_dict[message.agent_from_pk] = message.message + message_content = communicate_with_multiple_researchers_prompting( + message_dict + )[0] + discussion_log = AgentAgentDiscussionLog( + timestep=(int)(datetime.now().timestamp()), + agent_from_pk=message.agent_from_pk, + agent_to_pk=message.agent_to_pk, + message=message_content + ) + return discussion_log def read_paper( - self, papers: Dict[str, Dict[str, List[str]]], domain: str + self, + papers: List[PaperProfile], + domain: str ) -> str: + papers_dict: Dict[str, Dict[str, List[str]]] = {} + for paper in papers: + papers_dict[paper.pk] = {} + if paper.abstract is not None: + papers_dict[paper.pk]["abstract"] = [paper.abstract] + if paper.title is not None: + papers_dict[paper.pk]["title"] = [paper.title] + profile: Dict[str, str] = {} + if self.profile.name is not None: + profile["name"] = self.profile.name + if self.profile.bio is not None: + profile["profile"] = self.profile.bio trend = summarize_research_field_prompting( - profile=self.profile, + profile=profile, keywords=[domain], - papers=papers, + papers=papers_dict ) trend_output = trend[0] return trend_output - def find_collaborators(self, input: Dict[str, str], parameter: float = 0.5, max_number: int = 3) -> List[str]: - start_author = [self.name] + def find_collaborators( + self, + paper: PaperProfile, + parameter: float = 0.5, + max_number: int = 3 + ) -> List[AgentProfile]: + start_author: List[str] = [ + self.profile.name] if self.profile.name is not None else [] graph, _, _ = bfs( author_list=start_author, node_limit=max_number) collaborators = list( - {name for pair in graph for name in pair if name != self.name}) - self_profile = {self.name: self.profile["profile"]} - collaborator_profiles = {author: self.get_profile( - author)["profile"] for author in collaborators} + {name for pair in graph for name in pair if name != self.profile.name}) + self_profile: Dict[str, str] = { + self.profile.name: self.profile.bio} if self.profile.name is not None and self.profile.bio is not None else {} + collaborator_profiles: Dict[str, str] = {} + for author in collaborators: + author_bio = self.get_profile(author).bio + if author_bio is not None: + collaborator_profiles[author] = author_bio + paper_serialize: Dict[str, str] = { + paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {} result = find_collaborators_prompting( - input, self_profile, collaborator_profiles, parameter, max_number) - collaborators_list = [ - collaborator for collaborator in collaborators if collaborator in result[0]] + paper_serialize, + self_profile, + collaborator_profiles, + parameter, + max_number + ) + collaborators_list = [] + for collaborator in collaborators: + if collaborator in result: + collaborators_list.append(self.get_profile(collaborator)) return collaborators_list - def get_co_author_relationships(self, name: str, max_node: int) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]: - start_author = [name] + def get_co_author_relationships( + self, + agent_profile: AgentProfile, + max_node: int + ) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]: + start_author: List[str] = [ + self.profile.name] if self.profile.name is not None else [] graph, node_feat, edge_feat = bfs( author_list=start_author, node_limit=max_node) return graph, node_feat, edge_feat def generate_idea( - self, trends: List[str], domain: str + self, + papers: List[PaperProfile], + domain: str ) -> List[str]: + papers_dict: Dict[str, Dict[str, List[str]]] = {} + for paper_profile in papers: + papers_dict[paper_profile.pk] = {} + if paper_profile.abstract is not None: + papers_dict[paper_profile.pk]["abstract"] = [ + paper_profile.abstract] + if paper_profile.title is not None: + papers_dict[paper_profile.pk]["title"] = [paper_profile.title] + profile: Dict[str, str] = {} + if self.profile.name is not None: + profile["name"] = self.profile.name + if self.profile.bio is not None: + profile["profile"] = self.profile.bio + trends = summarize_research_field_prompting( + profile=profile, + keywords=[domain], + papers=papers_dict + ) ideas: List[str] = [] for trend in trends: idea = generate_ideas_prompting(trend)[0] @@ -78,25 +157,91 @@ def generate_idea( return ideas - def write_paper(self, input: List[str], papers: Dict[str, Dict[str, List[str]]]) -> str: - paper_abstract = write_paper_abstract_prompting(input, papers) - return paper_abstract[0] + def write_paper( + self, + research_ideas: List[str], + papers: List[PaperProfile] + ) -> PaperProfile: + papers_dict: Dict[str, Dict[str, List[str]]] = {} + for paper_profile in papers: + papers_dict[paper_profile.pk] = {} + if paper_profile.abstract is not None: + papers_dict[paper_profile.pk]["abstract"] = [ + paper_profile.abstract] + if paper_profile.title is not None: + papers_dict[paper_profile.pk]["title"] = [paper_profile.title] + paper_abstract = write_paper_abstract_prompting( + research_ideas, papers_dict)[0] + paper_profile = PaperProfile(abstract=paper_abstract) + return paper_profile - def review_paper(self, paper: Dict[str, str]) -> Tuple[int, str]: - paper_review = review_paper_prompting(paper)[0] + def review_paper( + self, + paper: PaperProfile + ) -> AgentPaperReviewLog: + paper_dict: Dict[str, str] = { + paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {} + paper_review = review_paper_prompting(paper_dict)[0] review_score = review_score_prompting(paper_review) - return review_score, paper_review + + return AgentPaperReviewLog( + timestep=(int)(datetime.now().timestamp()), + paper_pk=paper.pk, + agent_pk=self.profile.pk, + review_content=paper_review, + review_score=review_score + ) def make_review_decision( - self, submission: Dict[str, str], review: Dict[str, Tuple[int, str]] - ) -> Tuple[bool, str]: - meta_review = make_review_decision_prompting(submission, review) - if "accept" in meta_review[0].lower(): - review_decision = True - else: - review_decision = False - return review_decision, meta_review[0] - - def rebut_review(self, submission: Dict[str, str], review: Dict[str, Tuple[int, str]], decision: Dict[str, Tuple[bool, str]]) -> str: - rebut_review = rebut_review_prompting(submission, review, decision) - return rebut_review[0] + self, + paper: PaperProfile, + review: List[AgentPaperReviewLog] + ) -> AgentPaperMetaReviewLog: + paper_dict: Dict[str, str] = { + paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {} + review_dict: Dict[str, Tuple[int, str]] = {} + for agent_review_log in review: + if agent_review_log.review_score is not None and agent_review_log.review_content is not None: + review_dict[agent_review_log.pk] = ( + agent_review_log.review_score, agent_review_log.review_content) + + meta_review = make_review_decision_prompting(paper_dict, review_dict) + review_decision = "accept" in meta_review[0].lower() + + return AgentPaperMetaReviewLog( + timestep=(int)(datetime.now().timestamp()), + paper_pk=paper.pk, + agent_pk=self.profile.pk, + decision=review_decision, + meta_review=meta_review[0], + ) + + def rebut_review( + self, + paper: PaperProfile, + review: List[AgentPaperReviewLog], + decision: List[AgentPaperMetaReviewLog] + ) -> AgentPaperRebuttalLog: + paper_dict: Dict[str, str] = { + paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {} + review_dict: Dict[str, Tuple[int, str]] = {} + for agent_review_log in review: + if agent_review_log.review_score is not None and agent_review_log.review_content is not None: + review_dict[agent_review_log.pk] = ( + agent_review_log.review_score, agent_review_log.review_content) + + decision_dict: Dict[str, Tuple[bool, str]] = {} + for agent_meta_review_log in decision: + if agent_meta_review_log.decision is not None and agent_meta_review_log.meta_review is not None: + decision_dict[agent_meta_review_log.pk] = ( + agent_meta_review_log.decision, agent_meta_review_log.meta_review) + + rebut_review = rebut_review_prompting( + paper_dict, review_dict, decision_dict)[0] + + return AgentPaperRebuttalLog( + timestep=(int)(datetime.now().timestamp()), + paper_pk=paper.pk, + agent_pk=self.profile.pk, + rebuttal_content=rebut_review + ) diff --git a/research_town/dbs/__init__.py b/research_town/dbs/__init__.py new file mode 100644 index 00000000..3506579e --- /dev/null +++ b/research_town/dbs/__init__.py @@ -0,0 +1,29 @@ +from .agent_db import AgentProfile, AgentProfileDB +from .env_db import ( + AgentAgentDiscussionLog, + AgentPaperMetaReviewLog, + AgentPaperRebuttalLog, + AgentPaperReviewLog, + EnvLogDB, +) +from .paper_db import PaperProfile, PaperProfileDB +from .progress_db import ( + ResearchIdea, + ResearchPaperDraft, + ResearchProgressDB, +) + +__all__ = [ + "AgentAgentDiscussionLog", + "AgentPaperMetaReviewLog", + "AgentPaperRebuttalLog", + "AgentPaperReviewLog", + "PaperProfile", + "AgentProfile", + "ResearchIdea", + "ResearchPaperDraft", + "EnvLogDB", + "PaperProfileDB", + "AgentProfileDB", + "ResearchProgressDB" +] diff --git a/research_town/dbs/agent_db.py b/research_town/dbs/agent_db.py new file mode 100644 index 00000000..5e09ecf5 --- /dev/null +++ b/research_town/dbs/agent_db.py @@ -0,0 +1,57 @@ +import json +import uuid +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class AgentProfile(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + name: Optional[str] = Field(default=None) + bio: Optional[str] = Field(default=None) + + +class AgentProfileDB(object): + def __init__(self) -> None: + self.data: Dict[str, AgentProfile] = {} + + def add(self, agent: AgentProfile) -> None: + self.data[agent.pk] = agent + + def update(self, agent_pk: str, updates: Dict[str, Optional[str]]) -> bool: + if agent_pk in self.data: + for key, value in updates.items(): + if value is not None: + setattr(self.data[agent_pk], key, value) + return True + return False + + def delete(self, agent_pk: str) -> bool: + if agent_pk in self.data: + del self.data[agent_pk] + return True + return False + + def get(self, **conditions: Dict[str, Any]) -> List[AgentProfile]: + result = [] + for agent in self.data.values(): + if all(getattr(agent, key) == value for key, value in conditions.items()): + result.append(agent) + return result + + def save_to_file(self, file_name: str) -> None: + with open(file_name, "w") as f: + json.dump({aid: agent.dict() + for aid, agent in self.data.items()}, f, indent=2) + + def load_from_file(self, file_name: str) -> None: + with open(file_name, "r") as f: + data = json.load(f) + self.data = {aid: AgentProfile(**agent_data) + for aid, agent_data in data.items()} + + def update_db(self, data: Dict[str, List[Dict[str, Any]]]) -> None: + for date, agents in data.items(): + for agent_data in agents: + agent = AgentProfile(**agent_data) + self.add(agent) diff --git a/research_town/dbs/env_db.py b/research_town/dbs/env_db.py new file mode 100644 index 00000000..3603206a --- /dev/null +++ b/research_town/dbs/env_db.py @@ -0,0 +1,104 @@ +import json +import uuid +from typing import Any, Dict, List, Optional, Type, TypeVar + +from pydantic import BaseModel, Field + +T = TypeVar('T', bound=BaseModel) + + +class EnvLogDB: + def __init__(self) -> None: + self.data: Dict[str, List[Any]] = { + "PaperProfile": [], + "AgentPaperReviewLog": [], + "AgentPaperRebuttalLog": [], + "AgentPaperMetaReviewLog": [], + "AgentAgentDiscussionLog": [] + } + + def add(self, obj: T) -> None: + class_name = obj.__class__.__name__ + if class_name in self.data: + self.data[class_name].append(obj.dict()) + else: + raise ValueError(f"Unsupported log type: {class_name}") + + def get(self, cls: Type[T], **conditions: Dict[str, Any]) -> List[T]: + class_name = cls.__name__ + if class_name not in self.data: + raise ValueError(f"Unsupported log type: {class_name}") + result = [] + for data in self.data[class_name]: + instance = cls(**data) + if all(getattr(instance, key) == value for key, value in conditions.items()): + result.append(instance) + return result + + def update(self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, Any]) -> int: + class_name = cls.__name__ + if class_name not in self.data: + raise ValueError(f"Unsupported log type: {class_name}") + updated_count = 0 + for data in self.data[class_name]: + instance = cls(**data) + if all(getattr(instance, key) == value for key, value in conditions.items()): + for key, value in updates.items(): + setattr(instance, key, value) + self.data[class_name].remove(data) + self.data[class_name].append(instance.dict()) + updated_count += 1 + return updated_count + + def delete(self, cls: Type[T], **conditions: Dict[str, Any]) -> int: + class_name = cls.__name__ + if class_name not in self.data: + raise ValueError(f"Unsupported log type: {class_name}") + initial_count = len(self.data[class_name]) + self.data[class_name] = [ + data for data in self.data[class_name] + if not all(getattr(cls(**data), key) == value for key, value in conditions.items()) + ] + return initial_count - len(self.data[class_name]) + + def save_to_file(self, file_name: str) -> None: + with open(file_name, "w") as f: + json.dump(self.data, f, indent=2) + + def load_from_file(self, file_name: str) -> None: + with open(file_name, "r") as f: + self.data = json.load(f) + + +class AgentPaperReviewLog(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + timestep: int = Field(default=0) + paper_pk: str + agent_pk: str + review_score: Optional[int] = Field(default=0) + review_content: Optional[str] = Field(default=None) + + +class AgentPaperRebuttalLog(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + timestep: int = Field(default=0) + paper_pk: str + agent_pk: str + rebuttal_content: Optional[str] = Field(default=None) + + +class AgentPaperMetaReviewLog(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + timestep: int = Field(default=0) + paper_pk: str + agent_pk: str + decision: Optional[bool] = Field(default=False) + meta_review: Optional[str] = Field(default=None) + + +class AgentAgentDiscussionLog(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + timestep: int = Field(default=0) + agent_from_pk: str + agent_to_pk: str + message: str = Field(default=None) diff --git a/research_town/dbs/paper_db.py b/research_town/dbs/paper_db.py new file mode 100644 index 00000000..6b40b328 --- /dev/null +++ b/research_town/dbs/paper_db.py @@ -0,0 +1,72 @@ +import json +import uuid +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from ..utils.paper_collector import get_daily_papers + + +class PaperProfile(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + title: Optional[str] = Field(default=None) + abstract: Optional[str] = Field(default=None) + + +class PaperProfileDB: + def __init__(self) -> None: + self.data: Dict[str, PaperProfile] = {} + + def add_paper(self, paper: PaperProfile) -> None: + self.data[paper.pk] = paper + + def update_paper(self, paper_pk: str, updates: Dict[str, Optional[str]]) -> bool: + if paper_pk in self.data: + for key, value in updates.items(): + if value is not None: + setattr(self.data[paper_pk], key, value) + return True + return False + + def get_paper(self, paper_pk: str) -> Optional[PaperProfile]: + return self.data.get(paper_pk) + + def delete_paper(self, paper_pk: str) -> bool: + if paper_pk in self.data: + del self.data[paper_pk] + return True + return False + + def query_papers(self, **conditions: Dict[str, Any]) -> List[PaperProfile]: + result = [] + for paper in self.data.values(): + if all(getattr(paper, key) == value for key, value in conditions.items()): + result.append(paper) + return result + + def save_to_file(self, file_name: str) -> None: + with open(file_name, "w") as f: + json.dump({pk: paper.dict() + for pk, paper in self.data.items()}, f, indent=2) + + def load_from_file(self, file_name: str) -> None: + with open(file_name, "r") as f: + data = json.load(f) + self.data = {pk: PaperProfile(**paper_data) + for pk, paper_data in data.items()} + + def update_db(self, data: Dict[str, List[Dict[str, Any]]]) -> None: + for date, papers in data.items(): + for paper_data in papers: + paper = PaperProfile(**paper_data) + self.add_paper(paper) + + def fetch_and_add_papers(self, num: int, domain: str) -> None: + data, _ = get_daily_papers(domain, query=domain, max_results=num) + transformed_data = {} + for date, value in data.items(): + papers = [] + papers.append({"abstract": value["abstract"]}) + papers.append({"info": value["info"]}) + transformed_data[date] = papers + self.update_db(transformed_data) diff --git a/research_town/dbs/progress_db.py b/research_town/dbs/progress_db.py new file mode 100644 index 00000000..2d5efdb6 --- /dev/null +++ b/research_town/dbs/progress_db.py @@ -0,0 +1,75 @@ +import json +import uuid +from typing import Any, Dict, List, Optional, Type, TypeVar + +from pydantic import BaseModel, Field + +T = TypeVar('T', bound=BaseModel) + +class ResearchProgressDB: + def __init__(self) -> None: + self.data: Dict[str, List[Any]] = { + "ResearchIdea": [], + "ResearchPaper": [] + } + + def add(self, obj: T) -> None: + class_name = obj.__class__.__name__ + if class_name in self.data: + self.data[class_name].append(obj.dict()) + else: + raise ValueError(f"Unsupported type: {class_name}") + + def get(self, cls: Type[T], **conditions: Dict[str, Any]) -> List[T]: + class_name = cls.__name__ + if class_name not in self.data: + raise ValueError(f"Unsupported type: {class_name}") + result = [] + for data in self.data[class_name]: + instance = cls(**data) + if all(getattr(instance, key) == value for key, value in conditions.items()): + result.append(instance) + return result + + def update(self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, Any]) -> int: + class_name = cls.__name__ + if class_name not in self.data: + raise ValueError(f"Unsupported type: {class_name}") + updated_count = 0 + for data in self.data[class_name]: + instance = cls(**data) + if all(getattr(instance, key) == value for key, value in conditions.items()): + for key, value in updates.items(): + setattr(instance, key, value) + self.data[class_name].remove(data) + self.data[class_name].append(instance.dict()) + updated_count += 1 + return updated_count + + def delete(self, cls: Type[T], **conditions: Dict[str, Any]) -> int: + class_name = cls.__name__ + if class_name not in self.data: + raise ValueError(f"Unsupported type: {class_name}") + initial_count = len(self.data[class_name]) + self.data[class_name] = [ + data for data in self.data[class_name] + if not all(getattr(cls(**data), key) == value for key, value in conditions.items()) + ] + return initial_count - len(self.data[class_name]) + + def save_to_file(self, file_name: str) -> None: + with open(file_name, "w") as f: + json.dump(self.data, f, indent=2) + + def load_from_file(self, file_name: str) -> None: + with open(file_name, "r") as f: + self.data = json.load(f) + +class ResearchIdea(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + content: Optional[str] = Field(default=None) + +class ResearchPaperDraft(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + title: Optional[str] = Field(default=None) + abstract: Optional[str] = Field(default=None) diff --git a/research_town/envs/env_base.py b/research_town/envs/env_base.py index eab9b9d3..d87e5baf 100644 --- a/research_town/envs/env_base.py +++ b/research_town/envs/env_base.py @@ -1,15 +1,16 @@ -from typing import Dict +from typing import List from ..agents.agent_base import BaseResearchAgent -from ..kbs.kb_base import BaseKnowledgeBase +from ..dbs import AgentProfile, EnvLogDB class BaseMultiAgentEnv(object): - def __init__(self, agent_dict: Dict[str, str]) -> None: - self.agents: Dict[str, BaseResearchAgent] = {} - self.kb = BaseKnowledgeBase() - for agent_name, agent in agent_dict.items(): - self.agents[agent_name] = BaseResearchAgent(agent) + def __init__(self, agent_profiles: List[AgentProfile]) -> None: + self.agent_profiles: List[AgentProfile] = agent_profiles + self.db = EnvLogDB() + self.agents: List[BaseResearchAgent] = [] + for agent_profile in agent_profiles: + self.agents.append(BaseResearchAgent(agent_profile)) def step(self) -> None: raise NotImplementedError diff --git a/research_town/envs/env_paper_rebuttal.py b/research_town/envs/env_paper_rebuttal.py index 2c398460..3a9e6989 100644 --- a/research_town/envs/env_paper_rebuttal.py +++ b/research_town/envs/env_paper_rebuttal.py @@ -1,30 +1,35 @@ -from typing import Dict, Tuple +from typing import Dict, List, Tuple +from ..dbs import ( + AgentPaperMetaReviewLog, + AgentPaperRebuttalLog, + AgentPaperReviewLog, + AgentProfile, + PaperProfile, +) from .env_base import BaseMultiAgentEnv class PaperRebuttalMultiAgentEnv(BaseMultiAgentEnv): - def __init__(self, agent_dict: Dict[str, str]) -> None: - super().__init__(agent_dict) + def __init__(self, agent_profiles: List[AgentProfile]) -> None: + super().__init__(agent_profiles) self.turn_number = 0 self.turn_max = 1 self.terminated = False - self.roles: Dict[str, str] = {} - self.submission: Dict[str, str] = {} - self.review = "" - self.decision = "" - self.rebuttal = "" + self.decision = "reject" + self.submission = PaperProfile() + self.reviewer_mask = [False] * len(agent_profiles) + self.review: List[AgentPaperReviewLog] = [] + self.rebuttal: List[AgentPaperRebuttalLog] = [] + self.meta_review: List[AgentPaperMetaReviewLog] = [] def assign_roles(self, role_dict: Dict[str, str]) -> None: - self.roles = role_dict + for index, agent_profile in enumerate(self.agent_profiles): + if role_dict[agent_profile.pk] == "reviewer": + self.reviewer_mask[index] = True - def initialize_submission(self, external_data: Dict[str, str]) -> None: - self.submission = external_data - - def submit_review(self, review_dict: Dict[str, Tuple[int, str]]) -> None: - review_serialize = [ - f"Reviewer: {name}\nScore: {review[0]}\nReview: {review[1]}" for name, review in review_dict.items()] - self.review = "\n\n".join(review_serialize) + def initialize_submission(self, paper_profile: PaperProfile) -> None: + self.submission = paper_profile def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None: decision_count = {"accept": 0, "reject": 0} @@ -39,40 +44,27 @@ def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None: count_max = count self.decision = d - def submit_rebuttal(self, rebuttal_dict: Dict[str, str]) -> None: - rebuttal_serialize = [ - f"Author: {name}\nRebuttal: {rebuttal}" for name, rebuttal in rebuttal_dict.items()] - self.rebuttal = "\n\n".join(rebuttal_serialize) - def step(self) -> None: # Paper Reviewing - review_dict: Dict[str, Tuple[int, str]] = {} - for name, role in self.roles.items(): - if role == "reviewer": - review_dict[name] = self.agents[name].review_paper( - paper=self.submission) - self.submit_review(review_dict) + for index, agent in enumerate(self.agents): + if self.reviewer_mask[index]: + self.review.append(agent.review_paper( + paper=self.submission)) - # Decision Making - decision_dict: Dict[str, Tuple[bool, str]] = {} - for name, role in self.roles.items(): - if role == "reviewer": - decision_dict[name] = self.agents[name].make_review_decision( - submission=self.submission, review=review_dict) - self.submit_decision(decision_dict) + # Paper Meta Reviewing + for index, agent in enumerate(self.agents): + if self.reviewer_mask[index]: + self.meta_review.append(agent.make_review_decision( + paper=self.submission, review=self.review)) # Rebuttal Submitting - rebuttal_dict: Dict[str, str] = {} - for name, role in self.roles.items(): - if role == "author": - rebuttal_dict[name] = self.agents[name].rebut_review( - submission=self.submission, - review=review_dict, - decision=decision_dict) - self.submit_rebuttal(rebuttal_dict) + for index, agent in enumerate(self.agents): + if self.reviewer_mask[index]: + self.rebuttal.append(agent.rebut_review( + paper=self.submission, + review=self.review, + decision=self.meta_review)) self.turn_number += 1 - if self.decision == "accept": - self.terminated = True if self.turn_number >= self.turn_max: self.terminated = True diff --git a/research_town/envs/env_paper_submission.py b/research_town/envs/env_paper_submission.py index 379e4bf9..9556e389 100644 --- a/research_town/envs/env_paper_submission.py +++ b/research_town/envs/env_paper_submission.py @@ -1,7 +1,6 @@ -from typing import Dict - -from research_town.agents.agent_base import BaseResearchAgent +from typing import List +from ..dbs import AgentProfile, PaperProfile from .env_base import BaseMultiAgentEnv diff --git a/research_town/kbs/kb_base.py b/research_town/kbs/kb_base.py deleted file mode 100644 index f6bc3ce9..00000000 --- a/research_town/kbs/kb_base.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Dict, List - -from ..utils.paper_collector import get_daily_papers - - -class BaseKnowledgeBase(object): - def __init__(self) -> None: - self.data: Dict[str, str] = {} - - def update_kb(self, data: Dict[str, str]) -> None: - self.data.update(data) - - def add_data(self, data: Dict[str, str]) -> None: - self.data.update(data) - - def get_data(self, num: int, domain: str) -> Dict[str, Dict[str, List[str]]]: - data_collector = [] - keywords = dict() - keywords[domain] = domain - - for topic, keyword in keywords.items(): - data, _ = get_daily_papers(topic, query=keyword, max_results=num) - data_collector.append(data) - data_dict = {} - for data in data_collector: - for time in data.keys(): - papers = data[time] - data_dict[time] = papers - return data_dict diff --git a/research_town/utils/author_collector.py b/research_town/utils/agent_collector.py similarity index 96% rename from research_town/utils/author_collector.py rename to research_town/utils/agent_collector.py index 04b24865..0c26e418 100644 --- a/research_town/utils/author_collector.py +++ b/research_town/utils/agent_collector.py @@ -32,7 +32,8 @@ def co_author_frequency( def co_author_filter(co_authors: Dict[str, int], limit: int = 5) -> List[str]: - co_author_list = sorted(co_authors.items(), key=lambda p: p[1], reverse=True) + co_author_list = sorted( + co_authors.items(), key=lambda p: p[1], reverse=True) return [name for name, _ in co_author_list[:limit]] diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index a4f1125b..7ab4d820 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -1,25 +1,10 @@ from typing import Dict, List, Optional, Tuple - -# ## SET MAX TOKENS - via completion() -# response = litellm.completion( -# model="gpt-3.5-turbo", -# messages=[{ "content": "Hello, how are you?","role": "user"}], -# max_tokens=10 -# ) import litellm from .decorator import exponential_backoff from .paper_collector import get_related_papers -# use litellm as our model router. Supported Provider List: https://docs.litellm.ai/docs/providers . -# Example of litellm usage: -# # set env variables -# os.environ["OPENAI_API_KEY"] = "your-openai-key" - - - -# Todo(jinwei): we could add more selections of input params. CHECK the input params supported here: https://docs.litellm.ai/docs/completion/input @exponential_backoff(retries=5, base_wait_time=1) def model_prompting( llm_model: str, @@ -61,7 +46,7 @@ def summarize_research_field_prompting( query = query_template.format_map(template_input) corpus = [abstract for papers in papers.values() - for abstract in papers["abstract"]] + for abstract in papers["abstract"]] related_papers = get_related_papers(corpus, query, num=10) @@ -174,6 +159,7 @@ def write_paper_abstract_prompting( prompt = prompt_template.format_map(template_input) return model_prompting(llm_model, prompt) + def review_score_prompting(paper_review: str, llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> int: prompt_qa = ( "Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score." @@ -187,14 +173,15 @@ def review_score_prompting(paper_review: str, llm_model: Optional[str] = "mistra else: return 0 -def review_paper_prompting(external_data: Dict[str, str], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]: + +def review_paper_prompting(paper: Dict[str, str], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]: """ Review paper from using list, and external data (published papers) """ papers_serialize = [] - for _, timestamp in enumerate(external_data.keys()): - paper_entry = f"Title: {timestamp}\nPaper: {external_data[timestamp]}" + for _, title in enumerate(paper.keys()): + paper_entry = f"Title: {title}\nPaper: {paper[title]}" papers_serialize.append(paper_entry) papers_serialize_all = "\n\n".join(papers_serialize) @@ -210,13 +197,13 @@ def review_paper_prompting(external_data: Dict[str, str], llm_model: Optional[s return model_prompting(llm_model, prompt) -def make_review_decision_prompting(submission: Dict[str, str], review: Dict[str, Tuple[int,str]], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]: - submission_serialize = [] - for _, title in enumerate(submission.keys()): - abstract = submission[title] - submission_entry = f"Title: {title}\nAbstract:{abstract}\n" - submission_serialize.append(submission_entry) - submission_serialize_all = "\n\n".join(submission_serialize) +def make_review_decision_prompting(paper: Dict[str, str], review: Dict[str, Tuple[int, str]], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]: + paper_serialize = [] + for _, title in enumerate(paper.keys()): + abstract = paper[title] + paper_entry = f"Title: {title}\nAbstract:{abstract}\n" + paper_serialize.append(paper_entry) + paper_serialize_all = "\n\n".join(paper_serialize) review_serialize = [] for _, name in enumerate(review.keys()): @@ -227,10 +214,10 @@ def make_review_decision_prompting(submission: Dict[str, str], review: Dict[str, prompt_template = ( "Please make an review decision to decide whether the following submission should be accepted or rejected by an academic conference. Here are several reviews from reviewers for this submission. Please indicate your review decision as accept or reject." - "Here is the submission: {submission_serialize_all}" + "Here is the submission: {paper_serialize_all}" "Here are the reviews: {review_serialize_all}" ) - template_input = {"submission_serialize_all": submission_serialize_all, + template_input = {"paper_serialize_all": paper_serialize_all, "review_serialize_all": review_serialize_all} prompt = prompt_template.format_map(template_input) return model_prompting(llm_model, prompt) diff --git a/scripts/run.py b/scripts/run.py index 7fd24d5b..e69de29b 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -1,3 +0,0 @@ -from research_town.agents.agent_base import BaseResearchAgent - -agent = BaseResearchAgent("Jiaxuan You") diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 00000000..cbce54d0 --- /dev/null +++ b/tests/constants.py @@ -0,0 +1,72 @@ + +from research_town.dbs import ( + AgentAgentDiscussionLog, + AgentPaperMetaReviewLog, + AgentPaperRebuttalLog, + AgentPaperReviewLog, + AgentProfile, + PaperProfile, + ResearchIdea, +) + +paper_profile_A = PaperProfile( + title="A Survey on Machine Learning", + abstract="This paper surveys the field of machine learning.", +) + +paper_profile_B = PaperProfile( + title="A Survey on Graph Neural Networks", + abstract="This paper surveys the field of graph neural networks.", +) + + +agent_profile_A = AgentProfile( + name="Jiaxuan You", + bio="A researcher in the field of machine learning.", +) + +agent_profile_B = AgentProfile( + name="Rex Ying", + bio="A researcher in the field of GNN.", +) + +research_idea = ResearchIdea( + content="A new idea", +) + +agent_agent_discussion_log = AgentAgentDiscussionLog( + timestep=0, + agent_from_pk=agent_profile_A.pk, + agent_to_pk=agent_profile_B.pk, + message="good morning", +) + +agent_paper_review_log = AgentPaperReviewLog( + timestep=0, + paper_pk=paper_profile_A.pk, + agent_pk=agent_profile_A.pk, + review_score=5, + review_content="This paper is well-written.", +) + +agent_paper_meta_review_log = AgentPaperMetaReviewLog( + timestep=0, + paper_pk=paper_profile_B.pk, + agent_pk=agent_profile_B.pk, + decision=True, + meta_review="This paper is well-written.", +) + +agent_paper_rebuttal_log = AgentPaperRebuttalLog( + timestep=0, + paper_pk=paper_profile_A.pk, + agent_pk=agent_profile_A.pk, + rebuttal_content="I have revised the paper.", +) + +agent_agent_discussion_log = AgentAgentDiscussionLog( + timestep=0, + agent_from_pk=agent_profile_A.pk, + agent_to_pk=agent_profile_B.pk, + message="How about the idea of building a research town with language agents?" +) diff --git a/tests/test_agent_base.py b/tests/test_agent_base.py index f1e1e50c..b4512cb5 100644 --- a/tests/test_agent_base.py +++ b/tests/test_agent_base.py @@ -1,112 +1,78 @@ from unittest.mock import MagicMock, patch from research_town.agents.agent_base import BaseResearchAgent +from research_town.dbs import AgentPaperRebuttalLog +from tests.constants import ( + agent_agent_discussion_log, + agent_profile_A, + agent_profile_B, + paper_profile_A, + paper_profile_B, +) from tests.utils import mock_papers, mock_prompting -@patch("research_town.utils.agent_prompter.model_prompting") -def test_get_profile(mock_model_prompting: MagicMock) -> None: - mock_prompting = MagicMock() - mock_prompting.return_value = [ - "I am a research agent who is interested in machine learning."] - - mock_model_prompting.return_value = mock_prompting - - research_agent = BaseResearchAgent("Jiaxuan You") - profile = research_agent.profile - assert profile["name"] == "Jiaxuan You" - assert "profile" in profile.keys() - -@patch("research_town.utils.agent_prompter.model_prompting") -def test_make_review_decision(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = [ - "Accept. This is a good paper."] +def test_get_profile() -> None: + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + assert research_agent.profile.name == "Jiaxuan You" + assert research_agent.profile.bio == "A researcher in the field of machine learning." - research_agent = BaseResearchAgent("Jiaxuan You") - submission = {"MambaOut: Do We Really Need Mamba for Vision?": "Mamba, an architecture with RNN-like token mixer of state space model (SSM), was recently introduced to address the quadratic complexity of the attention mechanism and subsequently applied to vision tasks. Nevertheless, the performance of Mamba for vision is often underwhelming when compared with convolutional and attention-based models. In this paper, we delve into the essence of Mamba, and conceptually conclude that Mamba is ideally suited for tasks with long-sequence and autoregressive characteristics. For vision tasks, as image classification does not align with either characteristic, we hypothesize that Mamba is not necessary for this task; Detection and segmentation tasks are also not autoregressive, yet they adhere to the long-sequence characteristic, so we believe it is still worthwhile to explore Mamba's potential for these tasks. To empirically verify our hypotheses, we construct a series of models named \\emph{MambaOut} through stacking Mamba blocks while removing their core token mixer, SSM. Experimental results strongly support our hypotheses. Specifically, our MambaOut model surpasses all visual Mamba models on ImageNet image classification, indicating that Mamba is indeed unnecessary for this task. As for detection and segmentation, MambaOut cannot match the performance of state-of-the-art visual Mamba models, demonstrating the potential of Mamba for long-sequence visual tasks."} - review = research_agent.review_paper(paper=submission) - review_decision, meta_review = research_agent.make_review_decision( - submission=submission, review={"Jiaxuan You": review}) - assert review_decision is True - assert meta_review == "Accept. This is a good paper." -@patch("research_town.utils.agent_prompter.model_prompting") -def test_review_paper(mock_model_prompting: MagicMock) -> None: - - mock_model_prompting.side_effect = mock_prompting - - research_agent = BaseResearchAgent("Jiaxuan You") - score, review = research_agent.review_paper(paper={"MambaOut: Do We Really Need Mamba for Vision?": "Mamba, an architecture with RNN-like token mixer of state space model (SSM), was recently introduced to address the quadratic complexity of the attention mechanism and subsequently applied to vision tasks. Nevertheless, the performance of Mamba for vision is often underwhelming when compared with convolutional and attention-based models. In this paper, we delve into the essence of Mamba, and conceptually conclude that Mamba is ideally suited for tasks with long-sequence and autoregressive characteristics. For vision tasks, as image classification does not align with either characteristic, we hypothesize that Mamba is not necessary for this task; Detection and segmentation tasks are also not autoregressive, yet they adhere to the long-sequence characteristic, so we believe it is still worthwhile to explore Mamba's potential for these tasks. To empirically verify our hypotheses, we construct a series of models named \\emph{MambaOut} through stacking Mamba blocks while removing their core token mixer, SSM. Experimental results strongly support our hypotheses. Specifically, our MambaOut model surpasses all visual Mamba models on ImageNet image classification, indicating that Mamba is indeed unnecessary for this task. As for detection and segmentation, MambaOut cannot match the performance of state-of-the-art visual Mamba models, demonstrating the potential of Mamba for long-sequence visual tasks."}) - print(score, review) - assert score == 2 - assert review == "This is a paper review for MambaOut." - - -# ========================================================= -# !IMPORTANT! -# patch should not add path that it comes from -# patch should add path that the function is used -# ========================================================= -@patch("research_town.utils.agent_prompter.model_prompting") +@patch("research_town.utils.agent_prompter.openai_prompting") @patch("research_town.utils.agent_prompter.get_related_papers") def test_generate_idea( mock_get_related_papers: MagicMock, mock_model_prompting: MagicMock, ) -> None: - - - # Configure the mocks mock_get_related_papers.side_effect = mock_papers mock_model_prompting.side_effect = mock_prompting - research_agent = BaseResearchAgent("Jiaxuan You") - trend = research_agent.read_paper({"2024-04": {"abstract": ["Believable proxies of human behavior can empower interactive applications ranging from immersive environments to rehearsal spaces for interpersonal communication to prototyping tools. In this paper, we introduce generative agents--computational software agents that simulate believable human behavior. Generative agents wake up, cook breakfast, and head to work; artists paint, while authors write; they form opinions, notice each other, and initiate conversations; they remember and reflect on days past as they plan the next day. To enable generative agents, we describe an architecture that extends a large language model to store a complete record of the agent's experiences using natural language, synthesize those memories over time into higher-level reflections, and retrieve them dynamically to plan behavior. We instantiate generative agents to populate an interactive sandbox environment inspired by The Sims, where end users can interact with a small town of twenty five agents using natural language. In an evaluation, these generative agents produce believable individual and emergent social behaviors: for example, starting with only a single user-specified notion that one agent wants to throw a Valentine's Day party, the agents autonomously spread invitations to the party over the next two days, make new acquaintances, ask each other out on dates to the party, and coordinate to show up for the party together at the right time. We demonstrate through ablation that the components of our agent architecture--observation, planning, and reflection--each contribute critically to the believability of agent behavior. By fusing large language models with computational, interactive agents, this work introduces architectural and interaction patterns for enabling believable simulations of human behavior. "]}}, domain="machine learning") - trends = [trend] - ideas = research_agent.generate_idea(trends, domain="machine learning") + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + ideas = research_agent.generate_idea( + papers=[paper_profile_A, paper_profile_B], + domain="machine learning" + ) + assert ideas == ["This is a research idea."] - assert isinstance(ideas, list) - assert len(ideas) > 0 -@patch("research_town.utils.agent_prompter.model_prompting") -def test_communicate(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = [ - "I believe in the potential of using automous agents to simulate the current research pipeline."] +@patch("research_town.utils.agent_prompter.openai_prompting") +def test_communicate(mock_openai_prompting: MagicMock) -> None: + mock_openai_prompting.return_value = [ + "I believe in the potential of using automous agents to simulate the current research pipeline." + ] - research_agent = BaseResearchAgent("Jiaxuan You") - response = research_agent.communicate( - {"Alice": "I believe in the potential of using automous agents to simulate the current research pipeline."}) - assert isinstance(response, str) - assert response != "" + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + response = research_agent.communicate(agent_agent_discussion_log) + assert response.message == "I believe in the potential of using automous agents to simulate the current research pipeline." + assert response.agent_to_pk is not None + assert response.agent_from_pk is not None + assert response.timestep >= 0 + assert response.pk is not None -@patch("research_town.utils.agent_prompter.model_prompting") -def test_write_paper_abstract(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = ["Believable proxies of human behavior can empower interactive applications ranging from immersive environments to rehearsal spaces for interpersonal communication to prototyping tools. In this paper, we introduce generative agents--computational software agents that simulate believable human behavior. Generative agents wake up, cook breakfast, and head to work; artists paint, while authors write; they form opinions, notice each other, and initiate conversations; they remember and reflect on days past as they plan the next day. To enable generative agents, we describe an architecture that extends a large language model to store a complete record of the agent's experiences using natural language, synthesize those memories over time into higher-level reflections, and retrieve them dynamically to plan behavior. We instantiate generative agents to populate an interactive sandbox environment inspired by The Sims, where end users can interact with a small town of twenty five agents using natural language. In an evaluation, these generative agents produce believable individual and emergent social behaviors: for example, starting with only a single user-specified notion that one agent wants to throw a Valentine's Day party, the agents autonomously spread invitations to the party over the next two days, make new acquaintances, ask each other out on dates to the party, and coordinate to show up for the party together at the right time. We demonstrate through ablation that the components of our agent architecture--observation, planning, and reflection--each contribute critically to the believability of agent behavior. By fusing large language models with computational, interactive agents, this work introduces architectural and interaction patterns for enabling believable simulations of human behavior. "] - - research_agent = BaseResearchAgent("Jiaxuan You") - abstract = research_agent.write_paper(["We can simulate the scientific research pipeline with agents."], {"2024-04": {"abstract": ["Believable proxies of human behavior can empower interactive applications ranging from immersive environments to rehearsal spaces for interpersonal communication to prototyping tools. In this paper, we introduce generative agents--computational software agents that simulate believable human behavior. Generative agents wake up, cook breakfast, and head to work; artists paint, while authors write; they form opinions, notice each other, and initiate conversations; they remember and reflect on days past as they plan the next day. To enable generative agents, we describe an architecture that extends a large language model to store a complete record of the agent's experiences using natural language, synthesize those memories over time into higher-level reflections, and retrieve them dynamically to plan behavior. We instantiate generative agents to populate an interactive sandbox environment inspired by The Sims, where end users can interact with a small town of twenty five agents using natural language. In an evaluation, these generative agents produce believable individual and emergent social behaviors: for example, starting with only a single user-specified notion that one agent wants to throw a Valentine's Day party, the agents autonomously spread invitations to the party over the next two days, make new acquaintances, ask each other out on dates to the party, and coordinate to show up for the party together at the right time. We demonstrate through ablation that the components of our agent architecture--observation, planning, and reflection--each contribute critically to the believability of agent behavior. By fusing large language models with computational, interactive agents, this work introduces architectural and interaction patterns for enabling believable simulations of human behavior. "]}}) - assert isinstance(abstract, str) - assert abstract != "" - -# ========================================================= -# !IMPORTANT! -# patch should not add path that it comes from -# patch should add path that the function is used -# ========================================================= -@patch("research_town.utils.agent_prompter.model_prompting") +@patch("research_town.utils.agent_prompter.openai_prompting") +def test_write_paper(mock_openai_prompting: MagicMock) -> None: + mock_openai_prompting.return_value = ["This is a paper abstract."] + + research_agent = BaseResearchAgent(agent_profile=agent_profile_B) + paper = research_agent.write_paper( + ["We can simulate the scientific research pipeline with agents."], [paper_profile_A]) + assert paper.abstract == "This is a paper abstract." + assert paper.pk is not None + + +@patch("research_town.utils.agent_prompter.openai_prompting") @patch("research_town.utils.agent_prompter.get_related_papers") def test_read_paper( mock_get_related_papers: MagicMock, mock_model_prompting: MagicMock, ) -> None: mock_get_related_papers.side_effect = mock_papers - mock_model_prompting.side_effect = mock_prompting - - papers = {"2021-01-01": {"abstract": ["This is a paper"]}} + mock_openai_prompting.side_effect = mock_prompting domain = "machine learning" - research_agent = BaseResearchAgent("Jiaxuan You") - summary = research_agent.read_paper(papers, domain) - assert isinstance(summary, str) + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + summary = research_agent.read_paper([paper_profile_A], domain) + assert summary == "Graph Neural Network" @patch("research_town.utils.agent_prompter.model_prompting") @@ -114,21 +80,50 @@ def test_find_collaborators(mock_model_prompting: MagicMock) -> None: mock_model_prompting.return_value = [ "These are collaborators including Jure Leskovec, Rex Ying, Saining Xie, Kaiming He."] - research_agent = BaseResearchAgent("Jiaxuan You") + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) collaborators = research_agent.find_collaborators( - input={"11 May 2024": "Organize a workshop on how far are we from AGI (artificial general intelligence) at ICLR 2024. This workshop aims to become a melting pot for ideas, discussions, and debates regarding our proximity to AGI."}, parameter=0.5, max_number=3) + paper=paper_profile_A, parameter=0.5, max_number=3) assert isinstance(collaborators, list) + assert len(collaborators) <= 3 -@patch("research_town.utils.agent_prompter.model_prompting") -def test_rebut_review(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = [ - "This is a paper rebuttal"] - - research_agent = BaseResearchAgent("Jiaxuan You") - submission = {"MambaOut: Do We Really Need Mamba for Vision?": "Mamba, an architecture with RNN-like token mixer of state space model (SSM), was recently introduced to address the quadratic complexity of the attention mechanism and subsequently applied to vision tasks. Nevertheless, the performance of Mamba for vision is often underwhelming when compared with convolutional and attention-based models. In this paper, we delve into the essence of Mamba, and conceptually conclude that Mamba is ideally suited for tasks with long-sequence and autoregressive characteristics. For vision tasks, as image classification does not align with either characteristic, we hypothesize that Mamba is not necessary for this task; Detection and segmentation tasks are also not autoregressive, yet they adhere to the long-sequence characteristic, so we believe it is still worthwhile to explore Mamba's potential for these tasks. To empirically verify our hypotheses, we construct a series of models named \\emph{MambaOut} through stacking Mamba blocks while removing their core token mixer, SSM. Experimental results strongly support our hypotheses. Specifically, our MambaOut model surpasses all visual Mamba models on ImageNet image classification, indicating that Mamba is indeed unnecessary for this task. As for detection and segmentation, MambaOut cannot match the performance of state-of-the-art visual Mamba models, demonstrating the potential of Mamba for long-sequence visual tasks."} - review = research_agent.review_paper(paper=submission) - review_decision = research_agent.make_review_decision( - submission=submission, review={"Jiaxuan You": review}) - rebut_review = research_agent.rebut_review(submission=submission, review={ - "Jiaxuan You": review}, decision={"Jiaxuan You": review_decision}) - assert isinstance(rebut_review, str) + +@patch("research_town.utils.agent_prompter.openai_prompting") +def test_make_review_decision(mock_openai_prompting: MagicMock) -> None: + mock_openai_prompting.return_value = [ + "Accept. This is a good paper."] + + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + review = research_agent.review_paper(paper=paper_profile_A) + decision = research_agent.make_review_decision( + paper=paper_profile_A, review=[review]) + assert decision.decision is True + assert decision.meta_review == "Accept. This is a good paper." + assert decision.timestep >= 0 + assert decision.pk is not None + + +@patch("research_town.utils.agent_prompter.openai_prompting") +def test_review_paper(mock_openai_prompting: MagicMock) -> None: + mock_openai_prompting.side_effect = mock_prompting + + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + review = research_agent.review_paper(paper=paper_profile_A) + assert review.review_score == 2 + assert review.review_content == "This is a paper review for MambaOut." + + +@patch("research_town.utils.agent_prompter.openai_prompting") +def test_rebut_review(mock_openai_prompting: MagicMock) -> None: + mock_openai_prompting.return_value = [ + "This is a paper rebuttal."] + + research_agent = BaseResearchAgent(agent_profile=agent_profile_A) + review = research_agent.review_paper(paper=paper_profile_A) + decision = research_agent.make_review_decision( + paper=paper_profile_A, review=[review]) + rebuttal = research_agent.rebut_review( + paper=paper_profile_A, review=[review], decision=[decision]) + assert isinstance(rebuttal, AgentPaperRebuttalLog) + if rebuttal.rebuttal_content is not None: + assert len(rebuttal.rebuttal_content) > 0 + assert rebuttal.rebuttal_content == "This is a paper rebuttal." diff --git a/tests/test_db_base.py b/tests/test_db_base.py new file mode 100644 index 00000000..3df75a61 --- /dev/null +++ b/tests/test_db_base.py @@ -0,0 +1,4 @@ + + +def test_get_data() -> None: + pass diff --git a/tests/test_envs.py b/tests/test_envs.py index 96b1c981..94a12cd9 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,8 +3,10 @@ from research_town.envs.env_paper_rebuttal import ( PaperRebuttalMultiAgentEnv, ) -from research_town.envs.env_paper_submission import ( - PaperSubmissionMultiAgentEnvironment, +from tests.constants import ( + agent_profile_A, + agent_profile_B, + paper_profile_A, ) @@ -12,15 +14,25 @@ def test_paper_rebuttal_env(mock_model_prompting: MagicMock) -> None: mock_model_prompting.return_value = [ "Paper Rebuttal Environment."] - env = PaperRebuttalMultiAgentEnv(agent_dict={"Jiaxuan You": "Jiaxuan You", "Rex Ying": "Rex Ying", "Jure Leskovec": "Jure Leskovec", "Christos Faloutsos": "Christos Faloutsos"}) - env.assign_roles({"Jiaxuan You": "author", "Rex Ying": "author", - "Jure Leskovec": "reviewer", "Christos Faloutsos": "reviewer"}) - env.initialize_submission({"MambaOut: Do We Really Need Mamba for Vision?": "Mamba, an architecture with RNN-like token mixer of state space model (SSM), was recently introduced to address the quadratic complexity of the attention mechanism and subsequently applied to vision tasks. Nevertheless, the performance of Mamba for vision is often underwhelming when compared with convolutional and attention-based models. In this paper, we delve into the essence of Mamba, and conceptually conclude that Mamba is ideally suited for tasks with long-sequence and autoregressive characteristics. For vision tasks, as image classification does not align with either characteristic, we hypothesize that Mamba is not necessary for this task; Detection and segmentation tasks are also not autoregressive, yet they adhere to the long-sequence characteristic, so we believe it is still worthwhile to explore Mamba's potential for these tasks. To empirically verify our hypotheses, we construct a series of models named \\emph{MambaOut} through stacking Mamba blocks while removing their core token mixer, SSM. Experimental results strongly support our hypotheses. Specifically, our MambaOut model surpasses all visual Mamba models on ImageNet image classification, indicating that Mamba is indeed unnecessary for this task. As for detection and segmentation, MambaOut cannot match the performance of state-of-the-art visual Mamba models, demonstrating the potential of Mamba for long-sequence visual tasks."}) + env = PaperRebuttalMultiAgentEnv( + agent_profiles=[agent_profile_A, agent_profile_B] + ) + + submission = paper_profile_A + env.initialize_submission(submission) + env.assign_roles({agent_profile_A.pk: "author", + agent_profile_B.pk: "reviewer"}) + while not env.terminated: env.step() - assert isinstance(env.review, str) + + assert isinstance(env.review, list) + assert len(env.review) > 0 assert isinstance(env.decision, str) - assert isinstance(env.rebuttal, str) + assert env.decision in ["accept", "reject", "boarderline"] + assert isinstance(env.rebuttal, list) + assert len(env.rebuttal) > 0 + @patch("research_town.utils.agent_prompter.model_prompting") def test_paper_submission_env(mock_model_prompting: MagicMock) -> None: diff --git a/tests/test_kb_base.py b/tests/test_kb_base.py deleted file mode 100644 index 9aedcf2f..00000000 --- a/tests/test_kb_base.py +++ /dev/null @@ -1,8 +0,0 @@ -from research_town.kbs.kb_base import BaseKnowledgeBase - - -def test_get_data() -> None: - kb = BaseKnowledgeBase() - data = kb.get_data(10, "Machine Learning") - assert data is not None - assert len(data) <= 10 and len(data) > 0 diff --git a/tests/utils.py b/tests/utils.py index ebdec583..7c16f280 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,4 +14,8 @@ def mock_prompting( return ["This is a paper review for MambaOut."] elif "Please provide a score for the following reviews." in prompt: return ["2"] + elif "Please give me 3 to 5 novel ideas and insights" in prompt: + return ["This is a research idea."] + elif "summarize the keywords" in prompt: + return ["Graph Neural Network"] return ["Default response"]