Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add beartype checking for public function #133

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ types-requests = "^2.31.0"
torch = ">=2.1.0 <2.3.0"
transformers = "^4.40.0"
litellm = "^1.0.0"
beartype = "^0.18.5"

[tool.poetry.group.dev.dependencies]
pre-commit = "*"
Expand Down
11 changes: 11 additions & 0 deletions research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple

from beartype import beartype

from ..dbs import (
AgentAgentDiscussionLog,
AgentPaperMetaReviewLog,
Expand Down Expand Up @@ -32,6 +34,7 @@ def __init__(self,
self.memory: Dict[str, str] = {}
self.model_name: str = model_name

@beartype
def get_profile(self, author_name: str) -> AgentProfile:
# TODO: db get based on name
agent_profile = AgentProfile(
Expand All @@ -40,6 +43,7 @@ def get_profile(self, author_name: str) -> AgentProfile:
)
return agent_profile

@beartype
def communicate(
self,
message: AgentAgentDiscussionLog
Expand All @@ -60,6 +64,7 @@ def communicate(
)
return discussion_log

@beartype
def read_paper(
self,
papers: List[PaperProfile],
Expand All @@ -86,6 +91,7 @@ def read_paper(
trend_output = trend[0]
return trend_output

@beartype
def find_collaborators(
self,
paper: PaperProfile,
Expand Down Expand Up @@ -120,6 +126,7 @@ def find_collaborators(
collaborators_list.append(self.get_profile(collaborator))
return collaborators_list

@beartype
def get_co_author_relationships(
self,
agent_profile: AgentProfile,
Expand Down Expand Up @@ -165,6 +172,7 @@ def generate_idea(

return ideas

@beartype
def write_paper(
self,
research_ideas: List[str],
Expand All @@ -186,6 +194,7 @@ def write_paper(
paper_profile = PaperProfile(abstract=paper_abstract)
return paper_profile

@beartype
def review_paper(
self,
paper: PaperProfile
Expand All @@ -209,6 +218,7 @@ def review_paper(
review_score=review_score
)

@beartype
def make_review_decision(
self,
paper: PaperProfile,
Expand Down Expand Up @@ -237,6 +247,7 @@ def make_review_decision(
meta_review=meta_review[0],
)

@beartype
def rebut_review(
self,
paper: PaperProfile,
Expand Down
5 changes: 5 additions & 0 deletions research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Tuple

from beartype import beartype

from ..dbs import (
AgentPaperMetaReviewLog,
AgentPaperRebuttalLog,
Expand Down Expand Up @@ -34,14 +36,17 @@ def __init__(self,
self.paper_db = paper_db
self.env_db = env_db

@beartype
def assign_roles(self, role_dict: Dict[str, str]) -> None:
for index, agent_profile in enumerate(self.agent_profiles):
if role_dict[agent_profile.pk] == "reviewer":
self.reviewer_mask[index] = True

@beartype
def initialize_submission(self, paper_profile: PaperProfile) -> None:
self.submission = paper_profile

@beartype
def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None:
decision_count = {"accept": 0, "reject": 0}
for _, decision in decision_dict.items():
Expand Down
3 changes: 3 additions & 0 deletions research_town/envs/env_paper_submission.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List

from beartype import beartype

from ..agents.agent_base import BaseResearchAgent
from ..dbs import (
AgentProfile,
Expand Down Expand Up @@ -70,6 +72,7 @@ def step(self) -> None:
self.submit_paper(abstracts)
self.terminated = True

@beartype
def submit_paper(self, paper_dict: Dict[str, PaperProfile]) -> None:
# TODO: clarify paper submission
for _, paper in paper_dict.items():
Expand Down
20 changes: 12 additions & 8 deletions research_town/utils/agent_prompter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Dict, List, Tuple

from beartype import beartype

from .model_prompting import model_prompting
from .paper_collector import get_related_papers


@beartype
def summarize_research_field_prompting(
profile: Dict[str, str],
keywords: List[str],
Expand Down Expand Up @@ -45,6 +48,7 @@ def summarize_research_field_prompting(
return model_prompting(model_name, prompt)


@beartype
def find_collaborators_prompting(
input: Dict[str, str],
self_profile: Dict[str, str],
Expand Down Expand Up @@ -77,7 +81,7 @@ def find_collaborators_prompting(
prompt = prompt_qa.format_map(input)
return model_prompting(model_name, prompt)


@beartype
def generate_ideas_prompting(
trend: str,
model_name: str,
Expand All @@ -94,7 +98,7 @@ def generate_ideas_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def summarize_research_direction_prompting(
personal_info: str,
model_name: str,
Expand All @@ -111,7 +115,7 @@ def summarize_research_direction_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def write_paper_abstract_prompting(
ideas: List[str],
papers: Dict[str, Dict[str, List[str]]],
Expand Down Expand Up @@ -145,7 +149,7 @@ def write_paper_abstract_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def review_score_prompting(paper_review: str, model_name: str) -> int:
prompt_qa = (
"Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score."
Expand All @@ -159,7 +163,7 @@ def review_score_prompting(paper_review: str, model_name: str) -> int:
else:
return 0


@beartype
def review_paper_prompting(paper: Dict[str, str], model_name: str,) -> List[str]:
"""
Review paper from using list, and external data (published papers)
Expand All @@ -182,7 +186,7 @@ def review_paper_prompting(paper: Dict[str, str], model_name: str,) -> List[str
prompt = prompt_qa.format_map(input)
return model_prompting(model_name, prompt)


@beartype
def make_review_decision_prompting(
paper: Dict[str, str],
review: Dict[str, Tuple[int, str]],
Expand Down Expand Up @@ -212,7 +216,7 @@ def make_review_decision_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def rebut_review_prompting(
paper: Dict[str, str],
review: Dict[str, Tuple[int, str]],
Expand Down Expand Up @@ -251,7 +255,7 @@ def rebut_review_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def communicate_with_multiple_researchers_prompting(
messages: Dict[str, str],
model_name: str,
Expand Down
5 changes: 4 additions & 1 deletion research_town/utils/eval_prompter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict

from beartype import beartype

from .model_prompting import model_prompting


@beartype
def idea_quality_eval_prompting(
idea: str,
trend: str,
Expand Down Expand Up @@ -87,7 +90,7 @@ def idea_quality_eval_prompting(

return combined_result


@beartype
def paper_quality_eval_prompting(
idea: str,
paper: Dict[str,str],
Expand Down
2 changes: 2 additions & 0 deletions research_town/utils/model_prompting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional

import litellm
from beartype import beartype

from .decorator import exponential_backoff


@beartype
@exponential_backoff(retries=5, base_wait_time=1)
def model_prompting(
llm_model: str,
Expand Down
Loading