Skip to content

Commit

Permalink
Merge branch 'main' into reorg/fix-agent-base-func-logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lwaekfjlk authored May 29, 2024
2 parents 3149fc4 + b9cff0c commit 1c2d855
Show file tree
Hide file tree
Showing 15 changed files with 253 additions and 63 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ types-requests = "^2.31.0"
torch = ">=2.1.0 <2.3.0"
transformers = "^4.40.0"
litellm = "^1.0.0"
beartype = "^0.18.5"

[tool.poetry.group.dev.dependencies]
pre-commit = "*"
Expand Down
11 changes: 11 additions & 0 deletions research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple

from beartype import beartype

from ..dbs import (
AgentAgentDiscussionLog,
AgentPaperMetaReviewLog,
Expand Down Expand Up @@ -33,6 +35,7 @@ def __init__(self,
self.memory: Dict[str, str] = {}
self.model_name: str = model_name

@beartype
def get_profile(self, author_name: str) -> AgentProfile:
# TODO: db get based on name
agent_profile = AgentProfile(
Expand All @@ -41,6 +44,7 @@ def get_profile(self, author_name: str) -> AgentProfile:
)
return agent_profile

@beartype
def communicate(
self,
message: AgentAgentDiscussionLog
Expand All @@ -61,6 +65,7 @@ def communicate(
)
return discussion_log

@beartype
def read_paper(
self,
papers: List[PaperProfile],
Expand All @@ -75,6 +80,7 @@ def read_paper(
)
return trend_output

@beartype
def find_collaborators(
self,
paper: PaperProfile,
Expand Down Expand Up @@ -109,6 +115,7 @@ def find_collaborators(
collaborators_list.append(self.get_profile(collaborator))
return collaborators_list

@beartype
def get_co_author_relationships(
self,
agent_profile: AgentProfile,
Expand Down Expand Up @@ -154,6 +161,7 @@ def generate_idea(

return ideas

@beartype
def write_paper(
self,
research_ideas: List[str],
Expand All @@ -175,6 +183,7 @@ def write_paper(
paper_profile = PaperProfile(abstract=paper_abstract)
return paper_profile

@beartype
def review_paper(
self,
paper: PaperProfile
Expand All @@ -198,6 +207,7 @@ def review_paper(
review_score=review_score
)

@beartype
def make_review_decision(
self,
paper: PaperProfile,
Expand Down Expand Up @@ -226,6 +236,7 @@ def make_review_decision(
meta_review=meta_review[0],
)

@beartype
def rebut_review(
self,
paper: PaperProfile,
Expand Down
5 changes: 5 additions & 0 deletions research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Tuple

from beartype import beartype

from ..dbs import (
AgentPaperMetaReviewLog,
AgentPaperRebuttalLog,
Expand Down Expand Up @@ -34,14 +36,17 @@ def __init__(self,
self.paper_db = paper_db
self.env_db = env_db

@beartype
def assign_roles(self, role_dict: Dict[str, str]) -> None:
for index, agent_profile in enumerate(self.agent_profiles):
if role_dict[agent_profile.pk] == "reviewer":
self.reviewer_mask[index] = True

@beartype
def initialize_submission(self, paper_profile: PaperProfile) -> None:
self.submission = paper_profile

@beartype
def submit_decision(self, decision_dict: Dict[str, Tuple[bool, str]]) -> None:
decision_count = {"accept": 0, "reject": 0}
for _, decision in decision_dict.items():
Expand Down
3 changes: 3 additions & 0 deletions research_town/envs/env_paper_submission.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List

from beartype import beartype

from ..agents.agent_base import BaseResearchAgent
from ..dbs import (
AgentProfile,
Expand Down Expand Up @@ -70,6 +72,7 @@ def step(self) -> None:
self.submit_paper(abstracts)
self.terminated = True

@beartype
def submit_paper(self, paper_dict: Dict[str, PaperProfile]) -> None:
# TODO: clarify paper submission
for _, paper in paper_dict.items():
Expand Down
24 changes: 23 additions & 1 deletion research_town/evaluators/output_format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Extra, Field, validator


class IdeaEvalOutput(BaseModel):
Expand All @@ -8,8 +8,30 @@ class IdeaEvalOutput(BaseModel):
class Config:
extra = Extra.allow # Allows extra fields to be stored

@validator('overall_score')
def validate_overall_score(cls, v):
if v is None:
raise ValueError("Overall score cannot be None")
if not (0 <= v <= 100):
raise ValueError("Overall score must be between 0 and 100")
return v


class PaperEvalOutput(BaseModel):
overall_score: int = Field(default=-1)
pk: str = Field(default='0')
class Config:
extra = Extra.allow # Allows extra fields to be stored

@validator('overall_score')
def validate_overall_score(cls, v):
if v is None:
raise ValueError("Overall score cannot be None")
if not (0 <= v <= 100):
raise ValueError("Overall score must be between 0 and 100")
return v

class OutputFormatError(Exception):
def __init__(self, message:str="Output format error")-> None:
self.message = message
super().__init__(self.message)
48 changes: 26 additions & 22 deletions research_town/evaluators/quality_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@

import re
from typing import Any, Dict
from typing import Any

from ..utils.decorator import parsing_error_exponential_backoff
from ..utils.eval_prompter import (
idea_quality_eval_prompting,
paper_quality_eval_prompting,
)
from .output_format import IdeaEvalOutput, PaperEvalOutput
from .output_format import (
IdeaEvalOutput,
OutputFormatError,
PaperEvalOutput,
)


class IdeaQualityEvaluator(object):
Expand All @@ -18,34 +23,32 @@ def __init__(self,
self.model_name = model_name
self.parsed_output = IdeaEvalOutput()


@parsing_error_exponential_backoff(retries=5, base_wait_time=1)
def eval(
self,
idea: str,
trend: str,
*args: Any,
**kwargs: Any,
)-> IdeaEvalOutput:
raw_output = idea_quality_eval_prompting(
idea=idea,
trend=trend,
idea=kwargs['idea'],
trend=kwargs['trend'],
model_name=self.model_name
)
self.parsed_output = self.parse(raw_output)
# get pk
# self.parsed_output.pk = kwargs.get("pk")
# Store the input kwargs in parsed_output

for key, value in kwargs.items():
setattr(self.parsed_output, key, value)
return self.parsed_output

def parse(self, raw_output:str) -> IdeaEvalOutput:
match = re.search(r"Overall\s*Score\s*\W*(\d+)\W*", raw_output, re.IGNORECASE)
if match:
return IdeaEvalOutput(overall_score=int(match.group(1)))
try:
return IdeaEvalOutput(overall_score=int(match.group(1)))
except ValueError as e:
raise OutputFormatError(f"Invalid overall score: {e}")
else:
return IdeaEvalOutput()

raise OutputFormatError("Output format error: 'Overall Score' not found")

class PaperQualityEvaluator(object):
def __init__(self,
Expand All @@ -56,28 +59,29 @@ def __init__(self,
self.model_name = model_name
self.parsed_output = PaperEvalOutput()


@parsing_error_exponential_backoff(retries=5, base_wait_time=1)
def eval(
self,
idea: str,
paper: Dict[str,str],
*args: Any,
**kwargs: Any,
)-> PaperEvalOutput:
raw_output = paper_quality_eval_prompting(
idea=idea,
paper=paper,
idea=kwargs['idea'],
paper=kwargs['paper'],
model_name=self.model_name
)
self.parsed_output = self.parse(raw_output)
# Store the input kwargs in parsed_output

for key, value in kwargs.items():
setattr(self.parsed_output, key, value)
return self.parsed_output

def parse(self, raw_output:str) -> PaperEvalOutput:
def parse(self, raw_output: str) -> PaperEvalOutput:
match = re.search(r"Overall\s*Score\s*\W*(\d+)\W*", raw_output, re.IGNORECASE)
if match:
return PaperEvalOutput(overall_score=int(match.group(1)))
try:
return PaperEvalOutput(overall_score=int(match.group(1)))
except ValueError as e:
raise OutputFormatError(f"Invalid overall score: {e}")
else:
return PaperEvalOutput()
raise OutputFormatError("Output format error: 'Overall Score' not found")
22 changes: 13 additions & 9 deletions research_town/utils/agent_prompter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Tuple

from beartype import beartype

from .model_prompting import model_prompting
from .paper_collector import get_related_papers
from ..dbs import PaperProfile, AgentProfile
Expand Down Expand Up @@ -30,7 +32,8 @@ def prepare_research_trend_prompt_input(
"papers": papers_dict
}

def research_trend_prompting(
@beartype
def summarize_research_field_prompting(
profile: Dict[str, str],
keywords: List[str],
papers: Dict[str, Dict[str, List[str]]],
Expand Down Expand Up @@ -71,6 +74,7 @@ def research_trend_prompting(
return model_prompting(model_name, prompt)


@beartype
def find_collaborators_prompting(
input: Dict[str, str],
self_profile: Dict[str, str],
Expand Down Expand Up @@ -103,7 +107,7 @@ def find_collaborators_prompting(
prompt = prompt_qa.format_map(input)
return model_prompting(model_name, prompt)


@beartype
def generate_ideas_prompting(
trend: str,
model_name: str,
Expand All @@ -120,7 +124,7 @@ def generate_ideas_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def summarize_research_direction_prompting(
personal_info: str,
model_name: str,
Expand All @@ -137,7 +141,7 @@ def summarize_research_direction_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def write_paper_abstract_prompting(
ideas: List[str],
papers: Dict[str, Dict[str, List[str]]],
Expand Down Expand Up @@ -171,7 +175,7 @@ def write_paper_abstract_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def review_score_prompting(paper_review: str, model_name: str) -> int:
prompt_qa = (
"Please provide a score for the following reviews. The score should be between 1 and 10, where 1 is the lowest and 10 is the highest. Only returns one number score."
Expand All @@ -185,7 +189,7 @@ def review_score_prompting(paper_review: str, model_name: str) -> int:
else:
return 0


@beartype
def review_paper_prompting(paper: Dict[str, str], model_name: str,) -> List[str]:
"""
Review paper from using list, and external data (published papers)
Expand All @@ -208,7 +212,7 @@ def review_paper_prompting(paper: Dict[str, str], model_name: str,) -> List[str
prompt = prompt_qa.format_map(input)
return model_prompting(model_name, prompt)


@beartype
def make_review_decision_prompting(
paper: Dict[str, str],
review: Dict[str, Tuple[int, str]],
Expand Down Expand Up @@ -238,7 +242,7 @@ def make_review_decision_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def rebut_review_prompting(
paper: Dict[str, str],
review: Dict[str, Tuple[int, str]],
Expand Down Expand Up @@ -277,7 +281,7 @@ def rebut_review_prompting(
prompt = prompt_template.format_map(template_input)
return model_prompting(model_name, prompt)


@beartype
def communicate_with_multiple_researchers_prompting(
messages: Dict[str, str],
model_name: str,
Expand Down
Loading

0 comments on commit 1c2d855

Please sign in to comment.