From c5b862773370a01a0e468af4ce1672e46a6a42c1 Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Thu, 30 May 2024 04:33:10 -0400 Subject: [PATCH] support and run ruff formatter (#148) * fix ruff format * fix ruff format * fix ruff format * change pyproject * fix ruff format * fix codespell error --- .pre-commit-config.yaml | 2 + data/dbs/test_agent_profile_db.json | 16 ++ data/dbs/test_env_logs_db.json | 43 ++++ data/dbs/test_paper_profile_db.json | 46 ++++ data/dbs/test_research_progress_db.json | 13 ++ examples/minimal_demo.py | 35 +-- pyproject.toml | 17 +- research_town/agents/agent_base.py | 124 ++++------ research_town/dbs/__init__.py | 26 +-- research_town/dbs/agent_db.py | 14 +- research_town/dbs/env_db.py | 41 ++-- research_town/dbs/paper_db.py | 19 +- research_town/dbs/progress_db.py | 40 ++-- research_town/envs/__init__.py | 5 +- research_town/envs/env_base.py | 10 +- research_town/envs/env_paper_rebuttal.py | 35 +-- research_town/envs/env_paper_submission.py | 40 ++-- research_town/evaluators/output_format.py | 21 +- research_town/evaluators/quality_evaluator.py | 67 +++--- research_town/utils/agent_collector.py | 33 +-- research_town/utils/agent_prompter.py | 159 +++++++------ research_town/utils/decorator.py | 32 +-- research_town/utils/eval_prompter.py | 178 +++++++------- research_town/utils/logging.py | 6 +- research_town/utils/model_prompting.py | 2 +- research_town/utils/paper_collector.py | 116 +++++----- research_town/utils/serializer.py | 23 +- research_town/utils/string_mapper.py | 8 +- research_town/utils/tools.py | 22 +- tests/constants.py | 41 ++-- tests/test_agents.py | 88 +++---- tests/test_dbs.py | 219 ++++++++++-------- tests/test_envs.py | 32 +-- tests/test_eval.py | 101 +++++--- tests/test_model_call.py | 15 +- tests/test_serialize.py | 2 +- tests/utils.py | 26 ++- 37 files changed, 966 insertions(+), 751 deletions(-) create mode 100644 data/dbs/test_agent_profile_db.json create mode 100644 data/dbs/test_env_logs_db.json create mode 100644 data/dbs/test_paper_profile_db.json create mode 100644 data/dbs/test_research_progress_db.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd1d13d0..ee9a812c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,8 @@ repos: - id: ruff types_or: [python, pyi, jupyter] args: [--fix] + - id: ruff-format + types_or: [python, pyi, jupyter] - repo: https://github.com/pre-commit/mirrors-isort rev: v5.10.1 # Use the latest isort version diff --git a/data/dbs/test_agent_profile_db.json b/data/dbs/test_agent_profile_db.json new file mode 100644 index 00000000..973dade6 --- /dev/null +++ b/data/dbs/test_agent_profile_db.json @@ -0,0 +1,16 @@ +{ + "d544f290-6748-46b5-a82e-fd8f40c1e4cc": { + "pk": "d544f290-6748-46b5-a82e-fd8f40c1e4cc", + "name": "Jane Smith", + "bio": "Expert in NLP", + "collaborators": [], + "institute": "NLP Lab" + }, + "9c581b74-86f6-4577-b400-9221df4c3917": { + "pk": "9c581b74-86f6-4577-b400-9221df4c3917", + "name": "Alice Johnson", + "bio": "Data Scientist", + "collaborators": [], + "institute": "Data Lab" + } +} diff --git a/data/dbs/test_env_logs_db.json b/data/dbs/test_env_logs_db.json new file mode 100644 index 00000000..737ba49b --- /dev/null +++ b/data/dbs/test_env_logs_db.json @@ -0,0 +1,43 @@ +{ + "PaperProfile": [], + "AgentPaperReviewLog": [ + { + "pk": "654935ea-be94-4898-80a4-bb5c7c12f286", + "timestep": 0, + "paper_pk": "paper2", + "agent_pk": "agent2", + "review_score": 4, + "review_content": "Interesting paper" + } + ], + "AgentPaperRebuttalLog": [ + { + "pk": "5387eadb-6a18-44e1-b7a3-55c49c808efd", + "timestep": 0, + "paper_pk": "paper1", + "agent_pk": "agent1", + "rebuttal_content": "I disagree with the review" + } + ], + "AgentPaperMetaReviewLog": [ + { + "pk": "f3bffbbc-c67c-40a5-82f1-200989b2bea9", + "timestep": 0, + "paper_pk": "paper1", + "agent_pk": "agent1", + "decision": true, + "meta_review": "Accept" + } + ], + "AgentAgentDiscussionLog": [ + { + "pk": "67a25e19-2182-4671-9005-a3f95dd3f7c0", + "timestep": 0, + "agent_from_pk": "agent1", + "agent_from_name": "Rex Ying", + "agent_to_pk": "agent2", + "agent_to_name": "John Doe", + "message": "Let's discuss this paper" + } + ] +} diff --git a/data/dbs/test_paper_profile_db.json b/data/dbs/test_paper_profile_db.json new file mode 100644 index 00000000..2125a5ed --- /dev/null +++ b/data/dbs/test_paper_profile_db.json @@ -0,0 +1,46 @@ +{ + "43653097-1230-48e5-ba17-6f616bc93380": { + "pk": "43653097-1230-48e5-ba17-6f616bc93380", + "title": "Updated Sample Paper 1", + "abstract": "This is the abstract for paper 1", + "authors": [ + "Author A", + "Author B" + ], + "url": "http://example.com/paper1", + "timestamp": 1617181723, + "section_contents": null, + "table_captions": null, + "figure_captions": null, + "bibliography": null, + "keywords": [ + "AI", + "ML" + ], + "domain": "Computer Science", + "references": null, + "citation_count": 15, + "award": null + }, + "37e9c697-bd7b-40da-975f-579eddc9508e": { + "pk": "37e9c697-bd7b-40da-975f-579eddc9508e", + "title": "Sample Paper 3", + "abstract": "This is the abstract for paper 3", + "authors": [ + "Author D" + ], + "url": "http://example.com/paper3", + "timestamp": 1617181789, + "section_contents": null, + "table_captions": null, + "figure_captions": null, + "bibliography": null, + "keywords": [ + "Blockchain" + ], + "domain": "Computer Science", + "references": null, + "citation_count": 2, + "award": null + } +} diff --git a/data/dbs/test_research_progress_db.json b/data/dbs/test_research_progress_db.json new file mode 100644 index 00000000..38e4622f --- /dev/null +++ b/data/dbs/test_research_progress_db.json @@ -0,0 +1,13 @@ +{ + "ResearchIdea": [ + { + "pk": "585e0e17-ae53-44a1-a682-e4ee2883655c", + "content": "Blockchain research proposal" + }, + { + "pk": "baf40f3b-f14b-48a0-bc1c-d84eaefa9e58", + "content": "Updated idea content" + } + ], + "ResearchPaper": [] +} diff --git a/examples/minimal_demo.py b/examples/minimal_demo.py index 4f899b46..30be065d 100644 --- a/examples/minimal_demo.py +++ b/examples/minimal_demo.py @@ -1,34 +1,36 @@ from beartype.typing import Dict, List -from research_town.dbs import ( - AgentProfile, - AgentProfileDB, - EnvLogDB, - PaperProfileDB, -) +from research_town.dbs import AgentProfile, AgentProfileDB, EnvLogDB, PaperProfileDB from research_town.envs import ( PaperRebuttalMultiAgentEnv, PaperSubmissionMultiAgentEnvironment, ) -def run_sync_experiment(agent_list: List[str], role_list: List[str], task: Dict[str, str]) -> None: +def run_sync_experiment( + agent_list: List[str], role_list: List[str], task: Dict[str, str] +) -> None: # Create Environment and Agents - agent_profiles = [AgentProfile( - name=agent, bio="A researcher in machine learning.") for agent in agent_list] + agent_profiles = [ + AgentProfile(name=agent, bio='A researcher in machine learning.') + for agent in agent_list + ] agent_db = AgentProfileDB() paper_db = PaperProfileDB() env_db = EnvLogDB() paper_submission_env = PaperSubmissionMultiAgentEnvironment( - agent_profiles=agent_profiles, task=task, + agent_profiles=agent_profiles, + task=task, agent_db=agent_db, paper_db=paper_db, - env_db=env_db) + env_db=env_db, + ) paper_rebuttal_env = PaperRebuttalMultiAgentEnv( agent_profiles=agent_profiles, agent_db=agent_db, paper_db=paper_db, - env_db=env_db) + env_db=env_db, + ) # Paper Submission submission_done = False @@ -51,10 +53,11 @@ def run_sync_experiment(agent_list: List[str], role_list: List[str], task: Dict[ def main() -> None: run_sync_experiment( - agent_list=["Jiaxuan You", "Jure Leskovec"], - role_list=["author", "reviewer"], - task={}) + agent_list=['Jiaxuan You', 'Jure Leskovec'], + role_list=['author', 'reviewer'], + task={}, + ) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/pyproject.toml b/pyproject.toml index 597482de..bf480e96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,22 @@ skip_gitignore = true multi_line_output = 3 include_trailing_comma = true force_grid_wrap = 0 -line_length = 70 +line_length = 88 + +[tool.black] +line-length = 88 +target-version = ['py37', 'py38', 'py39', 'py310'] + +[tool.ruff] +line-length = 88 +fix = true +target-version = "py310" + +[tool.ruff.format] +quote-style = "single" +indent-style = "space" +docstring-code-format = true +docstring-code-line-length = 88 [tool.mypy-arxiv] ignore_missing_imports = true diff --git a/research_town/agents/agent_base.py b/research_town/agents/agent_base.py index 0b273aa2..bc205222 100644 --- a/research_town/agents/agent_base.py +++ b/research_town/agents/agent_base.py @@ -30,10 +30,7 @@ class BaseResearchAgent(object): - def __init__(self, - agent_profile: AgentProfile, - model_name: str - ) -> None: + def __init__(self, agent_profile: AgentProfile, model_name: str) -> None: self.profile: AgentProfile = agent_profile self.memory: Dict[str, str] = {} self.model_name: str = model_name @@ -45,40 +42,39 @@ def get_profile(self, author_name: str) -> AgentProfile: # TODO: need rebuild agent_profile = AgentProfile( name='Geoffrey Hinton', - bio="A researcher in the field of neural network.", + bio='A researcher in the field of neural network.', ) return agent_profile - @beartype def find_collaborators( - self, - paper: PaperProfile, - parameter: float = 0.5, - max_number: int = 3 + self, paper: PaperProfile, parameter: float = 0.5, max_number: int = 3 ) -> List[AgentProfile]: # TODO: need rebuild - start_author: List[str] = [ - self.profile.name] if self.profile.name is not None else [] - graph, _, _ = bfs( - author_list=start_author, node_limit=max_number) + start_author: List[str] = ( + [self.profile.name] if self.profile.name is not None else [] + ) + graph, _, _ = bfs(author_list=start_author, node_limit=max_number) collaborators = list( - {name for pair in graph for name in pair if name != self.profile.name}) - self_profile: Dict[str, str] = { - self.profile.name: self.profile.bio} if self.profile.name is not None and self.profile.bio is not None else {} + {name for pair in graph for name in pair if name != self.profile.name} + ) + self_profile: Dict[str, str] = ( + {self.profile.name: self.profile.bio} + if self.profile.name is not None and self.profile.bio is not None + else {} + ) collaborator_profiles: Dict[str, str] = {} for author in collaborators: author_bio = self.get_profile(author).bio if author_bio is not None: collaborator_profiles[author] = author_bio - paper_serialize: Dict[str, str] = { - paper.title: paper.abstract} if paper.title is not None and paper.abstract is not None else {} + paper_serialize: Dict[str, str] = ( + {paper.title: paper.abstract} + if paper.title is not None and paper.abstract is not None + else {} + ) result = find_collaborators_prompting( - paper_serialize, - self_profile, - collaborator_profiles, - parameter, - max_number + paper_serialize, self_profile, collaborator_profiles, parameter, max_number ) collaborators_list = [] for collaborator in collaborators: @@ -88,24 +84,24 @@ def find_collaborators( @beartype def get_co_author_relationships( - self, - agent_profile: AgentProfile, - max_node: int - ) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]: + self, agent_profile: AgentProfile, max_node: int + ) -> Tuple[ + List[Tuple[str, str]], + Dict[str, List[Dict[str, Any]]], + Dict[str, List[Dict[str, Any]]], + ]: # TODO: need rebuild - start_author: List[str] = [ - self.profile.name] if self.profile.name is not None else [] - graph, node_feat, edge_feat = bfs( - author_list=start_author, node_limit=max_node) + start_author: List[str] = ( + [self.profile.name] if self.profile.name is not None else [] + ) + graph, node_feat, edge_feat = bfs(author_list=start_author, node_limit=max_node) return graph, node_feat, edge_feat -# ======================================= + # ======================================= @beartype def read_paper( - self, - papers: List[PaperProfile], - domains: List[str] + self, papers: List[PaperProfile], domains: List[str] ) -> List[ResearchInsight]: serialized_papers = self.serializer.serialize(papers) serialized_profile = self.serializer.serialize(self.profile) @@ -113,14 +109,13 @@ def read_paper( profile=serialized_profile, papers=serialized_papers, domains=domains, - model_name=self.model_name + model_name=self.model_name, ) insights: List[ResearchInsight] = [] for content in insight_contents: insights.append(ResearchInsight(content=content)) return insights - @beartype def think_idea( self, @@ -129,58 +124,45 @@ def think_idea( serialized_insights = self.serializer.serialize(insights) idea_contents: List[str] = [] for insight in serialized_insights: - idea_contents.append(think_idea_prompting( - insight=insight, - model_name=self.model_name - )[0]) + idea_contents.append( + think_idea_prompting(insight=insight, model_name=self.model_name)[0] + ) ideas: List[ResearchIdea] = [] for content in idea_contents: ideas.append(ResearchIdea(content=content)) return ideas - @beartype def write_paper( - self, - ideas: List[ResearchIdea], - papers: List[PaperProfile] + self, ideas: List[ResearchIdea], papers: List[PaperProfile] ) -> ResearchPaperSubmission: serialized_ideas = self.serializer.serialize(ideas) serialized_papers = self.serializer.serialize(papers) paper_abstract = write_paper_prompting( - ideas=serialized_ideas, - papers=serialized_papers, - model_name=self.model_name + ideas=serialized_ideas, papers=serialized_papers, model_name=self.model_name )[0] return ResearchPaperSubmission(abstract=paper_abstract) @beartype - def write_paper_review( - self, - paper: PaperProfile - ) -> AgentPaperReviewLog: + def write_paper_review(self, paper: PaperProfile) -> AgentPaperReviewLog: serialized_paper = self.serializer.serialize(paper) paper_review = review_paper_prompting( - paper=serialized_paper, - model_name=self.model_name + paper=serialized_paper, model_name=self.model_name )[0] review_score = review_score_prompting( - paper_review=paper_review, - model_name=self.model_name + paper_review=paper_review, model_name=self.model_name ) return AgentPaperReviewLog( timestep=(int)(datetime.now().timestamp()), paper_pk=paper.pk, agent_pk=self.profile.pk, review_content=paper_review, - review_score=review_score + review_score=review_score, ) @beartype def write_paper_meta_review( - self, - paper: PaperProfile, - reviews: List[AgentPaperReviewLog] + self, paper: PaperProfile, reviews: List[AgentPaperReviewLog] ) -> AgentPaperMetaReviewLog: serialized_paper = self.serializer.serialize(paper) serialized_reviews = self.serializer.serialize(reviews) @@ -188,9 +170,9 @@ def write_paper_meta_review( meta_review = write_meta_review_prompting( paper=serialized_paper, reviews=serialized_reviews, - model_name=self.model_name + model_name=self.model_name, ) - review_decision = "accept" in meta_review[0].lower() + review_decision = 'accept' in meta_review[0].lower() return AgentPaperMetaReviewLog( timestep=(int)(datetime.now().timestamp()), @@ -210,27 +192,21 @@ def write_rebuttal( serialized_review = self.serializer.serialize(review) rebuttal_content = write_rebuttal_prompting( - paper=serialized_paper, - review=serialized_review, - model_name=self.model_name + paper=serialized_paper, review=serialized_review, model_name=self.model_name )[0] return AgentPaperRebuttalLog( timestep=(int)(datetime.now().timestamp()), paper_pk=paper.pk, agent_pk=self.profile.pk, - rebuttal_content=rebuttal_content + rebuttal_content=rebuttal_content, ) @beartype - def discuss( - self, - message: AgentAgentDiscussionLog - ) -> AgentAgentDiscussionLog: + def discuss(self, message: AgentAgentDiscussionLog) -> AgentAgentDiscussionLog: serialized_message = self.serializer.serialize(message) message_content = discuss_prompting( - message=serialized_message, - model_name=self.model_name + message=serialized_message, model_name=self.model_name )[0] return AgentAgentDiscussionLog( timestep=(int)(datetime.now().timestamp()), @@ -238,5 +214,5 @@ def discuss( agent_from_name=message.agent_from_name, agent_to_pk=message.agent_to_pk, agent_to_name=message.agent_to_name, - message=message_content + message=message_content, ) diff --git a/research_town/dbs/__init__.py b/research_town/dbs/__init__.py index 0dc43f9f..6051ec94 100644 --- a/research_town/dbs/__init__.py +++ b/research_town/dbs/__init__.py @@ -15,17 +15,17 @@ ) __all__ = [ - "AgentAgentDiscussionLog", - "AgentPaperMetaReviewLog", - "AgentPaperRebuttalLog", - "AgentPaperReviewLog", - "PaperProfile", - "AgentProfile", - "ResearchIdea", - "ResearchInsight", - "ResearchPaperSubmission", - "EnvLogDB", - "PaperProfileDB", - "AgentProfileDB", - "ResearchProgressDB" + 'AgentAgentDiscussionLog', + 'AgentPaperMetaReviewLog', + 'AgentPaperRebuttalLog', + 'AgentPaperReviewLog', + 'PaperProfile', + 'AgentProfile', + 'ResearchIdea', + 'ResearchInsight', + 'ResearchPaperSubmission', + 'EnvLogDB', + 'PaperProfileDB', + 'AgentProfileDB', + 'ResearchProgressDB', ] diff --git a/research_town/dbs/agent_db.py b/research_town/dbs/agent_db.py index 7c993512..df154acf 100644 --- a/research_town/dbs/agent_db.py +++ b/research_town/dbs/agent_db.py @@ -42,15 +42,17 @@ def get(self, **conditions: Dict[str, Any]) -> List[AgentProfile]: return result def save_to_file(self, file_name: str) -> None: - with open(file_name, "w") as f: - json.dump({aid: agent.dict() - for aid, agent in self.data.items()}, f, indent=2) + with open(file_name, 'w') as f: + json.dump( + {aid: agent.dict() for aid, agent in self.data.items()}, f, indent=2 + ) def load_from_file(self, file_name: str) -> None: - with open(file_name, "r") as f: + with open(file_name, 'r') as f: data = json.load(f) - self.data = {aid: AgentProfile(**agent_data) - for aid, agent_data in data.items()} + self.data = { + aid: AgentProfile(**agent_data) for aid, agent_data in data.items() + } def update_db(self, data: Dict[str, List[Dict[str, Any]]]) -> None: for date, agents in data.items(): diff --git a/research_town/dbs/env_db.py b/research_town/dbs/env_db.py index f6cb27c2..d1f8968f 100644 --- a/research_town/dbs/env_db.py +++ b/research_town/dbs/env_db.py @@ -10,11 +10,11 @@ class EnvLogDB: def __init__(self) -> None: self.data: Dict[str, List[Any]] = { - "PaperProfile": [], - "AgentPaperReviewLog": [], - "AgentPaperRebuttalLog": [], - "AgentPaperMetaReviewLog": [], - "AgentAgentDiscussionLog": [] + 'PaperProfile': [], + 'AgentPaperReviewLog': [], + 'AgentPaperRebuttalLog': [], + 'AgentPaperMetaReviewLog': [], + 'AgentAgentDiscussionLog': [], } def add(self, obj: T) -> None: @@ -22,27 +22,33 @@ def add(self, obj: T) -> None: if class_name in self.data: self.data[class_name].append(obj.dict()) else: - raise ValueError(f"Unsupported log type: {class_name}") + raise ValueError(f'Unsupported log type: {class_name}') def get(self, cls: Type[T], **conditions: Dict[str, Any]) -> List[T]: class_name = cls.__name__ if class_name not in self.data: - raise ValueError(f"Unsupported log type: {class_name}") + raise ValueError(f'Unsupported log type: {class_name}') result = [] for data in self.data[class_name]: instance = cls(**data) - if all(getattr(instance, key) == value for key, value in conditions.items()): + if all( + getattr(instance, key) == value for key, value in conditions.items() + ): result.append(instance) return result - def update(self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, Any]) -> int: + def update( + self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, Any] + ) -> int: class_name = cls.__name__ if class_name not in self.data: - raise ValueError(f"Unsupported log type: {class_name}") + raise ValueError(f'Unsupported log type: {class_name}') updated_count = 0 for data in self.data[class_name]: instance = cls(**data) - if all(getattr(instance, key) == value for key, value in conditions.items()): + if all( + getattr(instance, key) == value for key, value in conditions.items() + ): for key, value in updates.items(): setattr(instance, key, value) self.data[class_name].remove(data) @@ -53,20 +59,23 @@ def update(self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, An def delete(self, cls: Type[T], **conditions: Dict[str, Any]) -> int: class_name = cls.__name__ if class_name not in self.data: - raise ValueError(f"Unsupported log type: {class_name}") + raise ValueError(f'Unsupported log type: {class_name}') initial_count = len(self.data[class_name]) self.data[class_name] = [ - data for data in self.data[class_name] - if not all(getattr(cls(**data), key) == value for key, value in conditions.items()) + data + for data in self.data[class_name] + if not all( + getattr(cls(**data), key) == value for key, value in conditions.items() + ) ] return initial_count - len(self.data[class_name]) def save_to_file(self, file_name: str) -> None: - with open(file_name, "w") as f: + with open(file_name, 'w') as f: json.dump(self.data, f, indent=2) def load_from_file(self, file_name: str) -> None: - with open(file_name, "r") as f: + with open(file_name, 'r') as f: self.data = json.load(f) diff --git a/research_town/dbs/paper_db.py b/research_town/dbs/paper_db.py index 7d159efb..7ef22a84 100644 --- a/research_town/dbs/paper_db.py +++ b/research_town/dbs/paper_db.py @@ -25,7 +25,6 @@ class PaperProfile(BaseModel): award: Optional[str] = Field(default=None) - class PaperProfileDB: def __init__(self) -> None: self.data: Dict[str, PaperProfile] = {} @@ -58,15 +57,17 @@ def query_papers(self, **conditions: Dict[str, Any]) -> List[PaperProfile]: return result def save_to_file(self, file_name: str) -> None: - with open(file_name, "w") as f: - json.dump({pk: paper.dict() - for pk, paper in self.data.items()}, f, indent=2) + with open(file_name, 'w') as f: + json.dump( + {pk: paper.dict() for pk, paper in self.data.items()}, f, indent=2 + ) def load_from_file(self, file_name: str) -> None: - with open(file_name, "r") as f: + with open(file_name, 'r') as f: data = json.load(f) - self.data = {pk: PaperProfile(**paper_data) - for pk, paper_data in data.items()} + self.data = { + pk: PaperProfile(**paper_data) for pk, paper_data in data.items() + } def update_db(self, data: Dict[str, List[Dict[str, Any]]]) -> None: for date, papers in data.items(): @@ -79,7 +80,7 @@ def fetch_and_add_papers(self, num: int, domain: str) -> None: transformed_data = {} for date, value in data.items(): papers = [] - papers.append({"abstract": value["abstract"]}) - papers.append({"info": value["info"]}) + papers.append({'abstract': value['abstract']}) + papers.append({'info': value['info']}) transformed_data[date] = papers self.update_db(transformed_data) diff --git a/research_town/dbs/progress_db.py b/research_town/dbs/progress_db.py index fd9bcb30..df8da660 100644 --- a/research_town/dbs/progress_db.py +++ b/research_town/dbs/progress_db.py @@ -6,39 +6,43 @@ T = TypeVar('T', bound=BaseModel) + class ResearchProgressDB: def __init__(self) -> None: - self.data: Dict[str, List[Any]] = { - "ResearchIdea": [], - "ResearchPaper": [] - } + self.data: Dict[str, List[Any]] = {'ResearchIdea': [], 'ResearchPaper': []} def add(self, obj: T) -> None: class_name = obj.__class__.__name__ if class_name in self.data: self.data[class_name].append(obj.dict()) else: - raise ValueError(f"Unsupported type: {class_name}") + raise ValueError(f'Unsupported type: {class_name}') def get(self, cls: Type[T], **conditions: Dict[str, Any]) -> List[T]: class_name = cls.__name__ if class_name not in self.data: - raise ValueError(f"Unsupported type: {class_name}") + raise ValueError(f'Unsupported type: {class_name}') result = [] for data in self.data[class_name]: instance = cls(**data) - if all(getattr(instance, key) == value for key, value in conditions.items()): + if all( + getattr(instance, key) == value for key, value in conditions.items() + ): result.append(instance) return result - def update(self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, Any]) -> int: + def update( + self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, Any] + ) -> int: class_name = cls.__name__ if class_name not in self.data: - raise ValueError(f"Unsupported type: {class_name}") + raise ValueError(f'Unsupported type: {class_name}') updated_count = 0 for data in self.data[class_name]: instance = cls(**data) - if all(getattr(instance, key) == value for key, value in conditions.items()): + if all( + getattr(instance, key) == value for key, value in conditions.items() + ): for key, value in updates.items(): setattr(instance, key, value) self.data[class_name].remove(data) @@ -49,32 +53,38 @@ def update(self, cls: Type[T], conditions: Dict[str, Any], updates: Dict[str, An def delete(self, cls: Type[T], **conditions: Dict[str, Any]) -> int: class_name = cls.__name__ if class_name not in self.data: - raise ValueError(f"Unsupported type: {class_name}") + raise ValueError(f'Unsupported type: {class_name}') initial_count = len(self.data[class_name]) self.data[class_name] = [ - data for data in self.data[class_name] - if not all(getattr(cls(**data), key) == value for key, value in conditions.items()) + data + for data in self.data[class_name] + if not all( + getattr(cls(**data), key) == value for key, value in conditions.items() + ) ] return initial_count - len(self.data[class_name]) def save_to_file(self, file_name: str) -> None: - with open(file_name, "w") as f: + with open(file_name, 'w') as f: json.dump(self.data, f, indent=2) def load_from_file(self, file_name: str) -> None: - with open(file_name, "r") as f: + with open(file_name, 'r') as f: self.data = json.load(f) + class ResearchIdea(BaseModel): pk: str = Field(default_factory=lambda: str(uuid.uuid4())) content: Optional[str] = Field(default=None) + class ResearchPaperSubmission(BaseModel): pk: str = Field(default_factory=lambda: str(uuid.uuid4())) title: Optional[str] = Field(default=None) abstract: Optional[str] = Field(default=None) conference: Optional[str] = Field(default=None) + class ResearchInsight(BaseModel): pk: str = Field(default_factory=lambda: str(uuid.uuid4())) content: Optional[str] = Field(default=None) diff --git a/research_town/envs/__init__.py b/research_town/envs/__init__.py index a0716e90..7c0be9c4 100644 --- a/research_town/envs/__init__.py +++ b/research_town/envs/__init__.py @@ -1,7 +1,4 @@ from .env_paper_rebuttal import PaperRebuttalMultiAgentEnv from .env_paper_submission import PaperSubmissionMultiAgentEnvironment -__all__ = [ - "PaperRebuttalMultiAgentEnv", - "PaperSubmissionMultiAgentEnvironment" -] +__all__ = ['PaperRebuttalMultiAgentEnv', 'PaperSubmissionMultiAgentEnvironment'] diff --git a/research_town/envs/env_base.py b/research_town/envs/env_base.py index f9a5b315..c1e5758b 100644 --- a/research_town/envs/env_base.py +++ b/research_town/envs/env_base.py @@ -10,10 +10,12 @@ def __init__(self, agent_profiles: List[AgentProfile]) -> None: self.db = EnvLogDB() self.agents: List[BaseResearchAgent] = [] for agent_profile in agent_profiles: - self.agents.append(BaseResearchAgent( - agent_profile=agent_profile, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" - )) + self.agents.append( + BaseResearchAgent( + agent_profile=agent_profile, + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', + ) + ) def step(self) -> None: raise NotImplementedError diff --git a/research_town/envs/env_paper_rebuttal.py b/research_town/envs/env_paper_rebuttal.py index 3e8305b0..a887c9d7 100644 --- a/research_town/envs/env_paper_rebuttal.py +++ b/research_town/envs/env_paper_rebuttal.py @@ -15,17 +15,18 @@ class PaperRebuttalMultiAgentEnv(BaseMultiAgentEnv): - def __init__(self, + def __init__( + self, agent_profiles: List[AgentProfile], agent_db: AgentProfileDB, paper_db: PaperProfileDB, - env_db: EnvLogDB + env_db: EnvLogDB, ) -> None: super().__init__(agent_profiles) self.turn_number = 0 self.turn_max = 1 self.terminated = False - self.decision = "reject" + self.decision = 'reject' self.submission = PaperProfile() self.reviewer_mask = [False] * len(agent_profiles) self.reviews: List[AgentPaperReviewLog] = [] @@ -38,7 +39,7 @@ def __init__(self, @beartype def assign_roles(self, role_dict: Dict[str, str]) -> None: for index, agent_profile in enumerate(self.agent_profiles): - if role_dict[agent_profile.pk] == "reviewer": + if role_dict[agent_profile.pk] == 'reviewer': self.reviewer_mask[index] = True @beartype @@ -47,12 +48,12 @@ def initialize_submission(self, paper_profile: PaperProfile) -> None: @beartype def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None: - decision_count = {"accept": 0, "reject": 0} + decision_count = {'accept': 0, 'reject': 0} for _, decision in decision_dict.items(): if decision[0]: - decision_count["accept"] += 1 + decision_count['accept'] += 1 else: - decision_count["reject"] += 1 + decision_count['reject'] += 1 count_max = 0 for d, count in decision_count.items(): if count > count_max: @@ -63,23 +64,27 @@ def step(self) -> None: # Paper Reviewing for index, agent in enumerate(self.agents): if self.reviewer_mask[index]: - self.reviews.append(agent.write_paper_review( - paper=self.submission)) + self.reviews.append(agent.write_paper_review(paper=self.submission)) # Paper Meta Reviewing for index, agent in enumerate(self.agents): if self.reviewer_mask[index]: - self.meta_reviews.append(agent.write_paper_meta_review( - paper=self.submission, reviews=self.reviews)) + self.meta_reviews.append( + agent.write_paper_meta_review( + paper=self.submission, reviews=self.reviews + ) + ) # Rebuttal Submitting for index, agent in enumerate(self.agents): for review in self.reviews: if self.reviewer_mask[index]: - self.rebuttals.append(agent.write_rebuttal( - paper=self.submission, - review=review, - )) + self.rebuttals.append( + agent.write_rebuttal( + paper=self.submission, + review=review, + ) + ) self.turn_number += 1 if self.turn_number >= self.turn_max: diff --git a/research_town/envs/env_paper_submission.py b/research_town/envs/env_paper_submission.py index 2b9a0989..9f5e46d1 100644 --- a/research_town/envs/env_paper_submission.py +++ b/research_town/envs/env_paper_submission.py @@ -2,13 +2,7 @@ from beartype.typing import Dict, List from ..agents.agent_base import BaseResearchAgent -from ..dbs import ( - AgentProfile, - AgentProfileDB, - EnvLogDB, - PaperProfile, - PaperProfileDB, -) +from ..dbs import AgentProfile, AgentProfileDB, EnvLogDB, PaperProfile, PaperProfileDB from .env_base import BaseMultiAgentEnv @@ -19,7 +13,7 @@ def __init__( agent_db: AgentProfileDB, paper_db: PaperProfileDB, env_db: EnvLogDB, - task: Dict[str, str] + task: Dict[str, str], ) -> None: super().__init__(agent_profiles) self.turn_number = 0 @@ -34,8 +28,16 @@ def __init__( def step(self) -> None: # TODO: support retrieval from database # external_data = self.db.get(cls=PaperProfile, conditions={}) - papers = [PaperProfile(title="A Survey on Machine Learning", - abstract="This paper surveys the field of machine learning."), PaperProfile(title="A Survey on Natural Language Processing", abstract="This paper surveys the field of natural language processing.")] + papers = [ + PaperProfile( + title='A Survey on Machine Learning', + abstract='This paper surveys the field of machine learning.', + ), + PaperProfile( + title='A Survey on Natural Language Processing', + abstract='This paper surveys the field of natural language processing.', + ), + ] agent_names_to_objs: Dict[str, BaseResearchAgent] = {} for iter_agent in self.agents: if iter_agent.profile.name is not None: @@ -43,26 +45,28 @@ def step(self) -> None: submissions: Dict[str, PaperProfile] = {} for agent in self.agents: # TODO: update find collaborator functions with initial task - collaborators = agent.find_collaborators(PaperProfile(title="A Survey on Machine Learning", - abstract="This paper surveys the field of machine learning.")) + collaborators = agent.find_collaborators( + PaperProfile( + title='A Survey on Machine Learning', + abstract='This paper surveys the field of machine learning.', + ) + ) collaborator_agents: List[BaseResearchAgent] = [] for researcher_profile in collaborators: if researcher_profile.name: if researcher_profile.name not in agent_names_to_objs: new_agent_obj = BaseResearchAgent( agent_profile=researcher_profile, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) collaborator_agents.append(new_agent_obj) agent_names_to_objs[researcher_profile.name] = new_agent_obj else: collaborator_agents.append( - agent_names_to_objs[researcher_profile.name]) + agent_names_to_objs[researcher_profile.name] + ) - insights = agent.read_paper( - papers=papers, - domains=["machine learning"] - ) + insights = agent.read_paper(papers=papers, domains=['machine learning']) ideas = agent.think_idea(insights=insights) for collaborator_agent in collaborator_agents: ideas.extend(collaborator_agent.think_idea(insights=insights)) diff --git a/research_town/evaluators/output_format.py b/research_town/evaluators/output_format.py index a70152a6..6e341b32 100644 --- a/research_town/evaluators/output_format.py +++ b/research_town/evaluators/output_format.py @@ -1,53 +1,58 @@ - from beartype.typing import Type, TypeVar from pydantic import BaseModel, Extra, Field, validator T = TypeVar('T', bound=BaseModel) + class IdeaEvalOutput(BaseModel): overall_score: int = Field(default=-1) pk: str = Field(default='0') + class Config: extra = Extra.allow # Allows extra fields to be stored @validator('overall_score') def validate_overall_score(cls: Type[T], v: int) -> int: if v is None: - raise ValueError("Overall score cannot be None") + raise ValueError('Overall score cannot be None') if not (0 <= v <= 100): - raise ValueError("Overall score must be between 0 and 100") + raise ValueError('Overall score must be between 0 and 100') return v class PaperEvalOutput(BaseModel): overall_score: int = Field(default=-1) pk: str = Field(default='0') + class Config: extra = Extra.allow # Allows extra fields to be stored @validator('overall_score') def validate_overall_score(cls: Type[T], v: int) -> int: if v is None: - raise ValueError("Overall score cannot be None") + raise ValueError('Overall score cannot be None') if not (0 <= v <= 100): - raise ValueError("Overall score must be between 0 and 100") + raise ValueError('Overall score must be between 0 and 100') return v + class ReviewEvalOutput(BaseModel): overall_score: int = Field(default=-1) pk: str = Field(default='0') + class Config: extra = Extra.allow # Allows extra fields to be stored @validator('overall_score') def validate_overall_score(cls: Type[T], v: int) -> int: if v is None: - raise ValueError("Overall score cannot be None") + raise ValueError('Overall score cannot be None') if not (0 <= v <= 100): - raise ValueError("Overall score must be between 0 and 100") + raise ValueError('Overall score must be between 0 and 100') return v + class OutputFormatError(Exception): - def __init__(self, message:str="Output format error")-> None: + def __init__(self, message: str = 'Output format error') -> None: self.message = message super().__init__(self.message) diff --git a/research_town/evaluators/quality_evaluator.py b/research_town/evaluators/quality_evaluator.py index 4aced69b..8c4867c1 100644 --- a/research_town/evaluators/quality_evaluator.py +++ b/research_town/evaluators/quality_evaluator.py @@ -1,4 +1,3 @@ - import re from beartype.typing import Any @@ -18,11 +17,7 @@ class IdeaQualityEvaluator(object): - def __init__(self, - model_name: str, - *args: Any, - **kwargs: Any - )-> None: + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: self.model_name = model_name self.parsed_output = IdeaEvalOutput() @@ -31,11 +26,9 @@ def eval( self, *args: Any, **kwargs: Any, - )-> IdeaEvalOutput: + ) -> IdeaEvalOutput: raw_output = idea_quality_eval_prompting( - idea=kwargs['idea'], - trend=kwargs['trend'], - model_name=self.model_name + idea=kwargs['idea'], trend=kwargs['trend'], model_name=self.model_name ) self.parsed_output = self.parse(raw_output) @@ -43,22 +36,19 @@ def eval( setattr(self.parsed_output, key, value) return self.parsed_output - def parse(self, raw_output:str) -> IdeaEvalOutput: - match = re.search(r"Overall\s*Score\s*\W*(\d+)\W*", raw_output, re.IGNORECASE) + def parse(self, raw_output: str) -> IdeaEvalOutput: + match = re.search(r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE) if match: try: return IdeaEvalOutput(overall_score=int(match.group(1))) except ValueError as e: - raise OutputFormatError(f"Invalid overall score: {e}") + raise OutputFormatError(f'Invalid overall score: {e}') else: raise OutputFormatError("Output format error: 'Overall Score' not found") + class PaperQualityEvaluator(object): - def __init__(self, - model_name: str, - *args: Any, - **kwargs: Any - )-> None: + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: self.model_name = model_name self.parsed_output = PaperEvalOutput() @@ -67,11 +57,9 @@ def eval( self, *args: Any, **kwargs: Any, - )-> PaperEvalOutput: + ) -> PaperEvalOutput: raw_output = paper_quality_eval_prompting( - idea=kwargs['idea'], - paper=kwargs['paper'], - model_name=self.model_name + idea=kwargs['idea'], paper=kwargs['paper'], model_name=self.model_name ) self.parsed_output = self.parse(raw_output) @@ -80,21 +68,18 @@ def eval( return self.parsed_output def parse(self, raw_output: str) -> PaperEvalOutput: - match = re.search(r"Overall\s*Score\s*\W*(\d+)\W*", raw_output, re.IGNORECASE) + match = re.search(r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE) if match: try: return PaperEvalOutput(overall_score=int(match.group(1))) except ValueError as e: - raise OutputFormatError(f"Invalid overall score: {e}") + raise OutputFormatError(f'Invalid overall score: {e}') else: raise OutputFormatError("Output format error: 'Overall Score' not found") + class ReviewQualityEvaluator(object): - def __init__(self, - model_name: str, - *args: Any, - **kwargs: Any - )-> None: + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: self.model_name = model_name self.parsed_output = ReviewEvalOutput() @@ -103,14 +88,14 @@ def eval( self, *args: Any, **kwargs: Any, - )-> ReviewEvalOutput: + ) -> ReviewEvalOutput: raw_output = review_quality_eval_prompting( - idea=kwargs['idea'], # idea: str, - trend=kwargs['trend'], # trend: str, - paper=kwargs['paper'], # paper: Dict[str,str], - review=kwargs['review'], # review: Dict[str,str], - decision=kwargs['decision'], # decision: str, - model_name=self.model_name + idea=kwargs['idea'], # idea: str, + trend=kwargs['trend'], # trend: str, + paper=kwargs['paper'], # paper: Dict[str,str], + review=kwargs['review'], # review: Dict[str,str], + decision=kwargs['decision'], # decision: str, + model_name=self.model_name, ) self.parsed_output = self.parse(raw_output) # Store the input kwargs in parsed_output @@ -118,12 +103,14 @@ def eval( setattr(self.parsed_output, key, value) return self.parsed_output - def parse(self, raw_output:str) -> ReviewEvalOutput: - match = re.search(r"Overall\s*Score\s*\W*(\d+)\W*", raw_output, re.IGNORECASE) + def parse(self, raw_output: str) -> ReviewEvalOutput: + match = re.search(r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE) if match: try: return ReviewEvalOutput(overall_score=int(match.group(1))) except ValueError as e: - raise OutputFormatError(f"Invalid overall score: {e}") + raise OutputFormatError(f'Invalid overall score: {e}') else: - raise OutputFormatError(f"Output format error: 'Overall Score' not found. Raw output is {raw_output}.") + raise OutputFormatError( + f"Output format error: 'Overall Score' not found. Raw output is {raw_output}." + ) diff --git a/research_town/utils/agent_collector.py b/research_town/utils/agent_collector.py index 4f1f0d19..aeaf93f8 100644 --- a/research_town/utils/agent_collector.py +++ b/research_town/utils/agent_collector.py @@ -6,7 +6,7 @@ def get_authors(authors: List[str], first_author: bool = False) -> str: if first_author: return authors[0] - return ", ".join(authors) + return ', '.join(authors) def author_position(author: str, author_list: List[str]) -> int: @@ -31,8 +31,7 @@ def co_author_frequency( def co_author_filter(co_authors: Dict[str, int], limit: int = 5) -> List[str]: - co_author_list = sorted( - co_authors.items(), key=lambda p: p[1], reverse=True) + co_author_list = sorted(co_authors.items(), key=lambda p: p[1], reverse=True) return [name for name, _ in co_author_list[:limit]] @@ -40,23 +39,23 @@ def fetch_author_info(author: str) -> Tuple[List[Dict[str, Any]], List[str]]: client = Client() papers_info = [] co_authors: Dict[str, int] = {} - search = Search(query=f"au:{author}", max_results=10) + search = Search(query=f'au:{author}', max_results=10) for result in tqdm( - client.results(search), desc="Processing Author Papers", unit="Paper" + client.results(search), desc='Processing Author Papers', unit='Paper' ): - if author not in ", ".join(author.name for author in result.authors): + if author not in ', '.join(author.name for author in result.authors): continue author_list = [author.name for author in result.authors] co_authors = co_author_frequency(author, author_list, co_authors) paper_info = { - "url": result.entry_id, - "title": result.title, - "abstract": result.summary, - "authors": ", ".join(author.name for author in result.authors), - "published": str(result.published).split(" ")[0], - "updated": str(result.updated).split(" ")[0], - "primary_cat": result.primary_category, - "cats": result.categories, + 'url': result.entry_id, + 'title': result.title, + 'abstract': result.summary, + 'authors': ', '.join(author.name for author in result.authors), + 'published': str(result.published).split(' ')[0], + 'updated': str(result.updated).split(' ')[0], + 'primary_cat': result.primary_category, + 'cats': result.categories, } papers_info.append(paper_info) co_author_names = co_author_filter(co_authors, limit=5) @@ -65,7 +64,11 @@ def fetch_author_info(author: str) -> Tuple[List[Dict[str, Any]], List[str]]: def bfs( author_list: List[str], node_limit: int = 20 -) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]: +) -> Tuple[ + List[Tuple[str, str]], + Dict[str, List[Dict[str, Any]]], + Dict[str, List[Dict[str, Any]]], +]: graph = [] node_feat: Dict[str, List[Dict[str, Any]]] = dict() edge_feat: Dict[str, List[Dict[str, Any]]] = dict() diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index 8b6a4c1c..648dfff7 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -14,6 +14,7 @@ # ======================================= + @beartype def summarize_research_direction_prompting( personal_info: str, @@ -24,10 +25,10 @@ def summarize_research_direction_prompting( """ prompt_template = ( "Based on the list of the researcher's first person persona from different times, please write a comprehensive first person persona. " - "Focus more on more recent personas. Be concise and clear (around 300 words). " - "Here are the personas from different times: {personalinfo}" + 'Focus more on more recent personas. Be concise and clear (around 300 words). ' + 'Here are the personas from different times: {personalinfo}' ) - template_input = {"personalinfo": personal_info} + template_input = {'personalinfo': personal_info} prompt = prompt_template.format_map(template_input) return model_prompting(model_name, prompt) @@ -39,34 +40,46 @@ def find_collaborators_prompting( collaborator_profiles: Dict[str, str], parameter: float = 0.5, max_number: int = 3, - model_name: str = "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1", + model_name: str = 'together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) -> List[str]: self_serialize = [ - f"Name: {name}\nProfile: {self_profile[name]}" for _, name in enumerate(self_profile.keys())] - self_serialize_all = "\n\n".join(self_serialize) + f'Name: {name}\nProfile: {self_profile[name]}' + for _, name in enumerate(self_profile.keys()) + ] + self_serialize_all = '\n\n'.join(self_serialize) - task_serialize = [f"Time: {timestamp}\nAbstract: {input[timestamp]}\n" for _, - timestamp in enumerate(input.keys())] - task_serialize_all = "\n\n".join(task_serialize) + task_serialize = [ + f'Time: {timestamp}\nAbstract: {input[timestamp]}\n' + for _, timestamp in enumerate(input.keys()) + ] + task_serialize_all = '\n\n'.join(task_serialize) collaborator_serialize = [ - f"Name: {name}\nProfile: {collaborator_profiles[name]}" for _, name in enumerate(collaborator_profiles.keys())] - collaborator_serialize_all = "\n\n".join(collaborator_serialize) + f'Name: {name}\nProfile: {collaborator_profiles[name]}' + for _, name in enumerate(collaborator_profiles.keys()) + ] + collaborator_serialize_all = '\n\n'.join(collaborator_serialize) prompt_template = ( - "Given the name and profile of me, could you find {max_number} collaborators for the following collaboration task?" - "Here is my profile: {self_serialize_all}" - "The collaboration task include: {task_serialize_all}" - "Here are a full list of the names and profiles of potential collaborators: {collaborators_serialize_all}" + 'Given the name and profile of me, could you find {max_number} collaborators for the following collaboration task?' + 'Here is my profile: {self_serialize_all}' + 'The collaboration task include: {task_serialize_all}' + 'Here are a full list of the names and profiles of potential collaborators: {collaborators_serialize_all}' "Generate the collaborator in a list separated by '-' for each collaborator" ) - input = {"max_number": str(max_number), "self_serialize_all": self_serialize_all, - "task_serialize_all": task_serialize_all, "collaborators_serialize_all": collaborator_serialize_all} + input = { + 'max_number': str(max_number), + 'self_serialize_all': self_serialize_all, + 'task_serialize_all': task_serialize_all, + 'collaborators_serialize_all': collaborator_serialize_all, + } prompt = prompt_template.format_map(input) return model_prompting(model_name, prompt) + # ======================================= + @beartype def read_paper_prompting( profile: Dict[str, str], @@ -75,30 +88,31 @@ def read_paper_prompting( model_name: str, ) -> List[str]: query_template = ( - "Given the profile of me, keywords, some recent paper titles and abstracts. Could you summarize the keywords of high level research backgrounds and insights in this field (related to my profile if possible)." - "Here is my profile biology: {profile_bio}" - "Here are the domains: {domains}" + 'Given the profile of me, keywords, some recent paper titles and abstracts. Could you summarize the keywords of high level research backgrounds and insights in this field (related to my profile if possible).' + 'Here is my profile biology: {profile_bio}' + 'Here are the domains: {domains}' ) prompt_template = ( - "Given the profile of me, keywords, some recent paper titles and abstracts. Could you summarize the keywords of high level research backgrounds and insights in this field (related to my profile if possible)." - "Here is my profile biology: {profile_bio}" - "Here are the research domains: {domains}" - "Here are some recent paper titles and abstracts: {papers}" + 'Given the profile of me, keywords, some recent paper titles and abstracts. Could you summarize the keywords of high level research backgrounds and insights in this field (related to my profile if possible).' + 'Here is my profile biology: {profile_bio}' + 'Here are the research domains: {domains}' + 'Here are some recent paper titles and abstracts: {papers}' ) - query = query_template.format_map({ - "profile_bio": profile['bio'], - "domains": "; ".join(domains) - }) + query = query_template.format_map( + {'profile_bio': profile['bio'], 'domains': '; '.join(domains)} + ) corpus = [paper['abstract'] for paper in papers] related_papers = get_related_papers(corpus, query, num=1) - prompt = prompt_template.format_map({ - "profile_bio": profile['bio'], - "domains": "; ".join(domains), - "papers": "; ".join(related_papers) - }) + prompt = prompt_template.format_map( + { + 'profile_bio': profile['bio'], + 'domains': '; '.join(domains), + 'papers': '; '.join(related_papers), + } + ) return model_prompting(model_name, prompt) @@ -108,11 +122,11 @@ def think_idea_prompting( model_name: str, ) -> List[str]: prompt_template = ( - "Here is a high-level summarized insight of a research field {insight}. " - "How do you view this field? Do you have any novel ideas or insights? " - "Please give me 3 to 5 novel ideas and insights in bullet points. Each bullet point should be concise, containing 2 or 3 sentences." + 'Here is a high-level summarized insight of a research field {insight}. ' + 'How do you view this field? Do you have any novel ideas or insights? ' + 'Please give me 3 to 5 novel ideas and insights in bullet points. Each bullet point should be concise, containing 2 or 3 sentences.' ) - prompt = prompt_template.format_map({"insight": insight['content']}) + prompt = prompt_template.format_map({'insight': insight['content']}) return model_prompting(model_name, prompt) @@ -126,33 +140,31 @@ def write_paper_prompting( papers_str = map_paper_list_to_str(papers) prompt_template = ( - "Please write a paper based on the following ideas and external data. To save time, you only need to write the abstract. " - "You might use two or more of these ideas if they are related and works well together. " - "Here are the ideas: {ideas}" - "Here are the external data, which is a list abstracts of related papers: {papers}" + 'Please write a paper based on the following ideas and external data. To save time, you only need to write the abstract. ' + 'You might use two or more of these ideas if they are related and works well together. ' + 'Here are the ideas: {ideas}' + 'Here are the external data, which is a list abstracts of related papers: {papers}' ) - prompt = prompt_template.format_map({ - "ideas": ideas_str, - "papers": papers_str - }) + prompt = prompt_template.format_map({'ideas': ideas_str, 'papers': papers_str}) return model_prompting(model_name, prompt) + @beartype -def review_score_prompting( - paper_review: str, - model_name: str -) -> int: +def review_score_prompting(paper_review: str, model_name: str) -> int: prompt_template = ( - "Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score." - "Here are the reviews: {paper_review}" + 'Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score.' + 'Here are the reviews: {paper_review}' + ) + prompt = prompt_template.format_map( + { + 'paper_review': paper_review, + } ) - prompt = prompt_template.format_map({ - "paper_review": paper_review, - }) score_str = model_prompting(model_name, prompt)[0] return int(score_str[0]) if score_str[0].isdigit() else 0 + @beartype def review_paper_prompting( paper: Dict[str, str], @@ -160,14 +172,15 @@ def review_paper_prompting( ) -> List[str]: papers_str = map_paper_to_str(paper) prompt_template = ( - "Please give some reviews based on the following inputs and external data." - "You might use two or more of these titles if they are related and works well together." - "Here are the external data, which is a list of related papers: {papers}" + 'Please give some reviews based on the following inputs and external data.' + 'You might use two or more of these titles if they are related and works well together.' + 'Here are the external data, which is a list of related papers: {papers}' ) - prompt = prompt_template.format_map({"papers": papers_str}) + prompt = prompt_template.format_map({'papers': papers_str}) return model_prompting(model_name, prompt) + @beartype def write_meta_review_prompting( paper: Dict[str, str], @@ -178,16 +191,14 @@ def write_meta_review_prompting( reviews_str = map_review_list_to_str(reviews) prompt_template = ( - "Please make an review decision to decide whether the following submission should be accepted or rejected by an academic conference. Here are several reviews from reviewers for this submission. Please indicate your review decision as accept or reject." - "Here is the submission: {paper}" - "Here are the reviews: {reviews}" + 'Please make an review decision to decide whether the following submission should be accepted or rejected by an academic conference. Here are several reviews from reviewers for this submission. Please indicate your review decision as accept or reject.' + 'Here is the submission: {paper}' + 'Here are the reviews: {reviews}' ) - prompt = prompt_template.format_map({ - "paper": paper_str, - "reviews": reviews_str - }) + prompt = prompt_template.format_map({'paper': paper_str, 'reviews': reviews_str}) return model_prompting(model_name, prompt) + @beartype def write_rebuttal_prompting( paper: Dict[str, str], @@ -198,16 +209,14 @@ def write_rebuttal_prompting( review_str = map_review_to_str(review) prompt_template = ( - "Please write a rebuttal for the following submission you have made to an academic conference. Here are the reviews and decisions from the reviewers. Your rebuttal should rebut the reviews to convince the reviewers to accept your submission." - "Here is the submission: {paper}" - "Here are the reviews: {review}" + 'Please write a rebuttal for the following submission you have made to an academic conference. Here are the reviews and decisions from the reviewers. Your rebuttal should rebut the reviews to convince the reviewers to accept your submission.' + 'Here is the submission: {paper}' + 'Here are the reviews: {review}' ) - prompt = prompt_template.format_map({ - "paper": paper_str, - "review": review_str - }) + prompt = prompt_template.format_map({'paper': paper_str, 'review': review_str}) return model_prompting(model_name, prompt) + @beartype def discuss_prompting( message: Dict[str, str], @@ -215,8 +224,8 @@ def discuss_prompting( ) -> List[str]: message_str = map_message_to_str(message) prompt_template = ( - "Please continue in a conversation with other fellow researchers for me, where you will address their concerns in a scholarly way. " - "Here are the messages from other researchers: {message}" + 'Please continue in a conversation with other fellow researchers for me, where you will address their concerns in a scholarly way. ' + 'Here are the messages from other researchers: {message}' ) - prompt = prompt_template.format_map({"message": message_str}) + prompt = prompt_template.format_map({'message': message_str}) return model_prompting(model_name, prompt) diff --git a/research_town/utils/decorator.py b/research_town/utils/decorator.py index 513e35a4..83cb8b68 100644 --- a/research_town/utils/decorator.py +++ b/research_town/utils/decorator.py @@ -2,20 +2,14 @@ import time from functools import wraps -from beartype.typing import ( - Any, - Callable, - List, - Optional, - TypeVar, - cast, -) +from beartype.typing import Any, Callable, List, Optional, TypeVar, cast from pydantic import BaseModel INF = float(math.inf) T = TypeVar('T', bound=Callable[..., Optional[List[str]]]) + def api_calling_error_exponential_backoff( retries: int = 5, base_wait_time: int = 1 ) -> Callable[[T], T]: @@ -34,18 +28,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[List[str]]: try: return func(*args, **kwargs) except Exception as e: - wait_time = base_wait_time * (2 ** attempts) - print(f"Attempt {attempts + 1} failed: {e}") - print(f"Waiting {wait_time} seconds before retrying...") + wait_time = base_wait_time * (2**attempts) + print(f'Attempt {attempts + 1} failed: {e}') + print(f'Waiting {wait_time} seconds before retrying...') time.sleep(wait_time) attempts += 1 print(f"Failed to execute '{func.__name__}' after {retries} retries.") return None + return cast(T, wrapper) + return cast(Callable[[T], T], decorator) + TBaseModel = TypeVar('TBaseModel', bound=Callable[..., BaseModel]) + def parsing_error_exponential_backoff( retries: int = 5, base_wait_time: int = 1 ) -> Callable[[TBaseModel], TBaseModel]: @@ -64,12 +62,16 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Optional[BaseModel]: try: return func(self, *args, **kwargs) except Exception as e: - wait_time = base_wait_time * (2 ** attempts) - print(f"Attempt {attempts + 1} failed: {e}") - print(f"Waiting {wait_time} seconds before retrying...") + wait_time = base_wait_time * (2**attempts) + print(f'Attempt {attempts + 1} failed: {e}') + print(f'Waiting {wait_time} seconds before retrying...') time.sleep(wait_time) attempts += 1 - print(f"Failed to get valid input from {func.__name__} after {retries} retries.") + print( + f'Failed to get valid input from {func.__name__} after {retries} retries.' + ) return None + return cast(TBaseModel, wrapper) + return cast(Callable[[TBaseModel], TBaseModel], decorator) diff --git a/research_town/utils/eval_prompter.py b/research_town/utils/eval_prompter.py index 6af8a57f..74bff7ae 100644 --- a/research_town/utils/eval_prompter.py +++ b/research_town/utils/eval_prompter.py @@ -7,93 +7,85 @@ @beartype def idea_quality_eval_prompting( idea: str, - trend: str, + trend: str, model_name: str, ) -> str: prompt_idea = ( - " Please evaluate the idea based on the following dimensions, considering the current research trend within the ML community. If the research trend field is left blank, please use your common knowledge to assess the trend. Finally, give an overall score (0-100) and 10 dimension scores (for each dimension, provide a rating (1-10)) as the evaluation for the idea. The output format should follow these rules: Overall Score of an idea (0-100), with 10 Dimension Scores: [d1, d2, d3, ..., d10], where di is the score of the i-th dimension. An example of output is: 'Overall Score=89. Dimension Scores=[8,9,9,9,9,9,9,9,9,9]'.\n" - " The details of rating are as follow:\n" - "1. Novelty\n" - "Rating (1-10):\n" - "Comments:\n" - "How original and unique is the idea?\n" - "Does it introduce a new perspective or significant advancement compared to existing methods?\n" - "How does it align with or diverge from the innovations highlighted in the trend?\n" - "2. Technical Depth\n" - "Rating (1-10):\n" - "Comments:\n" - "Assess the technical rigor of the idea.\n" - "Does it include solid theoretical foundations, robust algorithms, and detailed methodologies?\n" - "Is the technical depth in line with the state-of-the-art techniques noted in the trend?\n" - "3. Impact and Significance\n" - "Rating (1-10):\n" - "Comments:\n" - "Evaluate the potential impact of the idea on the ML community and beyond.\n" - "How significant is its contribution to advancing the field?\n" - "Does it address high-impact problems or gaps identified in the trend?\n" - "4. Feasibility and Practicality\n" - "Rating (1-10):\n" - "Comments:\n" - "Assess the feasibility of implementing the idea.\n" - "Is it practically applicable in real-world scenarios?\n" - "Does it consider efficiency and scalability, in line with the practical application focus of the trend?\n" - "5. Theoretical Foundation and Conceptual Soundness\n" - "Rating (1-10):\n" - "Comments:\n" - "Evaluate the theoretical foundation and conceptual soundness of the idea.\n" - "Are the underlying principles well-defined and logically consistent?\n" - "Does the idea demonstrate a deep understanding of relevant theories and concepts?\n" - "How does it contribute to advancing theoretical understanding within the field?\n" - "6. Clarity and Presentation\n" - "Rating (1-10):\n" - "Comments:\n" - "Assess the clarity, organization, and presentation quality of the idea.\n" - "Is the idea communicated effectively, adhering to high presentation standards seen in top-tier ML conferences?\n" - "7. Potential for Real-world Applications\n" - "Rating (1-10):\n" - "Comments:\n" - "Evaluate the potential of the idea to be applied in real-world scenarios.\n" - "How applicable is it in practical settings and industry contexts?\n" - "Does it address real-world problems or challenges identified in the trend?\n" - "8. Innovation Potential\n" - "Rating (1-10):\n" - "Comments:\n" - "Assess the potential of the idea to inspire further research and innovation within the ML community.\n" - "Does it open up new avenues for research or provide a novel framework aligning with the emerging trends and future directions of the trend?\n" - "9. Ethical Considerations\n" - "Rating (1-10):\n" - "Comments:\n" - "Consider the ethical implications and societal impact of the idea.\n" - "Does it adhere to the growing emphasis on ethical AI and responsible ML practices as highlighted in the trend?\n" - "10. Interdisciplinary Connections\n" - "Rating (1-10):\n" - "Comments:\n" - "Evaluate the potential for the idea to connect with and contribute to other disciplines beyond ML.\n" - "Does it align with the trend of interdisciplinary research and collaboration, integrating with fields such as data science, neuroscience, or social sciences?\n" - - "Here is the idea to evaluate: {idea}.\n" - "Here is the research trend: {trend}.\n" - + " Please evaluate the idea based on the following dimensions, considering the current research trend within the ML community. If the research trend field is left blank, please use your common knowledge to assess the trend. Finally, give an overall score (0-100) and 10 dimension scores (for each dimension, provide a rating (1-10)) as the evaluation for the idea. The output format should follow these rules: Overall Score of an idea (0-100), with 10 Dimension Scores: [d1, d2, d3, ..., d10], where di is the score of the i-th dimension. An example of output is: 'Overall Score=89. Dimension Scores=[8,9,9,9,9,9,9,9,9,9]'.\n" + ' The details of rating are as follow:\n' + '1. Novelty\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'How original and unique is the idea?\n' + 'Does it introduce a new perspective or significant advancement compared to existing methods?\n' + 'How does it align with or diverge from the innovations highlighted in the trend?\n' + '2. Technical Depth\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Assess the technical rigor of the idea.\n' + 'Does it include solid theoretical foundations, robust algorithms, and detailed methodologies?\n' + 'Is the technical depth in line with the state-of-the-art techniques noted in the trend?\n' + '3. Impact and Significance\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Evaluate the potential impact of the idea on the ML community and beyond.\n' + 'How significant is its contribution to advancing the field?\n' + 'Does it address high-impact problems or gaps identified in the trend?\n' + '4. Feasibility and Practicality\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Assess the feasibility of implementing the idea.\n' + 'Is it practically applicable in real-world scenarios?\n' + 'Does it consider efficiency and scalability, in line with the practical application focus of the trend?\n' + '5. Theoretical Foundation and Conceptual Soundness\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Evaluate the theoretical foundation and conceptual soundness of the idea.\n' + 'Are the underlying principles well-defined and logically consistent?\n' + 'Does the idea demonstrate a deep understanding of relevant theories and concepts?\n' + 'How does it contribute to advancing theoretical understanding within the field?\n' + '6. Clarity and Presentation\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Assess the clarity, organization, and presentation quality of the idea.\n' + 'Is the idea communicated effectively, adhering to high presentation standards seen in top-tier ML conferences?\n' + '7. Potential for Real-world Applications\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Evaluate the potential of the idea to be applied in real-world scenarios.\n' + 'How applicable is it in practical settings and industry contexts?\n' + 'Does it address real-world problems or challenges identified in the trend?\n' + '8. Innovation Potential\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Assess the potential of the idea to inspire further research and innovation within the ML community.\n' + 'Does it open up new avenues for research or provide a novel framework aligning with the emerging trends and future directions of the trend?\n' + '9. Ethical Considerations\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Consider the ethical implications and societal impact of the idea.\n' + 'Does it adhere to the growing emphasis on ethical AI and responsible ML practices as highlighted in the trend?\n' + '10. Interdisciplinary Connections\n' + 'Rating (1-10):\n' + 'Comments:\n' + 'Evaluate the potential for the idea to connect with and contribute to other disciplines beyond ML.\n' + 'Does it align with the trend of interdisciplinary research and collaboration, integrating with fields such as data science, neuroscience, or social sciences?\n' + 'Here is the idea to evaluate: {idea}.\n' + 'Here is the research trend: {trend}.\n' ) - - - input_data = { - "idea": idea, - "trend": trend - } + input_data = {'idea': idea, 'trend': trend} prompt = prompt_idea.format_map(input_data) evaluation_result = model_prompting(model_name, prompt) # merge results from List[Str] to Str - combined_result = "\n".join(evaluation_result) + combined_result = '\n'.join(evaluation_result) return combined_result + @beartype def paper_quality_eval_prompting( - idea: str, - paper: Dict[str,str], - model_name: str + idea: str, paper: Dict[str, str], model_name: str ) -> str: paper_prompt = """ Please evaluate the paper draft based on the following dimensions. Finally, give an overall score (0-100) and 10 dimension scores (for each dimension, provide a rating (1-10)) as the evaluation for the draft. The output format should follow these rules: Overall Score of a paper draft (0-100), with 10 Dimension Scores: [d1, d2, d3, ..., d10], where di is the score of the i-th dimension. An example of output is: 'Overall Score=85. Dimension Scores=[7,8,9,7,8,9,8,8,8,9]'. @@ -176,29 +168,26 @@ def paper_quality_eval_prompting( """ - - input_data = { - "idea": idea, - "title": paper["title"], - "abstract": paper["abstract"], + 'idea': idea, + 'title': paper['title'], + 'abstract': paper['abstract'], } prompt = paper_prompt.format_map(input_data) evaluation_result = model_prompting(model_name, prompt) # merge results from List[Str] to Str - combined_result = "\n".join(evaluation_result) + combined_result = '\n'.join(evaluation_result) return combined_result - def review_quality_eval_prompting( idea: str, - trend: str, - paper: Dict[str,str], + trend: str, + paper: Dict[str, str], review: List[str], decision: str, - model_name: str + model_name: str, ) -> str: review_prompt = """ @@ -289,21 +278,22 @@ def review_quality_eval_prompting( - Do the scores reflect a reasonable and unbiased assessment of the paper? """ - # Organize the reviews - organized_reviews = "\n".join([f"Reviewer {i+1}'s comment: {review[i]}" for i in range(len(review))]) + organized_reviews = '\n'.join( + [f"Reviewer {i+1}'s comment: {review[i]}" for i in range(len(review))] + ) input_data = { - "regulations": regulations, - "idea": idea, - "trend": trend, - "title": paper["title"], - "abstract": paper["abstract"], - "review": organized_reviews, - "final_decision": decision, + 'regulations': regulations, + 'idea': idea, + 'trend': trend, + 'title': paper['title'], + 'abstract': paper['abstract'], + 'review': organized_reviews, + 'final_decision': decision, } prompt = review_prompt.format_map(input_data) evaluation_result = model_prompting(model_name, prompt) # merge results from List[Str] to Str - combined_result = "\n".join(evaluation_result) + combined_result = '\n'.join(evaluation_result) return combined_result diff --git a/research_town/utils/logging.py b/research_town/utils/logging.py index df91c2d0..cd761fb6 100644 --- a/research_town/utils/logging.py +++ b/research_town/utils/logging.py @@ -2,7 +2,11 @@ from beartype.typing import Dict, List, Union -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s' +) + + def logging_callback(messages: Union[List[Dict[str, str]], None] = None) -> None: """ Logs messages using the logging module. diff --git a/research_town/utils/model_prompting.py b/research_town/utils/model_prompting.py index 21daf0ed..49f271b1 100644 --- a/research_town/utils/model_prompting.py +++ b/research_town/utils/model_prompting.py @@ -21,7 +21,7 @@ def model_prompting( """ completion = litellm.completion( model=llm_model, - messages=[{"role": "user", "content": prompt}], + messages=[{'role': 'user', 'content': prompt}], max_tokens=max_token_num, # for some models, 'n'(The number of chat completion choices ) is not supported. n=return_num, diff --git a/research_town/utils/paper_collector.py b/research_town/utils/paper_collector.py index 0ce68c89..b7741058 100644 --- a/research_town/utils/paper_collector.py +++ b/research_town/utils/paper_collector.py @@ -8,24 +8,24 @@ from beartype.typing import Any, Dict, List, Tuple from transformers import BertModel, BertTokenizer -ATOM_NAMESPACE = "{http://www.w3.org/2005/Atom}" +ATOM_NAMESPACE = '{http://www.w3.org/2005/Atom}' -def get_related_papers( - corpus: List[str], query: str, num: int -) -> List[str]: + +def get_related_papers(corpus: List[str], query: str, num: int) -> List[str]: corpus_embedding = get_bert_embedding(corpus) query_embedding = get_bert_embedding([query]) indices = neiborhood_search(corpus_embedding, query_embedding, num) related_papers = [corpus[idx] for idx in indices[0].tolist()] return related_papers + def get_bert_embedding(instructions: List[str]) -> List[torch.Tensor]: - tokenizer = BertTokenizer.from_pretrained("facebook/contriever") - model = BertModel.from_pretrained("facebook/contriever").to(torch.device("cpu")) + tokenizer = BertTokenizer.from_pretrained('facebook/contriever') + model = BertModel.from_pretrained('facebook/contriever').to(torch.device('cpu')) encoded_input_all = [ - tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to( - torch.device("cpu") + tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to( + torch.device('cpu') ) for text in instructions ] @@ -34,7 +34,7 @@ def get_bert_embedding(instructions: List[str]) -> List[torch.Tensor]: emb_list = [] for inter in encoded_input_all: emb = model(**inter) - emb_list.append(emb["last_hidden_state"].mean(1)) + emb_list.append(emb['last_hidden_state'].mean(1)) return emb_list @@ -46,8 +46,8 @@ def neiborhood_search( xq = torch.cat(query_data, 0).cpu().numpy() xb = torch.cat(corpus_data, 0).cpu().numpy() index = faiss.IndexFlatIP(d) - xq = xq.astype("float32") - xb = xb.astype("float32") + xq = xq.astype('float32') + xb = xb.astype('float32') faiss.normalize_L2(xq) faiss.normalize_L2(xb) index.add(xb) # add vectors to the index @@ -59,11 +59,11 @@ def find_text(element: ElementTree.Element, path: str) -> str: found_element = element.find(path) if found_element is not None and found_element.text is not None: return found_element.text.strip() - return "" + return '' def get_daily_papers( - topic: str, query: str = "slam", max_results: int = 2 + topic: str, query: str = 'slam', max_results: int = 2 ) -> Tuple[Dict[str, Dict[str, List[str]]], str]: client = arxiv.Client() search = arxiv.Search( @@ -71,54 +71,56 @@ def get_daily_papers( ) results = client.results(search) content: Dict[str, Dict[str, List[str]]] = {} - newest_day = "" + newest_day = '' for result in results: paper_title = result.title paper_url = result.entry_id - paper_abstract = result.summary.replace("\n", " ") + paper_abstract = result.summary.replace('\n', ' ') publish_time = result.published.date() newest_day = publish_time if publish_time in content: - content[publish_time]['abstract'].append(paper_title + ": " + paper_abstract) - content[publish_time]['info'].append(paper_title + ": " + paper_url) + content[publish_time]['abstract'].append( + paper_title + ': ' + paper_abstract + ) + content[publish_time]['info'].append(paper_title + ': ' + paper_url) else: content[publish_time] = {} - content[publish_time]['abstract'] = [paper_title + ": " + paper_abstract] - content[publish_time]['info'] = [paper_title + ": " + paper_url] + content[publish_time]['abstract'] = [paper_title + ': ' + paper_abstract] + content[publish_time]['info'] = [paper_title + ': ' + paper_url] return content, newest_day -def get_papers(entries: List[ElementTree.Element], author_name: str) -> Tuple[List[Dict[str, Any]], Dict[int, List[ElementTree.Element]]]: + +def get_papers( + entries: List[ElementTree.Element], author_name: str +) -> Tuple[List[Dict[str, Any]], Dict[int, List[ElementTree.Element]]]: papers_list: List[Dict[str, Any]] = [] papers_by_year: Dict[int, List[ElementTree.Element]] = {} for entry in entries: - title = find_text(entry, f"{ATOM_NAMESPACE}title") - published = find_text(entry, f"{ATOM_NAMESPACE}published") - abstract = find_text(entry, f"{ATOM_NAMESPACE}summary") - authors_elements = entry.findall(f"{ATOM_NAMESPACE}author") + title = find_text(entry, f'{ATOM_NAMESPACE}title') + published = find_text(entry, f'{ATOM_NAMESPACE}published') + abstract = find_text(entry, f'{ATOM_NAMESPACE}summary') + authors_elements = entry.findall(f'{ATOM_NAMESPACE}author') authors = [ - find_text(author, f"{ATOM_NAMESPACE}name") - for author in authors_elements + find_text(author, f'{ATOM_NAMESPACE}name') for author in authors_elements ] - link = find_text(entry, f"{ATOM_NAMESPACE}id") + link = find_text(entry, f'{ATOM_NAMESPACE}id') if author_name in authors: coauthors = [author for author in authors if author != author_name] - coauthors_str = ", ".join(coauthors) + coauthors_str = ', '.join(coauthors) papers_list.append( { - "date": published, - "Title & Abstract": f"{title}; {abstract}", - "coauthors": coauthors_str, - "link": link, + 'date': published, + 'Title & Abstract': f'{title}; {abstract}', + 'coauthors': coauthors_str, + 'link': link, } ) published_date = published - date_obj = datetime.datetime.strptime( - published_date, "%Y-%m-%dT%H:%M:%SZ" - ) + date_obj = datetime.datetime.strptime(published_date, '%Y-%m-%dT%H:%M:%SZ') year = date_obj.year if year not in papers_by_year: papers_by_year[year] = [] @@ -126,7 +128,10 @@ def get_papers(entries: List[ElementTree.Element], author_name: str) -> Tuple[Li return papers_list, papers_by_year -def select_papers(papers_by_year: Dict[int, List[ElementTree.Element]], author_name: str) -> List[Dict[str, Any]]: + +def select_papers( + papers_by_year: Dict[int, List[ElementTree.Element]], author_name: str +) -> List[Dict[str, Any]]: papers_list: List[Dict[str, Any]] = [] for cycle_start in range(min(papers_by_year), max(papers_by_year) + 1, 5): @@ -135,48 +140,37 @@ def select_papers(papers_by_year: Dict[int, List[ElementTree.Element]], author_n if year in papers_by_year: selected_papers = papers_by_year[year][:2] for paper in selected_papers: - title = find_text( - paper, f"{ATOM_NAMESPACE}title" - ) - abstract = find_text( - paper, f"{ATOM_NAMESPACE}summary" - ) - authors_elements = paper.findall( - f"{ATOM_NAMESPACE}author" - ) + title = find_text(paper, f'{ATOM_NAMESPACE}title') + abstract = find_text(paper, f'{ATOM_NAMESPACE}summary') + authors_elements = paper.findall(f'{ATOM_NAMESPACE}author') co_authors = [ - find_text( - author, f"{ATOM_NAMESPACE}name" - ) + find_text(author, f'{ATOM_NAMESPACE}name') for author in authors_elements - if find_text( - author, f"{ATOM_NAMESPACE}name" - ) - != author_name + if find_text(author, f'{ATOM_NAMESPACE}name') != author_name ] papers_list.append( { - "Author": author_name, - "Title & Abstract": f"{title}; {abstract}", - "Date Period": f"{year}", - "Cycle": f"{cycle_start}-{cycle_end}", - "Co_author": ", ".join(co_authors), + 'Author': author_name, + 'Title & Abstract': f'{title}; {abstract}', + 'Date Period': f'{year}', + 'Cycle': f'{cycle_start}-{cycle_end}', + 'Co_author': ', '.join(co_authors), } ) return papers_list def get_paper_list(author_name: str) -> List[Dict[str, Any]]: - author_query = author_name.replace(" ", "+") - url = f"http://export.arxiv.org/api/query?search_query=au:{author_query}&start=0&max_results=300" + author_query = author_name.replace(' ', '+') + url = f'http://export.arxiv.org/api/query?search_query=au:{author_query}&start=0&max_results=300' response = requests.get(url) if response.status_code == 200: xml_content = response.content.decode('utf-8', errors='ignore') root = ElementTree.fromstring(xml_content) - entries = root.findall(f"{ATOM_NAMESPACE}entry") + entries = root.findall(f'{ATOM_NAMESPACE}entry') papers_list, papers_by_year = get_papers(entries, author_name) if len(papers_list) > 40: @@ -186,5 +180,5 @@ def get_paper_list(author_name: str) -> List[Dict[str, Any]]: papers_list = papers_list[:10] return papers_list else: - print("Failed to fetch data from arXiv.") + print('Failed to fetch data from arXiv.') return [] diff --git a/research_town/utils/serializer.py b/research_town/utils/serializer.py index 69f66624..82673dd5 100644 --- a/research_town/utils/serializer.py +++ b/research_town/utils/serializer.py @@ -13,17 +13,26 @@ def serialize(cls, obj: Any) -> Any: return {key: cls.serialize(value) for key, value in obj.items()} elif isinstance(obj, (list, tuple, set)): return type(obj)(cls.serialize(item) for item in obj) - elif hasattr(obj, '__dict__'): # custom class + elif hasattr(obj, '__dict__'): # custom class return { '__class__': obj.__class__.__name__, '__module__': obj.__class__.__module__, - **{key: cls.serialize(value) for key, value in obj.__dict__.items() if not callable(value) and key != 'ckpt'} + **{ + key: cls.serialize(value) + for key, value in obj.__dict__.items() + if not callable(value) and key != 'ckpt' + }, } else: - raise TypeError(f"Unsupported data type: {type(obj)}") + raise TypeError(f'Unsupported data type: {type(obj)}') @classmethod - def deserialize(cls, data: Union[Dict[str, Any], List[Any], Tuple[Any, ...], Set[Any], str, int, bool]) -> Any: + def deserialize( + cls, + data: Union[ + Dict[str, Any], List[Any], Tuple[Any, ...], Set[Any], str, int, bool + ], + ) -> Any: if not isinstance(data, dict): if isinstance(data, list): return [cls.deserialize(item) for item in data] @@ -34,7 +43,7 @@ def deserialize(cls, data: Union[Dict[str, Any], List[Any], Tuple[Any, ...], Set if isinstance(data, str) or isinstance(data, int) or isinstance(data, bool): return data else: - raise TypeError(f"Unsupported data type: {type(data)}") + raise TypeError(f'Unsupported data type: {type(data)}') class_name = data.get('__class__') module_name = data.get('__module__') @@ -44,7 +53,9 @@ def deserialize(cls, data: Union[Dict[str, Any], List[Any], Tuple[Any, ...], Set target_class = getattr(module, class_name) obj = target_class.__new__(target_class) - attributes = {k: v for k, v in data.items() if k not in {'__class__', '__module__'}} + attributes = { + k: v for k, v in data.items() if k not in {'__class__', '__module__'} + } if issubclass(target_class, BaseModel): # Use Pydantic's construct method for BaseModel subclasses diff --git a/research_town/utils/string_mapper.py b/research_town/utils/string_mapper.py index 50cc2789..cdeeaf80 100644 --- a/research_town/utils/string_mapper.py +++ b/research_town/utils/string_mapper.py @@ -7,28 +7,34 @@ def map_idea_list_to_str(ideas: List[Dict[str, str]]) -> str: result += map_idea_to_str(idea) return result + def map_idea_to_str(idea: Dict[str, str]) -> str: return f"{idea['content']}" + def map_paper_list_to_str(papers: List[Dict[str, str]]) -> str: result = '' for paper in papers: result += map_paper_to_str(paper) return result + def map_review_list_to_str(reviews: List[Dict[str, Union[int, str]]]) -> str: result = '' for review in reviews: result += map_review_to_str(review) return result + def map_paper_to_str(paper: Dict[str, str]) -> str: return f"Title: {paper['title']}\nAbstract: {paper['abstract']}" + def map_review_to_str(review: Dict[str, Union[int, str]]) -> str: score = review['review_score'] content = review['review_content'] - return f"Score: {score}\nContent: {content}" + return f'Score: {score}\nContent: {content}' + def map_message_to_str(message: Dict[str, str]) -> str: return f"Message from {message['agent_from_name']} to {message['agent_to_name']}\n" diff --git a/research_town/utils/tools.py b/research_town/utils/tools.py index 30b82e23..1ec731d6 100644 --- a/research_town/utils/tools.py +++ b/research_town/utils/tools.py @@ -7,20 +7,20 @@ def show_time() -> str: time_stamp = ( - "\033[1;31;40m[" - + str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) - + "]\033[0m" + '\033[1;31;40m[' + + str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + + ']\033[0m' ) return time_stamp def text_wrap(text: str) -> str: - return "\033[1;31;40m" + str(text) + "\033[0m" + return '\033[1;31;40m' + str(text) + '\033[0m' def write_to_json(data: Dict[str, Any], output_file: str) -> None: - with open(output_file, "w") as file: + with open(output_file, 'w') as file: json.dump(data, file, indent=4) @@ -30,19 +30,19 @@ def check_path(path: str) -> None: def count_entries_in_json(file_path: str) -> int: - with open(file_path, "r") as file: + with open(file_path, 'r') as file: data = json.load(file) return len(data) def clean_title(title: str) -> str: - cleaned_title = title.replace("\n", " ").strip() + cleaned_title = title.replace('\n', ' ').strip() cleaned_title = os.path.splitext(cleaned_title)[0] cleaned_title = ( - cleaned_title.replace(":", "") - .replace("- ", " ") - .replace("-", " ") - .replace("_", " ") + cleaned_title.replace(':', '') + .replace('- ', ' ') + .replace('-', ' ') + .replace('_', ' ') .title() ) diff --git a/tests/constants.py b/tests/constants.py index e8525fb6..a8fd446e 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,4 +1,3 @@ - from research_town.dbs import ( AgentAgentDiscussionLog, AgentPaperMetaReviewLog, @@ -12,50 +11,50 @@ ) paper_profile_A = PaperProfile( - title="A Survey on Machine Learning", - abstract="This paper surveys the field of machine learning.", + title='A Survey on Machine Learning', + abstract='This paper surveys the field of machine learning.', ) paper_profile_B = PaperProfile( - title="A Survey on Graph Neural Networks", - abstract="This paper surveys the field of graph neural networks.", + title='A Survey on Graph Neural Networks', + abstract='This paper surveys the field of graph neural networks.', ) agent_profile_A = AgentProfile( - name="Jiaxuan You", - bio="A researcher in the field of machine learning.", + name='Jiaxuan You', + bio='A researcher in the field of machine learning.', ) agent_profile_B = AgentProfile( - name="Rex Ying", - bio="A researcher in the field of GNN.", + name='Rex Ying', + bio='A researcher in the field of GNN.', ) research_idea_A = ResearchIdea( - content="A new idea", + content='A new idea', ) research_idea_B = ResearchIdea( - content="Another idea", + content='Another idea', ) research_insight_A = ResearchInsight( - content="A new insight", + content='A new insight', ) research_insight_B = ResearchInsight( - content="Another insight", + content='Another insight', ) research_paper_submission_A = ResearchPaperSubmission( - title="A Survey on Machine Learning", - abstract="This paper surveys the field of machine learning.", + title='A Survey on Machine Learning', + abstract='This paper surveys the field of machine learning.', ) research_paper_submission_B = ResearchPaperSubmission( - title="A Survey on Graph Neural Networks", - abstract="This paper surveys the field of graph neural networks.", + title='A Survey on Graph Neural Networks', + abstract='This paper surveys the field of graph neural networks.', ) @@ -64,7 +63,7 @@ paper_pk=paper_profile_A.pk, agent_pk=agent_profile_A.pk, review_score=5, - review_content="This paper is well-written.", + review_content='This paper is well-written.', ) agent_paper_meta_review_log = AgentPaperMetaReviewLog( @@ -72,14 +71,14 @@ paper_pk=paper_profile_B.pk, agent_pk=agent_profile_B.pk, decision=True, - meta_review="This paper is well-written.", + meta_review='This paper is well-written.', ) agent_paper_rebuttal_log = AgentPaperRebuttalLog( timestep=0, paper_pk=paper_profile_A.pk, agent_pk=agent_profile_A.pk, - rebuttal_content="I have revised the paper.", + rebuttal_content='I have revised the paper.', ) agent_agent_discussion_log = AgentAgentDiscussionLog( @@ -88,5 +87,5 @@ agent_from_name=agent_profile_A.name, agent_to_pk=agent_profile_B.pk, agent_to_name=agent_profile_B.name, - message="How about the idea of building a research town with language agents?" + message='How about the idea of building a research town with language agents?', ) diff --git a/tests/test_agents.py b/tests/test_agents.py index b7a7a7de..cbbbb9a4 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -19,28 +19,33 @@ def test_get_profile() -> None: research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) - assert research_agent.profile.name == "Jiaxuan You" - assert research_agent.profile.bio == "A researcher in the field of machine learning." + assert research_agent.profile.name == 'Jiaxuan You' + assert ( + research_agent.profile.bio == 'A researcher in the field of machine learning.' + ) + -@patch("research_town.utils.agent_prompter.model_prompting") +@patch('research_town.utils.agent_prompter.model_prompting') def test_find_collaborators(mock_model_prompting: MagicMock) -> None: mock_model_prompting.return_value = [ - "These are collaborators including Jure Leskovec, Rex Ying, Saining Xie, Kaiming He."] + 'These are collaborators including Jure Leskovec, Rex Ying, Saining Xie, Kaiming He.' + ] research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) collaborators = research_agent.find_collaborators( - paper=paper_profile_A, parameter=0.5, max_number=3) + paper=paper_profile_A, parameter=0.5, max_number=3 + ) assert isinstance(collaborators, list) assert len(collaborators) <= 3 -@patch("research_town.utils.agent_prompter.model_prompting") -@patch("research_town.utils.agent_prompter.get_related_papers") +@patch('research_town.utils.agent_prompter.model_prompting') +@patch('research_town.utils.agent_prompter.get_related_papers') def test_read_paper( mock_get_related_papers: MagicMock, mock_model_prompting: MagicMock, @@ -49,19 +54,19 @@ def test_read_paper( mock_model_prompting.side_effect = mock_prompting research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) research_insight = research_agent.read_paper( papers=[paper_profile_A, paper_profile_B], - domains=["machine learning", "graph neural network"], + domains=['machine learning', 'graph neural network'], ) assert len(research_insight) == 1 assert research_insight[0].pk is not None - assert research_insight[0].content == "Graph Neural Network" + assert research_insight[0].content == 'Graph Neural Network' -@patch("research_town.utils.agent_prompter.model_prompting") -@patch("research_town.utils.agent_prompter.get_related_papers") +@patch('research_town.utils.agent_prompter.model_prompting') +@patch('research_town.utils.agent_prompter.get_related_papers') def test_think_idea( mock_get_related_papers: MagicMock, mock_model_prompting: MagicMock, @@ -71,75 +76,72 @@ def test_think_idea( research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) research_ideas = research_agent.think_idea( insights=[research_insight_A, research_insight_B], ) assert len(research_ideas) == 2 assert research_ideas[0].pk is not None - assert research_ideas[0].content == "This is a research idea." + assert research_ideas[0].content == 'This is a research idea.' assert research_ideas[1].pk is not None - assert research_ideas[1].content == "This is a research idea." + assert research_ideas[1].content == 'This is a research idea.' -@patch("research_town.utils.agent_prompter.model_prompting") +@patch('research_town.utils.agent_prompter.model_prompting') def test_write_paper(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = ["This is a paper abstract."] + mock_model_prompting.return_value = ['This is a paper abstract.'] research_agent = BaseResearchAgent( agent_profile=agent_profile_B, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) paper = research_agent.write_paper( ideas=[research_idea_A, research_idea_B], papers=[paper_profile_A, paper_profile_B], ) - assert paper.abstract == "This is a paper abstract." + assert paper.abstract == 'This is a paper abstract.' assert paper.pk is not None -@patch("research_town.utils.agent_prompter.model_prompting") +@patch('research_town.utils.agent_prompter.model_prompting') def test_write_paper_review(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) review = research_agent.write_paper_review(paper=paper_profile_A) assert review.review_score == 2 - assert review.review_content == "This is a paper review for MambaOut." + assert review.review_content == 'This is a paper review for MambaOut.' -@patch("research_town.utils.agent_prompter.model_prompting") +@patch('research_town.utils.agent_prompter.model_prompting') def test_write_paper_meta_review(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = ["Accept. This is a good paper."] + mock_model_prompting.return_value = ['Accept. This is a good paper.'] research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) reviews = research_agent.write_paper_review(paper=paper_profile_A) - meta_review= research_agent.write_paper_meta_review( - paper=paper_profile_A, - reviews=[reviews] + meta_review = research_agent.write_paper_meta_review( + paper=paper_profile_A, reviews=[reviews] ) assert meta_review.decision is True - assert meta_review.meta_review == "Accept. This is a good paper." + assert meta_review.meta_review == 'Accept. This is a good paper.' assert meta_review.timestep >= 0 assert meta_review.pk is not None -@patch("research_town.utils.agent_prompter.model_prompting") +@patch('research_town.utils.agent_prompter.model_prompting') def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = [ - "This is a paper rebuttal." - ] + mock_model_prompting.return_value = ['This is a paper rebuttal.'] research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) review = research_agent.write_paper_review(paper=paper_profile_A) rebuttal = research_agent.write_rebuttal( @@ -149,20 +151,24 @@ def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: assert isinstance(rebuttal, AgentPaperRebuttalLog) if rebuttal.rebuttal_content is not None: assert len(rebuttal.rebuttal_content) > 0 - assert rebuttal.rebuttal_content == "This is a paper rebuttal." + assert rebuttal.rebuttal_content == 'This is a paper rebuttal.' -@patch("research_town.utils.agent_prompter.model_prompting") + +@patch('research_town.utils.agent_prompter.model_prompting') def test_discuss(mock_model_prompting: MagicMock) -> None: mock_model_prompting.return_value = [ - "I believe in the potential of using automous agents to simulate the current research pipeline." + 'I believe in the potential of using automous agents to simulate the current research pipeline.' ] research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) response = research_agent.discuss(agent_agent_discussion_log) - assert response.message == "I believe in the potential of using automous agents to simulate the current research pipeline." + assert ( + response.message + == 'I believe in the potential of using automous agents to simulate the current research pipeline.' + ) assert response.agent_to_pk is not None assert response.agent_from_pk is not None assert response.timestep >= 0 diff --git a/tests/test_dbs.py b/tests/test_dbs.py index fb6479b4..0d6742f2 100644 --- a/tests/test_dbs.py +++ b/tests/test_dbs.py @@ -1,4 +1,3 @@ - from beartype.typing import Any, Dict, Optional from research_town.dbs.agent_db import AgentProfile, AgentProfileDB @@ -10,23 +9,31 @@ EnvLogDB, ) from research_town.dbs.paper_db import PaperProfile, PaperProfileDB -from research_town.dbs.progress_db import ( - ResearchIdea, - ResearchProgressDB, -) +from research_town.dbs.progress_db import ResearchIdea, ResearchProgressDB -def test_envlogdb_basic()->None: +def test_envlogdb_basic() -> None: db = EnvLogDB() - review_log = AgentPaperReviewLog(paper_pk="paper1", agent_pk="agent1", review_score=5, review_content="Good paper") - rebuttal_log = AgentPaperRebuttalLog(paper_pk="paper1", agent_pk="agent1", rebuttal_content="I disagree with the review") - meta_review_log = AgentPaperMetaReviewLog(paper_pk="paper1", agent_pk="agent1", decision=True, meta_review="Accept") + review_log = AgentPaperReviewLog( + paper_pk='paper1', + agent_pk='agent1', + review_score=5, + review_content='Good paper', + ) + rebuttal_log = AgentPaperRebuttalLog( + paper_pk='paper1', + agent_pk='agent1', + rebuttal_content='I disagree with the review', + ) + meta_review_log = AgentPaperMetaReviewLog( + paper_pk='paper1', agent_pk='agent1', decision=True, meta_review='Accept' + ) discussion_log = AgentAgentDiscussionLog( - agent_from_pk="agent1", + agent_from_pk='agent1', agent_from_name='Rex Ying', - agent_to_pk="agent2", + agent_to_pk='agent2', agent_to_name='John Doe', - message="Let's discuss this paper" + message="Let's discuss this paper", ) db.add(review_log) @@ -34,172 +41,181 @@ def test_envlogdb_basic()->None: db.add(meta_review_log) db.add(discussion_log) - new_log = AgentPaperReviewLog(paper_pk="paper2", agent_pk="agent2", review_score=4, review_content="Interesting paper") + new_log = AgentPaperReviewLog( + paper_pk='paper2', + agent_pk='agent2', + review_score=4, + review_content='Interesting paper', + ) db.add(new_log) - assert new_log.dict() in db.data["AgentPaperReviewLog"] + assert new_log.dict() in db.data['AgentPaperReviewLog'] - conditions: Dict[str, Any] = {"paper_pk": "paper1"} + conditions: Dict[str, Any] = {'paper_pk': 'paper1'} results = db.get(AgentPaperReviewLog, **conditions) assert len(results) == 1 - assert results[0].review_content == "Good paper" + assert results[0].review_content == 'Good paper' - updates = {"review_score": 3, "review_content": "Decent paper"} - updated_count = db.update(AgentPaperReviewLog, {"paper_pk": "paper1"}, updates) + updates = {'review_score': 3, 'review_content': 'Decent paper'} + updated_count = db.update(AgentPaperReviewLog, {'paper_pk': 'paper1'}, updates) assert updated_count == 2 updated_log = db.get(AgentPaperReviewLog, **conditions)[0] assert updated_log.review_score == 3 - assert updated_log.review_content == "Decent paper" + assert updated_log.review_content == 'Decent paper' deleted_count = db.delete(AgentPaperReviewLog, **conditions) assert deleted_count == 1 results = db.get(AgentPaperReviewLog, **conditions) assert len(results) == 0 - file_name = "./data/dbs/test_env_logs_db.json" + file_name = './data/dbs/test_env_logs_db.json' db.save_to_file(file_name) new_db = EnvLogDB() new_db.load_from_file(file_name) - assert len(new_db.data["AgentPaperReviewLog"]) == 1 - assert len(new_db.data["AgentPaperRebuttalLog"]) == 1 - assert len(new_db.data["AgentPaperMetaReviewLog"]) == 1 - assert len(new_db.data["AgentAgentDiscussionLog"]) == 1 - assert new_db.data["AgentPaperReviewLog"][0]["review_content"] == "Interesting paper" - + assert len(new_db.data['AgentPaperReviewLog']) == 1 + assert len(new_db.data['AgentPaperRebuttalLog']) == 1 + assert len(new_db.data['AgentPaperMetaReviewLog']) == 1 + assert len(new_db.data['AgentAgentDiscussionLog']) == 1 + assert ( + new_db.data['AgentPaperReviewLog'][0]['review_content'] == 'Interesting paper' + ) -def test_agentprofiledb_basic()->None: +def test_agentprofiledb_basic() -> None: db = AgentProfileDB() - agent1 = AgentProfile(name="John Doe", bio="Researcher in AI", institute="AI Institute") - agent2 = AgentProfile(name="Jane Smith", bio="Expert in NLP", institute="NLP Lab") + agent1 = AgentProfile( + name='John Doe', bio='Researcher in AI', institute='AI Institute' + ) + agent2 = AgentProfile(name='Jane Smith', bio='Expert in NLP', institute='NLP Lab') db.add(agent1) db.add(agent2) - agent3 = AgentProfile(name="Alice Johnson", bio="Data Scientist", institute="Data Lab") + agent3 = AgentProfile( + name='Alice Johnson', bio='Data Scientist', institute='Data Lab' + ) db.add(agent3) assert agent3.pk in db.data - assert db.data[agent3.pk].name == "Alice Johnson" + assert db.data[agent3.pk].name == 'Alice Johnson' - - updates = {"bio": "Senior Researcher in AI"} + updates = {'bio': 'Senior Researcher in AI'} updates_with_optional: Dict[str, Optional[str]] = {k: v for k, v in updates.items()} success = db.update(agent1.pk, updates_with_optional) assert success - assert db.data[agent1.pk].bio == "Senior Researcher in AI" + assert db.data[agent1.pk].bio == 'Senior Researcher in AI' - success = db.update("non-existing-pk", {"bio": "New Bio"}) + success = db.update('non-existing-pk', {'bio': 'New Bio'}) assert not success success = db.delete(agent1.pk) assert success assert agent1.pk not in db.data - success = db.delete("non-existing-pk") + success = db.delete('non-existing-pk') assert not success - conditions: Dict[str, Any] = {"name": "Jane Smith"} + conditions: Dict[str, Any] = {'name': 'Jane Smith'} results = db.get(**conditions) assert len(results) == 1 - assert results[0].name == "Jane Smith" + assert results[0].name == 'Jane Smith' - file_name = "./data/dbs/test_agent_profile_db.json" + file_name = './data/dbs/test_agent_profile_db.json' db.save_to_file(file_name) new_db = AgentProfileDB() new_db.load_from_file(file_name) new_data = { - "2024-05-29": [ + '2024-05-29': [ { - "pk": agent1.pk, - "name": "John Doe", - "bio": "Updated bio", - "collaborators": [], - "institute": "AI Institute" + 'pk': agent1.pk, + 'name': 'John Doe', + 'bio': 'Updated bio', + 'collaborators': [], + 'institute': 'AI Institute', }, { - "pk": "new-pk", - "name": "New Agent", - "bio": "New agent bio", - "collaborators": [], - "institute": "New Institute" - } + 'pk': 'new-pk', + 'name': 'New Agent', + 'bio': 'New agent bio', + 'collaborators': [], + 'institute': 'New Institute', + }, ] } db.update_db(new_data) - assert db.data[agent1.pk].bio == "Updated bio" - assert "new-pk" in db.data - assert db.data["new-pk"].name == "New Agent" + assert db.data[agent1.pk].bio == 'Updated bio' + assert 'new-pk' in db.data + assert db.data['new-pk'].name == 'New Agent' -def test_paperprofiledb_basic()->None: +def test_paperprofiledb_basic() -> None: db = PaperProfileDB() paper1 = PaperProfile( - title="Sample Paper 1", - abstract="This is the abstract for paper 1", - authors=["Author A", "Author B"], - url="http://example.com/paper1", + title='Sample Paper 1', + abstract='This is the abstract for paper 1', + authors=['Author A', 'Author B'], + url='http://example.com/paper1', timestamp=1617181723, - keywords=["AI", "ML"], - domain="Computer Science", - citation_count=10 + keywords=['AI', 'ML'], + domain='Computer Science', + citation_count=10, ) paper2 = PaperProfile( - title="Sample Paper 2", - abstract="This is the abstract for paper 2", - authors=["Author C"], - url="http://example.com/paper2", + title='Sample Paper 2', + abstract='This is the abstract for paper 2', + authors=['Author C'], + url='http://example.com/paper2', timestamp=1617181756, - keywords=["Quantum Computing"], - domain="Physics", - citation_count=5 + keywords=['Quantum Computing'], + domain='Physics', + citation_count=5, ) db.add_paper(paper1) db.add_paper(paper2) new_paper = PaperProfile( - title="Sample Paper 3", - abstract="This is the abstract for paper 3", - authors=["Author D"], - url="http://example.com/paper3", + title='Sample Paper 3', + abstract='This is the abstract for paper 3', + authors=['Author D'], + url='http://example.com/paper3', timestamp=1617181789, - keywords=["Blockchain"], - domain="Computer Science", - citation_count=2 + keywords=['Blockchain'], + domain='Computer Science', + citation_count=2, ) db.add_paper(new_paper) assert new_paper.pk in db.data paper = db.get_paper(paper1.pk) assert paper is not None - assert paper.title == "Sample Paper 1" + assert paper.title == 'Sample Paper 1' - updates:Dict[str, Any] = {"title": "Updated Sample Paper 1", "citation_count": 15} + updates: Dict[str, Any] = {'title': 'Updated Sample Paper 1', 'citation_count': 15} result = db.update_paper(paper1.pk, updates) assert result - updated_paper:Optional[PaperProfile] = db.get_paper(paper1.pk) + updated_paper: Optional[PaperProfile] = db.get_paper(paper1.pk) assert updated_paper is not None - assert updated_paper.title == "Updated Sample Paper 1" + assert updated_paper.title == 'Updated Sample Paper 1' assert updated_paper.citation_count == 15 result = db.delete_paper(paper2.pk) assert result assert db.get_paper(paper2.pk) is None - domain:Dict[str, Any] = {"domain": "Computer Science"} + domain: Dict[str, Any] = {'domain': 'Computer Science'} results = db.query_papers(**domain) assert len(results) == 2 - assert results[0].title == "Updated Sample Paper 1" - assert results[1].title == "Sample Paper 3" + assert results[0].title == 'Updated Sample Paper 1' + assert results[1].title == 'Sample Paper 3' - file_name = "./data/dbs/test_paper_profile_db.json" + file_name = './data/dbs/test_paper_profile_db.json' db.save_to_file(file_name) new_db = PaperProfileDB() @@ -207,45 +223,46 @@ def test_paperprofiledb_basic()->None: assert len(new_db.data) == 2 assert paper1.pk in new_db.data - assert new_db.data[paper1.pk].title == "Updated Sample Paper 1" + assert new_db.data[paper1.pk].title == 'Updated Sample Paper 1' -def test_researchprogressdb_basic()->None: + +def test_researchprogressdb_basic() -> None: db = ResearchProgressDB() - idea1 = ResearchIdea(content="Idea for a new AI algorithm") - idea2 = ResearchIdea(content="Quantum computing research plan") + idea1 = ResearchIdea(content='Idea for a new AI algorithm') + idea2 = ResearchIdea(content='Quantum computing research plan') db.add(idea1) db.add(idea2) - new_idea = ResearchIdea(content="Blockchain research proposal") + new_idea = ResearchIdea(content='Blockchain research proposal') db.add(new_idea) - assert new_idea.dict() in db.data["ResearchIdea"] + assert new_idea.dict() in db.data['ResearchIdea'] - content:Dict[str, Any] = {"content": "Idea for a new AI algorithm"} + content: Dict[str, Any] = {'content': 'Idea for a new AI algorithm'} results = db.get(ResearchIdea, **content) assert len(results) == 1 - assert results[0].content == "Idea for a new AI algorithm" - + assert results[0].content == 'Idea for a new AI algorithm' - updates = {"content": "Updated idea content"} - updated_count = db.update(ResearchIdea, {"content": "Idea for a new AI algorithm"}, updates) + updates = {'content': 'Updated idea content'} + updated_count = db.update( + ResearchIdea, {'content': 'Idea for a new AI algorithm'}, updates + ) assert updated_count == 1 - content2:Dict[str, Any] = {"content": "Updated idea content"} + content2: Dict[str, Any] = {'content': 'Updated idea content'} updated_results = db.get(ResearchIdea, **content2) assert len(updated_results) == 1 - assert updated_results[0].content == "Updated idea content" - + assert updated_results[0].content == 'Updated idea content' - content3:Dict[str, Any] = {"content": "Quantum computing research plan"} + content3: Dict[str, Any] = {'content': 'Quantum computing research plan'} deleted_count = db.delete(ResearchIdea, **content3) assert deleted_count == 1 remaining_results = db.get(ResearchIdea, **content3) assert len(remaining_results) == 0 - file_name = "./data/dbs/test_research_progress_db.json" + file_name = './data/dbs/test_research_progress_db.json' db.save_to_file(file_name) new_db = ResearchProgressDB() new_db.load_from_file(file_name) - assert len(new_db.data["ResearchIdea"]) == 2 + assert len(new_db.data['ResearchIdea']) == 2 diff --git a/tests/test_envs.py b/tests/test_envs.py index 2077cd1e..2e0fd9b1 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -5,18 +5,13 @@ PaperRebuttalMultiAgentEnv, PaperSubmissionMultiAgentEnvironment, ) -from tests.constants import ( - agent_profile_A, - agent_profile_B, - paper_profile_A, -) +from tests.constants import agent_profile_A, agent_profile_B, paper_profile_A from tests.utils import mock_papers -@patch("research_town.utils.agent_prompter.model_prompting") +@patch('research_town.utils.agent_prompter.model_prompting') def test_paper_rebuttal_env(mock_model_prompting: MagicMock) -> None: - mock_model_prompting.return_value = [ - "Paper Rebuttal Environment."] + mock_model_prompting.return_value = ['Paper Rebuttal Environment.'] agent_db = AgentProfileDB() paper_db = PaperProfileDB() env_db = EnvLogDB() @@ -24,15 +19,12 @@ def test_paper_rebuttal_env(mock_model_prompting: MagicMock) -> None: agent_profiles=[agent_profile_A, agent_profile_B], agent_db=agent_db, paper_db=paper_db, - env_db=env_db + env_db=env_db, ) submission = paper_profile_A env.initialize_submission(submission) - env.assign_roles({ - agent_profile_A.pk: "author", - agent_profile_B.pk: "reviewer"} - ) + env.assign_roles({agent_profile_A.pk: 'author', agent_profile_B.pk: 'reviewer'}) while not env.terminated: env.step() @@ -40,31 +32,31 @@ def test_paper_rebuttal_env(mock_model_prompting: MagicMock) -> None: assert isinstance(env.reviews, list) assert len(env.reviews) > 0 assert isinstance(env.decision, str) - assert env.decision in ["accept", "reject", "boarderline"] + assert env.decision in ['accept', 'reject', 'boarderline'] assert isinstance(env.rebuttals, list) assert len(env.rebuttals) > 0 -@patch("research_town.utils.agent_prompter.model_prompting") -@patch("research_town.utils.agent_prompter.get_related_papers") +@patch('research_town.utils.agent_prompter.model_prompting') +@patch('research_town.utils.agent_prompter.get_related_papers') def test_paper_submission_env( mock_get_related_papers: MagicMock, mock_model_prompting: MagicMock, ) -> None: mock_get_related_papers.side_effect = mock_papers - mock_model_prompting.return_value = ["This is a paper."] + mock_model_prompting.return_value = ['This is a paper.'] agent_db = AgentProfileDB() paper_db = PaperProfileDB() env_db = EnvLogDB() env = PaperSubmissionMultiAgentEnvironment( agent_profiles=[agent_profile_A], task={ - "Survey on Machine Learning": "This paper surveys the field of machine learning." + 'Survey on Machine Learning': 'This paper surveys the field of machine learning.' }, agent_db=agent_db, paper_db=paper_db, - env_db=env_db + env_db=env_db, ) env.step() assert env.paper.abstract is not None - assert env.paper.abstract == "This is a paper." + assert env.paper.abstract == 'This is a paper.' diff --git a/tests/test_eval.py b/tests/test_eval.py index d5a82cf9..ea89f5b6 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -9,11 +9,11 @@ ReviewQualityEvaluator, ) -idea = "The idea behind Mamba is to improve upon existing foundation models in deep learning, which typically rely on the Transformer architecture and its attention mechanism. While subquadratic-time architectures like linear attention, gated convolution, recurrent models, and structured state space models (SSMs) have been developed to address the inefficiency of Transformers on long sequences, they have not matched the performance of attention-based models in key areas such as language processing. Mamba addresses the shortcomings of these models by enabling content-based reasoning and making several key improvements: Adaptive SSM Parameters: By allowing SSM parameters to be functions of the input, Mamba effectively handles discrete modalities. This enables the model to selectively propagate or forget information along the sequence based on the current token.Parallel Recurrent Algorithm: Despite the changes preventing the use of efficient convolutions, Mamba employs a hardware-aware parallel algorithm in recurrent mode to maintain efficiency.Simplified Architecture: Mamba integrates these selective SSMs into a streamlined neural network architecture that does not rely on attention or MLP blocks." +idea = 'The idea behind Mamba is to improve upon existing foundation models in deep learning, which typically rely on the Transformer architecture and its attention mechanism. While subquadratic-time architectures like linear attention, gated convolution, recurrent models, and structured state space models (SSMs) have been developed to address the inefficiency of Transformers on long sequences, they have not matched the performance of attention-based models in key areas such as language processing. Mamba addresses the shortcomings of these models by enabling content-based reasoning and making several key improvements: Adaptive SSM Parameters: By allowing SSM parameters to be functions of the input, Mamba effectively handles discrete modalities. This enables the model to selectively propagate or forget information along the sequence based on the current token.Parallel Recurrent Algorithm: Despite the changes preventing the use of efficient convolutions, Mamba employs a hardware-aware parallel algorithm in recurrent mode to maintain efficiency.Simplified Architecture: Mamba integrates these selective SSMs into a streamlined neural network architecture that does not rely on attention or MLP blocks.' -trend = "The current research trend in foundation models (FMs) involves developing large models that are pretrained on extensive datasets and then adapted for various downstream tasks. These FMs are primarily based on sequence models, which process sequences of inputs across different domains such as language, images, speech, audio, time series, and genomics. The predominant architecture for these models is the Transformer, which utilizes self-attention mechanisms. The strength of self-attention lies in its ability to handle complex data by routing information densely within a context window. However, this comes with significant limitations: difficulty in modeling outside of a finite context window and quadratic scaling with respect to window length.\n Efforts to create more efficient variants of attention have been extensive but often compromise the effectiveness that self-attention provides. As a result, no alternative has yet matched the empirical success of Transformers across various domains.Recently, structured state space models (SSMs) have emerged as a promising alternative. These models combine elements of recurrent neural networks (RNNs) and convolutional neural networks (CNNs), drawing from classical state space models. SSMs can be computed efficiently, either as recurrences or convolutions, and they scale linearly or near-linearly with sequence length. They also have mechanisms for modeling long-range dependencies, particularly excelling in benchmarks like the Long Range Arena.\nDifferent variants of SSMs have been successful in continuous signal data domains such as audio and vision. However, they have not been as effective in handling discrete and information-dense data, such as text, highlighting an area for further research and development." +trend = 'The current research trend in foundation models (FMs) involves developing large models that are pretrained on extensive datasets and then adapted for various downstream tasks. These FMs are primarily based on sequence models, which process sequences of inputs across different domains such as language, images, speech, audio, time series, and genomics. The predominant architecture for these models is the Transformer, which utilizes self-attention mechanisms. The strength of self-attention lies in its ability to handle complex data by routing information densely within a context window. However, this comes with significant limitations: difficulty in modeling outside of a finite context window and quadratic scaling with respect to window length.\n Efforts to create more efficient variants of attention have been extensive but often compromise the effectiveness that self-attention provides. As a result, no alternative has yet matched the empirical success of Transformers across various domains.Recently, structured state space models (SSMs) have emerged as a promising alternative. These models combine elements of recurrent neural networks (RNNs) and convolutional neural networks (CNNs), drawing from classical state space models. SSMs can be computed efficiently, either as recurrences or convolutions, and they scale linearly or near-linearly with sequence length. They also have mechanisms for modeling long-range dependencies, particularly excelling in benchmarks like the Long Range Arena.\nDifferent variants of SSMs have been successful in continuous signal data domains such as audio and vision. However, they have not been as effective in handling discrete and information-dense data, such as text, highlighting an area for further research and development.' -paper_title = "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" +paper_title = 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces' paper_abstract = "Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5X higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation." @@ -61,7 +61,7 @@ Confidence: 5: You are absolutely certain about your assessment. You are very familiar with the related work and checked the math/other details carefully. Code Of Conduct: Yes """ -review2 =""" +review2 = """ Summary: The paper proposes a new class of selective state space models (SSMs) for sequence modeling that achieves Transformer-quality performance while scaling linearly in sequence length. The paper addresses the key problem in SSMs for selecting data by selecting particular inputs. The paper presents a hardware-aware algorithm that computes the model recurrently with a scan instead of convolution, avoiding materializing the expanded state to reduce memory usage. This results in faster computation than previous methods. @@ -90,7 +90,7 @@ Code Of Conduct: Yes """ -review3= """ +review3 = """ Summary: This paper upgrades S4 by making the token mixing matrix data-dependent and introduces the Mamba structure. On the other hand, although the use of FFT is not possible, the authors provide a linear algorithm for computation, resulting in linear computational complexity. The effectiveness of the proposed method is validated on multiple datasets. @@ -157,60 +157,95 @@ """ -@pytest.fixture(params=["gpt-4o", "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1"]) +@pytest.fixture(params=['gpt-4o', 'together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1']) def model_name(request: pytest.FixtureRequest) -> Any: return request.param -@pytest.mark.parametrize("use_mock", [True]) -def test_evaluator_eval_idea(use_mock:bool, model_name: str) -> None: - evaluator = IdeaQualityEvaluator(model_name= model_name) - input_dict = {'idea': idea, 'trend': trend,'pk':0} + +@pytest.mark.parametrize('use_mock', [True]) +def test_evaluator_eval_idea(use_mock: bool, model_name: str) -> None: + evaluator = IdeaQualityEvaluator(model_name=model_name) + input_dict = {'idea': idea, 'trend': trend, 'pk': 0} if use_mock: - with patch("research_town.utils.eval_prompter.model_prompting", MagicMock(return_value=[ - "Overall Score=86. Dimension Scores=[9, 8, 9, 9, 8, 8, 8, 9, 8, 8]." - ])): + with patch( + 'research_town.utils.eval_prompter.model_prompting', + MagicMock( + return_value=[ + 'Overall Score=86. Dimension Scores=[9, 8, 9, 9, 8, 8, 8, 9, 8, 8].' + ] + ), + ): evals_output = evaluator.eval(**input_dict) assert evals_output is not None - assert evals_output.overall_score == 86,f"overall score of idea (mock) should be 86, but it's {evals_output.overall_score}" + assert ( + evals_output.overall_score == 86 + ), f"overall score of idea (mock) should be 86, but it's {evals_output.overall_score}" else: evals_output = evaluator.eval(**input_dict) assert evals_output is not None - assert evals_output.overall_score>=0 and evals_output.overall_score<=100,f"overall score of idea should be an Int between 0 and 100, but it's {evals_output.overall_score}" + assert ( + evals_output.overall_score >= 0 and evals_output.overall_score <= 100 + ), f"overall score of idea should be an Int between 0 and 100, but it's {evals_output.overall_score}" -@pytest.mark.parametrize("use_mock", [True]) -def test_evaluator_eval_paper(use_mock:bool,model_name: str) -> None: - paper = {'title': paper_title, 'abstract':paper_abstract} +@pytest.mark.parametrize('use_mock', [True]) +def test_evaluator_eval_paper(use_mock: bool, model_name: str) -> None: + paper = {'title': paper_title, 'abstract': paper_abstract} - input_dict = {'idea': idea, 'paper': paper,'pk':0} + input_dict = {'idea': idea, 'paper': paper, 'pk': 0} evaluator = PaperQualityEvaluator(model_name=model_name) if use_mock: - with patch("research_town.utils.eval_prompter.model_prompting", MagicMock(return_value=[ - "Overall Score=86. Dimension Scores=[9, 8, 9, 9, 8, 8, 8, 9, 8, 8]." - ])): + with patch( + 'research_town.utils.eval_prompter.model_prompting', + MagicMock( + return_value=[ + 'Overall Score=86. Dimension Scores=[9, 8, 9, 9, 8, 8, 8, 9, 8, 8].' + ] + ), + ): evals_output = evaluator.eval(**input_dict) assert evals_output is not None - assert evals_output.overall_score == 86,f"overall score of paper (mock) should be 86, but it's {evals_output.overall_score}" + assert ( + evals_output.overall_score == 86 + ), f"overall score of paper (mock) should be 86, but it's {evals_output.overall_score}" else: evals_output = evaluator.eval(**input_dict) assert evals_output is not None - assert evals_output.overall_score>=0 and evals_output.overall_score<=100,f"overall score of paper should be an Int between 0 and 100, but it's {evals_output.overall_score}" + assert ( + evals_output.overall_score >= 0 and evals_output.overall_score <= 100 + ), f"overall score of paper should be an Int between 0 and 100, but it's {evals_output.overall_score}" -@pytest.mark.parametrize("use_mock", [True]) -def test_evaluator_eval_review(use_mock:bool,model_name: str) -> None: - paper = {'title': paper_title, 'abstract':paper_abstract} +@pytest.mark.parametrize('use_mock', [True]) +def test_evaluator_eval_review(use_mock: bool, model_name: str) -> None: + paper = {'title': paper_title, 'abstract': paper_abstract} reviews = [review1, review2, review3, review4] - input_dict = {'idea': idea, 'trend': trend, 'paper': paper,'pk':0,'review':reviews,'decision':'Reject'} + input_dict = { + 'idea': idea, + 'trend': trend, + 'paper': paper, + 'pk': 0, + 'review': reviews, + 'decision': 'Reject', + } evaluator = ReviewQualityEvaluator(model_name=model_name) if use_mock: - with patch("research_town.utils.eval_prompter.model_prompting", MagicMock(return_value=[ - "Overall Score=86. Dimension Scores=[9, 8, 9, 9, 8, 8, 8, 9, 8, 8]." - ])): + with patch( + 'research_town.utils.eval_prompter.model_prompting', + MagicMock( + return_value=[ + 'Overall Score=86. Dimension Scores=[9, 8, 9, 9, 8, 8, 8, 9, 8, 8].' + ] + ), + ): evals_output = evaluator.eval(**input_dict) assert evals_output is not None - assert evals_output.overall_score == 86,f"overall score of paper (mock) shoud be 86, but it's {evals_output.overall_score}" + assert ( + evals_output.overall_score == 86 + ), f"overall score of paper (mock) should be 86, but it's {evals_output.overall_score}" else: evals_output = evaluator.eval(**input_dict) assert evals_output is not None - assert evals_output.overall_score>=0 and evals_output.overall_score<=100,f"overall score of paper should be an Int between 0 and 100, but it's {evals_output.overall_score}" + assert ( + evals_output.overall_score >= 0 and evals_output.overall_score <= 100 + ), f"overall score of paper should be an Int between 0 and 100, but it's {evals_output.overall_score}" diff --git a/tests/test_model_call.py b/tests/test_model_call.py index 03c1f8a7..33820fd1 100644 --- a/tests/test_model_call.py +++ b/tests/test_model_call.py @@ -3,16 +3,23 @@ def test_openai_call() -> None: # supported by OPENAI_API_KEY - prompt = "Here is a high-level summarized insight of a research field Machine Learning. " - response = model_prompting("gpt-3.5-turbo", prompt) + prompt = ( + 'Here is a high-level summarized insight of a research field Machine Learning. ' + ) + response = model_prompting('gpt-3.5-turbo', prompt) assert response is not None assert len(response) > 0 assert len(response[0]) > 0 + def test_togetherai_mistral_call() -> None: # supported by TOGETHERAI_API_KEY - prompt = "Here is a high-level summarized insight of a research field Machine Learning. " - response = model_prompting("together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1", prompt) + prompt = ( + 'Here is a high-level summarized insight of a research field Machine Learning. ' + ) + response = model_prompting( + 'together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', prompt + ) assert response is not None assert len(response) > 0 assert len(response[0]) > 0 diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 8c317624..61876b88 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -6,7 +6,7 @@ def test_serializer() -> None: research_agent = BaseResearchAgent( agent_profile=agent_profile_A, - model_name="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1" + model_name='together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1', ) research_agent_serialized = Serializer.serialize(research_agent) research_agent_deserialized = Serializer.deserialize(research_agent_serialized) diff --git a/tests/utils.py b/tests/utils.py index de1eb04f..b028b0c1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,18 +4,22 @@ def mock_papers(corpus: List[str], query: str, num: int) -> List[str]: return corpus[:num] + def mock_prompting( llm_model: str, prompt: str, - return_num: Optional[int]=2, - max_tokens: Optional[int]=512, + return_num: Optional[int] = 2, + max_tokens: Optional[int] = 512, ) -> List[str]: - if "Please give some reviews based on the following inputs and external data." in prompt: - return ["This is a paper review for MambaOut."] - elif "Please provide a score for the following reviews." in prompt: - return ["2"] - elif "Please give me 3 to 5 novel ideas and insights" in prompt: - return ["This is a research idea."] - elif "summarize the keywords" in prompt: - return ["Graph Neural Network"] - return ["Default response"] + if ( + 'Please give some reviews based on the following inputs and external data.' + in prompt + ): + return ['This is a paper review for MambaOut.'] + elif 'Please provide a score for the following reviews.' in prompt: + return ['2'] + elif 'Please give me 3 to 5 novel ideas and insights' in prompt: + return ['This is a research idea.'] + elif 'summarize the keywords' in prompt: + return ['Graph Neural Network'] + return ['Default response']