Skip to content

Commit

Permalink
fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzr01 committed May 19, 2024
1 parent 26e577d commit 530e1da
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,5 @@ def make_review_decision(
) -> str:
return "accept"

def rebut_review(self, submission: str, review: str, decision: str) -> str:
def rebut_review(self, submission: Dict[str, str], review: Dict[str, str], decision: Dict[str, str]) -> str:
return "It should be accepted."
10 changes: 5 additions & 5 deletions research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Dict

from .env_base import BaseMultiAgentEnv
from ..agents.agent_base import BaseResearchAgent
from .env_base import BaseMultiAgentEnv


class PaperRebuttalMultiAgentEnv(BaseMultiAgentEnv):
def __init__(self, agent_dict: Dict[str, BaseResearchAgent]) -> None:
def __init__(self, agent_dict: Dict[str, str]) -> None:
super().__init__(agent_dict)
self.turn_number = 0
self.turn_max = 1
Expand Down Expand Up @@ -57,7 +57,7 @@ def step(self) -> None:
for name, role in self.roles.items():
if role == "reviewer":
decision_dict[name] = self.agents[name].make_review_decision(
input=self.submission, external_data=self.review)
input=self.submission, external_data=review_dict)
self.submit_decision(decision_dict)
# print("Decision Making", self.decision)

Expand All @@ -67,8 +67,8 @@ def step(self) -> None:
if role == "author":
rebuttal_dict[name] = self.agents[name].rebut_review(
submission=self.submission,
review=self.review,
decision=self.decision)
review=review_dict,
decision=decision_dict)
# print("Paper Rebuttal", self.rebuttal)

self.turn_number += 1
Expand Down
5 changes: 4 additions & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from research_town.envs.env_paper_rebuttal import PaperRebuttalMultiAgentEnv
from unittest.mock import MagicMock, patch

from research_town.envs.env_paper_rebuttal import (
PaperRebuttalMultiAgentEnv,
)


@patch("research_town.utils.agent_prompting.openai_prompting")
def test_paper_rebuttal_env(mock_openai_prompting: MagicMock) -> None:
Expand Down

0 comments on commit 530e1da

Please sign in to comment.