From 3dcf3ad1b8bd929fa2ca5f615103a560168cfd3d Mon Sep 17 00:00:00 2001 From: Kunlun-Zhu Date: Wed, 29 May 2024 20:31:29 +0800 Subject: [PATCH 1/4] add beartype --- pyproject.toml | 1 + research_town/envs/env_paper_rebuttal.py | 5 ++++- research_town/envs/env_paper_submission.py | 3 ++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b2a964a0..ec58cedd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ types-requests = "^2.31.0" torch = ">=2.1.0 <2.3.0" transformers = "^4.40.0" litellm = "^1.0.0" +beartype = "0.18.5" [tool.poetry.group.dev.dependencies] pre-commit = "*" diff --git a/research_town/envs/env_paper_rebuttal.py b/research_town/envs/env_paper_rebuttal.py index 0ee6ca54..36c80015 100644 --- a/research_town/envs/env_paper_rebuttal.py +++ b/research_town/envs/env_paper_rebuttal.py @@ -1,5 +1,5 @@ from typing import Dict, List, Tuple - +from beartype import beartype from ..dbs import ( AgentPaperMetaReviewLog, AgentPaperRebuttalLog, @@ -34,14 +34,17 @@ def __init__(self, self.paper_db = paper_db self.env_db = env_db + @beartype def assign_roles(self, role_dict: Dict[str, str]) -> None: for index, agent_profile in enumerate(self.agent_profiles): if role_dict[agent_profile.pk] == "reviewer": self.reviewer_mask[index] = True + @beartype def initialize_submission(self, paper_profile: PaperProfile) -> None: self.submission = paper_profile + @beartype def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None: decision_count = {"accept": 0, "reject": 0} for _, decision in decision_dict.items(): diff --git a/research_town/envs/env_paper_submission.py b/research_town/envs/env_paper_submission.py index 12c49790..c4cd89ac 100644 --- a/research_town/envs/env_paper_submission.py +++ b/research_town/envs/env_paper_submission.py @@ -1,5 +1,5 @@ from typing import Dict, List - +from beartype import beartype from ..agents.agent_base import BaseResearchAgent from ..dbs import ( AgentProfile, @@ -70,6 +70,7 @@ def step(self) -> None: self.submit_paper(abstracts) self.terminated = True + @beartype def submit_paper(self, paper_dict: Dict[str, PaperProfile]) -> None: # TODO: clarify paper submission for _, paper in paper_dict.items(): From f50e6919803899838bd754bf956a170ff9ec8b47 Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Wed, 29 May 2024 12:14:58 -0400 Subject: [PATCH 2/4] fix pre-commit error --- research_town/agents/agent_base.py | 11 +++++++++++ research_town/envs/env_paper_rebuttal.py | 2 ++ research_town/envs/env_paper_submission.py | 2 ++ research_town/utils/agent_prompter.py | 20 ++++++++++++-------- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/research_town/agents/agent_base.py b/research_town/agents/agent_base.py index 45d9da64..8a78def9 100644 --- a/research_town/agents/agent_base.py +++ b/research_town/agents/agent_base.py @@ -1,6 +1,8 @@ from datetime import datetime from typing import Any, Dict, List, Tuple +from beartype import beartype + from ..dbs import ( AgentAgentDiscussionLog, AgentPaperMetaReviewLog, @@ -32,6 +34,7 @@ def __init__(self, self.memory: Dict[str, str] = {} self.model_name: str = model_name + @beartype def get_profile(self, author_name: str) -> AgentProfile: # TODO: db get based on name agent_profile = AgentProfile( @@ -40,6 +43,7 @@ def get_profile(self, author_name: str) -> AgentProfile: ) return agent_profile + @beartype def communicate( self, message: AgentAgentDiscussionLog @@ -60,6 +64,7 @@ def communicate( ) return discussion_log + @beartype def read_paper( self, papers: List[PaperProfile], @@ -86,6 +91,7 @@ def read_paper( trend_output = trend[0] return trend_output + @beartype def find_collaborators( self, paper: PaperProfile, @@ -120,6 +126,7 @@ def find_collaborators( collaborators_list.append(self.get_profile(collaborator)) return collaborators_list + @beartype def get_co_author_relationships( self, agent_profile: AgentProfile, @@ -165,6 +172,7 @@ def generate_idea( return ideas + @beartype def write_paper( self, research_ideas: List[str], @@ -186,6 +194,7 @@ def write_paper( paper_profile = PaperProfile(abstract=paper_abstract) return paper_profile + @beartype def review_paper( self, paper: PaperProfile @@ -209,6 +218,7 @@ def review_paper( review_score=review_score ) + @beartype def make_review_decision( self, paper: PaperProfile, @@ -237,6 +247,7 @@ def make_review_decision( meta_review=meta_review[0], ) + @beartype def rebut_review( self, paper: PaperProfile, diff --git a/research_town/envs/env_paper_rebuttal.py b/research_town/envs/env_paper_rebuttal.py index 36c80015..20193b7e 100644 --- a/research_town/envs/env_paper_rebuttal.py +++ b/research_town/envs/env_paper_rebuttal.py @@ -1,5 +1,7 @@ from typing import Dict, List, Tuple + from beartype import beartype + from ..dbs import ( AgentPaperMetaReviewLog, AgentPaperRebuttalLog, diff --git a/research_town/envs/env_paper_submission.py b/research_town/envs/env_paper_submission.py index c4cd89ac..f57b85f5 100644 --- a/research_town/envs/env_paper_submission.py +++ b/research_town/envs/env_paper_submission.py @@ -1,5 +1,7 @@ from typing import Dict, List + from beartype import beartype + from ..agents.agent_base import BaseResearchAgent from ..dbs import ( AgentProfile, diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index e9b2486b..986810e6 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -1,9 +1,12 @@ from typing import Dict, List, Tuple +from beartype import beartype + from .model_prompting import model_prompting from .paper_collector import get_related_papers +@beartype def summarize_research_field_prompting( profile: Dict[str, str], keywords: List[str], @@ -45,6 +48,7 @@ def summarize_research_field_prompting( return model_prompting(model_name, prompt) +@beartype def find_collaborators_prompting( input: Dict[str, str], self_profile: Dict[str, str], @@ -77,7 +81,7 @@ def find_collaborators_prompting( prompt = prompt_qa.format_map(input) return model_prompting(model_name, prompt) - +@beartype def generate_ideas_prompting( trend: str, model_name: str, @@ -94,7 +98,7 @@ def generate_ideas_prompting( prompt = prompt_template.format_map(template_input) return model_prompting(model_name, prompt) - +@beartype def summarize_research_direction_prompting( personal_info: str, model_name: str, @@ -111,7 +115,7 @@ def summarize_research_direction_prompting( prompt = prompt_template.format_map(template_input) return model_prompting(model_name, prompt) - +@beartype def write_paper_abstract_prompting( ideas: List[str], papers: Dict[str, Dict[str, List[str]]], @@ -145,7 +149,7 @@ def write_paper_abstract_prompting( prompt = prompt_template.format_map(template_input) return model_prompting(model_name, prompt) - +@beartype def review_score_prompting(paper_review: str, model_name: str) -> int: prompt_qa = ( "Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score." @@ -159,7 +163,7 @@ def review_score_prompting(paper_review: str, model_name: str) -> int: else: return 0 - +@beartype def review_paper_prompting(paper: Dict[str, str], model_name: str,) -> List[str]: """ Review paper from using list, and external data (published papers) @@ -182,7 +186,7 @@ def review_paper_prompting(paper: Dict[str, str], model_name: str,) -> List[str prompt = prompt_qa.format_map(input) return model_prompting(model_name, prompt) - +@beartype def make_review_decision_prompting( paper: Dict[str, str], review: Dict[str, Tuple[int, str]], @@ -212,7 +216,7 @@ def make_review_decision_prompting( prompt = prompt_template.format_map(template_input) return model_prompting(model_name, prompt) - +@beartype def rebut_review_prompting( paper: Dict[str, str], review: Dict[str, Tuple[int, str]], @@ -251,7 +255,7 @@ def rebut_review_prompting( prompt = prompt_template.format_map(template_input) return model_prompting(model_name, prompt) - +@beartype def communicate_with_multiple_researchers_prompting( messages: Dict[str, str], model_name: str, From f960039c0d362fc345b8f63d19f571fee426ec6f Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Wed, 29 May 2024 12:16:23 -0400 Subject: [PATCH 3/4] fix pre-commit error --- research_town/utils/eval_prompter.py | 5 ++++- research_town/utils/model_prompting.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/research_town/utils/eval_prompter.py b/research_town/utils/eval_prompter.py index 62dad190..96f24523 100644 --- a/research_town/utils/eval_prompter.py +++ b/research_town/utils/eval_prompter.py @@ -1,8 +1,11 @@ from typing import Dict +from beartype import beartype + from .model_prompting import model_prompting +@beartype def idea_quality_eval_prompting( idea: str, trend: str, @@ -87,7 +90,7 @@ def idea_quality_eval_prompting( return combined_result - +@beartype def paper_quality_eval_prompting( idea: str, paper: Dict[str,str], diff --git a/research_town/utils/model_prompting.py b/research_town/utils/model_prompting.py index 666e3b3e..7d7bf36a 100644 --- a/research_town/utils/model_prompting.py +++ b/research_town/utils/model_prompting.py @@ -1,10 +1,12 @@ from typing import List, Optional import litellm +from beartype import beartype from .decorator import exponential_backoff +@beartype @exponential_backoff(retries=5, base_wait_time=1) def model_prompting( llm_model: str, From 64fa63e496dcf1d8232dd18d0b54a10bb8026e1d Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Wed, 29 May 2024 12:17:48 -0400 Subject: [PATCH 4/4] update pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ec58cedd..597482de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ types-requests = "^2.31.0" torch = ">=2.1.0 <2.3.0" transformers = "^4.40.0" litellm = "^1.0.0" -beartype = "0.18.5" +beartype = "^0.18.5" [tool.poetry.group.dev.dependencies] pre-commit = "*"