Skip to content

Commit

Permalink
rebuild data type (#70)
Browse files Browse the repository at this point in the history
* rebuild data type

* add reserach paper class

* change dict to list class

* prompter should keep the same without any class in and out

* rebuild data type in rebuttal-related functions

* update agent list

* rebuild data type for test functions for rebuttal-related functions

* add constants for testing

* add a rebuttal instance

* update get profile

* rebuild data type for submission-related functions

* fix test errors

* rebuild data type for test functions for submission-related functions

* fix test errors

* rebuild the database class

* fix some bugs

* add some optional param

* udpate db class content

* fix id to pk

* fix pre-commit bug

* fix pre-commit bug

* fix some mypy error

* add a TODO comment

* fix part of mypy errors

* fix part of mypy errors

* fix mypy.ini format

* recover test functions

* fix pre-commit

* add default

---------

Co-authored-by: chengzr01 <[email protected]>
  • Loading branch information
lwaekfjlk and chengzr01 authored May 25, 2024
1 parent d77c821 commit 2ed57ce
Show file tree
Hide file tree
Showing 20 changed files with 798 additions and 278 deletions.
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[mypy]
plugins = pydantic.mypy

[mypy-arxiv.*]
ignore_missing_imports = True

[mypy-faiss.*]
ignore_missing_imports = True

[mypy-transformers.*]
ignore_missing_imports = True
247 changes: 196 additions & 51 deletions research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple

from ..dbs import (
AgentAgentDiscussionLog,
AgentPaperMetaReviewLog,
AgentPaperRebuttalLog,
AgentPaperReviewLog,
AgentProfile,
PaperProfile,
)
from ..utils.agent_collector import bfs
from ..utils.agent_prompter import (
communicate_with_multiple_researchers_prompting,
find_collaborators_prompting,
Expand All @@ -8,95 +18,230 @@
rebut_review_prompting,
review_paper_prompting,
review_score_prompting,
summarize_research_direction_prompting,
summarize_research_field_prompting,
write_paper_abstract_prompting,
)
from ..utils.author_collector import bfs
from ..utils.paper_collector import get_paper_list


class BaseResearchAgent(object):
def __init__(self, name: str) -> None:
self.profile = self.get_profile(name)
self.name = name
def __init__(self,
agent_profile: AgentProfile,
) -> None:
self.profile: AgentProfile = agent_profile
self.memory: Dict[str, str] = {}

def get_profile(self, author_name: str) -> Dict[str, Any]:
papers_list = get_paper_list(author_name)
if papers_list:
personal_info = "; ".join(
[f"{details['Title & Abstract']}" for details in papers_list]
)
profile_info = summarize_research_direction_prompting(personal_info)
return {"name": author_name, "profile": profile_info[0]}
else:
return {"info": "fail!"}
def get_profile(self, author_name: str) -> AgentProfile:
# TODO: db get based on name
agent_profile = AgentProfile(
name='Geoffrey Hinton',
bio="A researcher in the field of neural network.",
)
return agent_profile

def communicate(self, message: Dict[str, str]) -> str:
return communicate_with_multiple_researchers_prompting(message)[0]
def communicate(
self,
message: AgentAgentDiscussionLog
) -> AgentAgentDiscussionLog:
# TODO: find a meaningful key
message_dict: Dict[str, str] = {}
if message.message is not None:
message_dict[message.agent_from_pk] = message.message
message_content = communicate_with_multiple_researchers_prompting(
message_dict
)[0]
discussion_log = AgentAgentDiscussionLog(
timestep=(int)(datetime.now().timestamp()),
agent_from_pk=message.agent_from_pk,
agent_to_pk=message.agent_to_pk,
message=message_content
)
return discussion_log

def read_paper(
self, papers: Dict[str, Dict[str, List[str]]], domain: str
self,
papers: List[PaperProfile],
domain: str
) -> str:
papers_dict: Dict[str, Dict[str, List[str]]] = {}
for paper in papers:
papers_dict[paper.pk] = {}
if paper.abstract is not None:
papers_dict[paper.pk]["abstract"] = [paper.abstract]
if paper.title is not None:
papers_dict[paper.pk]["title"] = [paper.title]
profile: Dict[str, str] = {}
if self.profile.name is not None:
profile["name"] = self.profile.name
if self.profile.bio is not None:
profile["profile"] = self.profile.bio
trend = summarize_research_field_prompting(
profile=self.profile,
profile=profile,
keywords=[domain],
papers=papers,
papers=papers_dict
)
trend_output = trend[0]
return trend_output

def find_collaborators(self, input: Dict[str, str], parameter: float = 0.5, max_number: int = 3) -> List[str]:
start_author = [self.name]
def find_collaborators(
self,
paper: PaperProfile,
parameter: float = 0.5,
max_number: int = 3
) -> List[AgentProfile]:
start_author: List[str] = [
self.profile.name] if self.profile.name is not None else []
graph, _, _ = bfs(
author_list=start_author, node_limit=max_number)
collaborators = list(
{name for pair in graph for name in pair if name != self.name})
self_profile = {self.name: self.profile["profile"]}
collaborator_profiles = {author: self.get_profile(
author)["profile"] for author in collaborators}
{name for pair in graph for name in pair if name != self.profile.name})
self_profile: Dict[str, str] = {
self.profile.name: self.profile.bio} if self.profile.name is not None and self.profile.bio is not None else {}
collaborator_profiles: Dict[str, str] = {}
for author in collaborators:
author_bio = self.get_profile(author).bio
if author_bio is not None:
collaborator_profiles[author] = author_bio
paper_serialize: Dict[str, str] = {
paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {}
result = find_collaborators_prompting(
input, self_profile, collaborator_profiles, parameter, max_number)
collaborators_list = [
collaborator for collaborator in collaborators if collaborator in result[0]]
paper_serialize,
self_profile,
collaborator_profiles,
parameter,
max_number
)
collaborators_list = []
for collaborator in collaborators:
if collaborator in result:
collaborators_list.append(self.get_profile(collaborator))
return collaborators_list

def get_co_author_relationships(self, name: str, max_node: int) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]:
start_author = [name]
def get_co_author_relationships(
self,
agent_profile: AgentProfile,
max_node: int
) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]:
start_author: List[str] = [
self.profile.name] if self.profile.name is not None else []
graph, node_feat, edge_feat = bfs(
author_list=start_author, node_limit=max_node)
return graph, node_feat, edge_feat

def generate_idea(
self, trends: List[str], domain: str
self,
papers: List[PaperProfile],
domain: str
) -> List[str]:
papers_dict: Dict[str, Dict[str, List[str]]] = {}
for paper_profile in papers:
papers_dict[paper_profile.pk] = {}
if paper_profile.abstract is not None:
papers_dict[paper_profile.pk]["abstract"] = [
paper_profile.abstract]
if paper_profile.title is not None:
papers_dict[paper_profile.pk]["title"] = [paper_profile.title]
profile: Dict[str, str] = {}
if self.profile.name is not None:
profile["name"] = self.profile.name
if self.profile.bio is not None:
profile["profile"] = self.profile.bio
trends = summarize_research_field_prompting(
profile=profile,
keywords=[domain],
papers=papers_dict
)
ideas: List[str] = []
for trend in trends:
idea = generate_ideas_prompting(trend)[0]
ideas.append(idea)

return ideas

def write_paper(self, input: List[str], papers: Dict[str, Dict[str, List[str]]]) -> str:
paper_abstract = write_paper_abstract_prompting(input, papers)
return paper_abstract[0]
def write_paper(
self,
research_ideas: List[str],
papers: List[PaperProfile]
) -> PaperProfile:
papers_dict: Dict[str, Dict[str, List[str]]] = {}
for paper_profile in papers:
papers_dict[paper_profile.pk] = {}
if paper_profile.abstract is not None:
papers_dict[paper_profile.pk]["abstract"] = [
paper_profile.abstract]
if paper_profile.title is not None:
papers_dict[paper_profile.pk]["title"] = [paper_profile.title]
paper_abstract = write_paper_abstract_prompting(
research_ideas, papers_dict)[0]
paper_profile = PaperProfile(abstract=paper_abstract)
return paper_profile

def review_paper(self, paper: Dict[str, str]) -> Tuple[int, str]:
paper_review = review_paper_prompting(paper)[0]
def review_paper(
self,
paper: PaperProfile
) -> AgentPaperReviewLog:
paper_dict: Dict[str, str] = {
paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {}
paper_review = review_paper_prompting(paper_dict)[0]
review_score = review_score_prompting(paper_review)
return review_score, paper_review

return AgentPaperReviewLog(
timestep=(int)(datetime.now().timestamp()),
paper_pk=paper.pk,
agent_pk=self.profile.pk,
review_content=paper_review,
review_score=review_score
)

def make_review_decision(
self, submission: Dict[str, str], review: Dict[str, Tuple[int, str]]
) -> Tuple[bool, str]:
meta_review = make_review_decision_prompting(submission, review)
if "accept" in meta_review[0].lower():
review_decision = True
else:
review_decision = False
return review_decision, meta_review[0]

def rebut_review(self, submission: Dict[str, str], review: Dict[str, Tuple[int, str]], decision: Dict[str, Tuple[bool, str]]) -> str:
rebut_review = rebut_review_prompting(submission, review, decision)
return rebut_review[0]
self,
paper: PaperProfile,
review: List[AgentPaperReviewLog]
) -> AgentPaperMetaReviewLog:
paper_dict: Dict[str, str] = {
paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {}
review_dict: Dict[str, Tuple[int, str]] = {}
for agent_review_log in review:
if agent_review_log.review_score is not None and agent_review_log.review_content is not None:
review_dict[agent_review_log.pk] = (
agent_review_log.review_score, agent_review_log.review_content)

meta_review = make_review_decision_prompting(paper_dict, review_dict)
review_decision = "accept" in meta_review[0].lower()

return AgentPaperMetaReviewLog(
timestep=(int)(datetime.now().timestamp()),
paper_pk=paper.pk,
agent_pk=self.profile.pk,
decision=review_decision,
meta_review=meta_review[0],
)

def rebut_review(
self,
paper: PaperProfile,
review: List[AgentPaperReviewLog],
decision: List[AgentPaperMetaReviewLog]
) -> AgentPaperRebuttalLog:
paper_dict: Dict[str, str] = {
paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {}
review_dict: Dict[str, Tuple[int, str]] = {}
for agent_review_log in review:
if agent_review_log.review_score is not None and agent_review_log.review_content is not None:
review_dict[agent_review_log.pk] = (
agent_review_log.review_score, agent_review_log.review_content)

decision_dict: Dict[str, Tuple[bool, str]] = {}
for agent_meta_review_log in decision:
if agent_meta_review_log.decision is not None and agent_meta_review_log.meta_review is not None:
decision_dict[agent_meta_review_log.pk] = (
agent_meta_review_log.decision, agent_meta_review_log.meta_review)

rebut_review = rebut_review_prompting(
paper_dict, review_dict, decision_dict)[0]

return AgentPaperRebuttalLog(
timestep=(int)(datetime.now().timestamp()),
paper_pk=paper.pk,
agent_pk=self.profile.pk,
rebuttal_content=rebut_review
)
29 changes: 29 additions & 0 deletions research_town/dbs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .agent_db import AgentProfile, AgentProfileDB
from .env_db import (
AgentAgentDiscussionLog,
AgentPaperMetaReviewLog,
AgentPaperRebuttalLog,
AgentPaperReviewLog,
EnvLogDB,
)
from .paper_db import PaperProfile, PaperProfileDB
from .progress_db import (
ResearchIdea,
ResearchPaperDraft,
ResearchProgressDB,
)

__all__ = [
"AgentAgentDiscussionLog",
"AgentPaperMetaReviewLog",
"AgentPaperRebuttalLog",
"AgentPaperReviewLog",
"PaperProfile",
"AgentProfile",
"ResearchIdea",
"ResearchPaperDraft",
"EnvLogDB",
"PaperProfileDB",
"AgentProfileDB",
"ResearchProgressDB"
]
57 changes: 57 additions & 0 deletions research_town/dbs/agent_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json
import uuid
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field


class AgentProfile(BaseModel):
pk: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: Optional[str] = Field(default=None)
bio: Optional[str] = Field(default=None)


class AgentProfileDB(object):
def __init__(self) -> None:
self.data: Dict[str, AgentProfile] = {}

def add(self, agent: AgentProfile) -> None:
self.data[agent.pk] = agent

def update(self, agent_pk: str, updates: Dict[str, Optional[str]]) -> bool:
if agent_pk in self.data:
for key, value in updates.items():
if value is not None:
setattr(self.data[agent_pk], key, value)
return True
return False

def delete(self, agent_pk: str) -> bool:
if agent_pk in self.data:
del self.data[agent_pk]
return True
return False

def get(self, **conditions: Dict[str, Any]) -> List[AgentProfile]:
result = []
for agent in self.data.values():
if all(getattr(agent, key) == value for key, value in conditions.items()):
result.append(agent)
return result

def save_to_file(self, file_name: str) -> None:
with open(file_name, "w") as f:
json.dump({aid: agent.dict()
for aid, agent in self.data.items()}, f, indent=2)

def load_from_file(self, file_name: str) -> None:
with open(file_name, "r") as f:
data = json.load(f)
self.data = {aid: AgentProfile(**agent_data)
for aid, agent_data in data.items()}

def update_db(self, data: Dict[str, List[Dict[str, Any]]]) -> None:
for date, agents in data.items():
for agent_data in agents:
agent = AgentProfile(**agent_data)
self.add(agent)
Loading

0 comments on commit 2ed57ce

Please sign in to comment.