Skip to content

Commit

Permalink
enhance evaluator by polishing prompt and adding info (#181)
Browse files Browse the repository at this point in the history
* outline of alignment criteria

* reorganize evaluator prompts

* fix bug if no trend provided in paper eval

* minor prompt format and change eval parser

* enable real tests

* ease parser matching

* disable real tests in test_eval

* Fix issues identified by pre-commit hooks

---------

Co-authored-by: chengzr01 <[email protected]>
Co-authored-by: Haofei Yu <[email protected]>
  • Loading branch information
3 people authored Jun 1, 2024
1 parent b90ede0 commit faa8172
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 155 deletions.
8 changes: 4 additions & 4 deletions data/dbs/test_agent_profile_db.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"d544f290-6748-46b5-a82e-fd8f40c1e4cc": {
"pk": "d544f290-6748-46b5-a82e-fd8f40c1e4cc",
"8514bd13-0501-4c4c-bd7c-1a76a6c54c77": {
"pk": "8514bd13-0501-4c4c-bd7c-1a76a6c54c77",
"name": "Jane Smith",
"bio": "Expert in NLP",
"collaborators": [],
"institute": "NLP Lab"
},
"9c581b74-86f6-4577-b400-9221df4c3917": {
"pk": "9c581b74-86f6-4577-b400-9221df4c3917",
"5faa7149-d6ed-46e4-95c5-a61041c6c621": {
"pk": "5faa7149-d6ed-46e4-95c5-a61041c6c621",
"name": "Alice Johnson",
"bio": "Data Scientist",
"collaborators": [],
Expand Down
9 changes: 5 additions & 4 deletions data/dbs/test_env_logs_db.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
{
"PaperProfile": [],
"ResearchPaperSubmission": [],
"AgentPaperReviewLog": [
{
"pk": "654935ea-be94-4898-80a4-bb5c7c12f286",
"pk": "6dcd86f3-e1cb-4dc7-9989-1e5483475475",
"timestep": 0,
"paper_pk": "paper2",
"agent_pk": "agent2",
Expand All @@ -12,7 +13,7 @@
],
"AgentPaperRebuttalLog": [
{
"pk": "5387eadb-6a18-44e1-b7a3-55c49c808efd",
"pk": "32760169-07ca-4cb1-9b4d-868ee9f9c04c",
"timestep": 0,
"paper_pk": "paper1",
"agent_pk": "agent1",
Expand All @@ -21,7 +22,7 @@
],
"AgentPaperMetaReviewLog": [
{
"pk": "f3bffbbc-c67c-40a5-82f1-200989b2bea9",
"pk": "0e9bd890-ba49-4cb6-bedd-01f1d21b2586",
"timestep": 0,
"paper_pk": "paper1",
"agent_pk": "agent1",
Expand All @@ -31,7 +32,7 @@
],
"AgentAgentDiscussionLog": [
{
"pk": "67a25e19-2182-4671-9005-a3f95dd3f7c0",
"pk": "7a4f1d87-d466-41d6-a105-f282573da839",
"timestep": 0,
"agent_from_pk": "agent1",
"agent_from_name": "Rex Ying",
Expand Down
8 changes: 4 additions & 4 deletions data/dbs/test_paper_profile_db.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"43653097-1230-48e5-ba17-6f616bc93380": {
"pk": "43653097-1230-48e5-ba17-6f616bc93380",
"d5e1587d-14aa-4ac4-ab72-c2774cedb2cc": {
"pk": "d5e1587d-14aa-4ac4-ab72-c2774cedb2cc",
"title": "Updated Sample Paper 1",
"abstract": "This is the abstract for paper 1",
"authors": [
Expand All @@ -22,8 +22,8 @@
"citation_count": 15,
"award": null
},
"37e9c697-bd7b-40da-975f-579eddc9508e": {
"pk": "37e9c697-bd7b-40da-975f-579eddc9508e",
"dfbec221-2cce-4038-b5b8-7925e6ea916f": {
"pk": "dfbec221-2cce-4038-b5b8-7925e6ea916f",
"title": "Sample Paper 3",
"abstract": "This is the abstract for paper 3",
"authors": [
Expand Down
4 changes: 2 additions & 2 deletions data/dbs/test_research_progress_db.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"ResearchIdea": [
{
"pk": "585e0e17-ae53-44a1-a682-e4ee2883655c",
"pk": "177ca3d9-595f-4147-9f2a-562c9e2b08f1",
"content": "Blockchain research proposal"
},
{
"pk": "baf40f3b-f14b-48a0-bc1c-d84eaefa9e58",
"pk": "31d65766-93b6-45a3-87ba-0e9fbcf21855",
"content": "Updated idea content"
}
],
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 22 additions & 1 deletion research_town/evaluators/output_format.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from beartype.typing import Type, TypeVar
from beartype.typing import List, Type, TypeVar
from pydantic import BaseModel, Extra, Field, validator

T = TypeVar('T', bound=BaseModel)
Expand All @@ -7,6 +7,7 @@
class IdeaEvalOutput(BaseModel):
overall_score: int = Field(default=-1)
pk: str = Field(default='0')
dimension_scores: List[int] = Field(default=[])

class Config:
extra = Extra.allow # Allows extra fields to be stored
Expand All @@ -19,10 +20,17 @@ def validate_overall_score(cls: Type[T], v: int) -> int:
raise ValueError('Overall score must be between 0 and 100')
return v

@validator('dimension_scores', each_item=True)
def validate_dimension_scores(cls: Type[T], v: int) -> int:
if not (0 <= v <= 10):
raise ValueError('Each dimension score must be between 0 and 10')
return v


class PaperEvalOutput(BaseModel):
overall_score: int = Field(default=-1)
pk: str = Field(default='0')
dimension_scores: List[int] = Field(default=[])

class Config:
extra = Extra.allow # Allows extra fields to be stored
Expand All @@ -35,10 +43,17 @@ def validate_overall_score(cls: Type[T], v: int) -> int:
raise ValueError('Overall score must be between 0 and 100')
return v

@validator('dimension_scores', each_item=True)
def validate_dimension_scores(cls: Type[T], v: int) -> int:
if not (0 <= v <= 10):
raise ValueError('Each dimension score must be between 0 and 10')
return v


class ReviewEvalOutput(BaseModel):
overall_score: int = Field(default=-1)
pk: str = Field(default='0')
dimension_scores: List[int] = Field(default=[])

class Config:
extra = Extra.allow # Allows extra fields to be stored
Expand All @@ -51,6 +66,12 @@ def validate_overall_score(cls: Type[T], v: int) -> int:
raise ValueError('Overall score must be between 0 and 100')
return v

@validator('dimension_scores', each_item=True)
def validate_dimension_scores(cls: Type[T], v: int) -> int:
if not (0 <= v <= 10):
raise ValueError('Each dimension score must be between 0 and 10')
return v


class OutputFormatError(Exception):
def __init__(self, message: str = 'Output format error') -> None:
Expand Down
90 changes: 79 additions & 11 deletions research_town/evaluators/quality_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,37 @@ def eval(
return self.parsed_output

def parse(self, raw_output: str) -> IdeaEvalOutput:
match = re.search(r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE)
if match:
overall_score_match = re.search(
r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE
)
dimension_scores_match = re.search(
r'Dimension\s*Scores\s*\W*\s*\[([0-9,\s]+)\]', raw_output, re.IGNORECASE
)

if overall_score_match:
try:
return IdeaEvalOutput(overall_score=int(match.group(1)))
overall_score = int(overall_score_match.group(1))
except ValueError as e:
raise OutputFormatError(f'Invalid overall score: {e}')
else:
raise OutputFormatError("Output format error: 'Overall Score' not found")
raise OutputFormatError(
f"Output format error: 'Overall Score' not found. Raw output is {raw_output}."
)

if dimension_scores_match:
try:
dimension_scores = list(
map(int, dimension_scores_match.group(1).split(','))
)
except ValueError as e:
raise OutputFormatError(f'Invalid dimension scores: {e}')
else:
raise OutputFormatError(
f"Output format error: 'Dimension Scores' not found. Raw output is {raw_output}."
)
return IdeaEvalOutput(
overall_score=overall_score, dimension_scores=dimension_scores
)


class PaperQualityEvaluator(object):
Expand All @@ -68,14 +91,37 @@ def eval(
return self.parsed_output

def parse(self, raw_output: str) -> PaperEvalOutput:
match = re.search(r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE)
if match:
overall_score_match = re.search(
r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE
)
dimension_scores_match = re.search(
r'Dimension\s*Scores\s*\W*\s*\[([0-9,\s]+)\]', raw_output, re.IGNORECASE
)

if overall_score_match:
try:
return PaperEvalOutput(overall_score=int(match.group(1)))
overall_score = int(overall_score_match.group(1))
except ValueError as e:
raise OutputFormatError(f'Invalid overall score: {e}')
else:
raise OutputFormatError("Output format error: 'Overall Score' not found")
raise OutputFormatError(
f"Output format error: 'Overall Score' not found. Raw output is {raw_output}."
)

if dimension_scores_match:
try:
dimension_scores = list(
map(int, dimension_scores_match.group(1).split(','))
)
except ValueError as e:
raise OutputFormatError(f'Invalid dimension scores: {e}')
else:
raise OutputFormatError(
f"Output format error: 'Dimension Scores' not found. Raw output is {raw_output}."
)
return PaperEvalOutput(
overall_score=overall_score, dimension_scores=dimension_scores
)


class ReviewQualityEvaluator(object):
Expand Down Expand Up @@ -104,13 +150,35 @@ def eval(
return self.parsed_output

def parse(self, raw_output: str) -> ReviewEvalOutput:
match = re.search(r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE)
if match:
overall_score_match = re.search(
r'Overall\s*Score\s*\W*(\d+)\W*', raw_output, re.IGNORECASE
)
dimension_scores_match = re.search(
r'Dimension\s*Scores\s*\W*\s*\[([0-9,\s]+)\]', raw_output, re.IGNORECASE
)

if overall_score_match:
try:
return ReviewEvalOutput(overall_score=int(match.group(1)))
overall_score = int(overall_score_match.group(1))
except ValueError as e:
raise OutputFormatError(f'Invalid overall score: {e}')
else:
raise OutputFormatError(
f"Output format error: 'Overall Score' not found. Raw output is {raw_output}."
)

if dimension_scores_match:
try:
dimension_scores = list(
map(int, dimension_scores_match.group(1).split(','))
)
except ValueError as e:
raise OutputFormatError(f'Invalid dimension scores: {e}')
else:
raise OutputFormatError(
f"Output format error: 'Dimension Scores' not found. Raw output is {raw_output}."
)

return ReviewEvalOutput(
overall_score=overall_score, dimension_scores=dimension_scores
)
Loading

0 comments on commit faa8172

Please sign in to comment.