Skip to content

Commit

Permalink
fix mypy error
Browse files Browse the repository at this point in the history
  • Loading branch information
lwaekfjlk committed Jun 1, 2024
1 parent 5819de9 commit c48205b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 34 deletions.
35 changes: 11 additions & 24 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ target-version = "py310"
[tool.ruff.format]
quote-style = "single"
indent-style = "space"
docstring-code-format = true
docstring-code-line-length = 88

[tool.mypy-arxiv]
Expand Down
9 changes: 7 additions & 2 deletions research_town/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

class BaseMultiAgentEnv(object):
def __init__(self, agent_profiles: List[AgentProfile]) -> None:
self.env_run_number = 0
self.max_env_run_number = 1
self.terminated = False
self.agent_profiles: List[AgentProfile] = agent_profiles
self.db = EnvLogDB()
Expand Down Expand Up @@ -36,13 +38,16 @@ def step(self) -> LogType:
try:
return next(self.step_obj)
except Exception:
self.terminated = True
if self.env_run_number < self.max_env_run_number:
self.env_run_number += 1
else:
self.terminated = True
self.step_obj = self._step()
return next(self.step_obj)
else:
return next(
self.log(
f"Call 'step()' on a envionment that has terminated ({self.turn_number} / {self.turn_max}).",
f"Call 'step()' on a environment that has terminated ({self.env_run_number} / {self.max_env_run_number}).",
'ERROR',
)
)
6 changes: 3 additions & 3 deletions research_town/utils/eval_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def idea_quality_eval_prompting(
3. Significance
Rating (1-10):
Comments:
Evaluate the potential impact of the idea on the specfic domain of research community that the idea belongs to and beyond.
Evaluate the potential impact of the idea on the specific domain of research community that the idea belongs to and beyond.
How significant is its contribution to advancing the field?
Does it address high-impact problems or gaps identified in the trend?
How applicable is it in practical settings and industry contexts?
Expand Down Expand Up @@ -75,7 +75,7 @@ def idea_quality_eval_prompting(
def paper_quality_eval_prompting(
idea: str, paper: Dict[str, str], model_name: str, trend: Optional[str] = None
) -> str:
# refer to idea eval, but replace those not needed, and paraphrase thoese have overlaps.
# refer to idea eval, but replace those not needed, and paraphrase those have overlaps.
paper_prompt = """
<Instruction> Please evaluate the paper draft based on the following dimensions. Finally, give an overall score (0-100) and 10 dimension scores (for each dimension, provide a rating (1-10)) as the evaluation for the draft. .
<Instruction>
Expand Down Expand Up @@ -107,7 +107,7 @@ def paper_quality_eval_prompting(
3. Significance
Rating (1-10):
Comments:
Evaluate the potential contribution and impact of the paper on the specfic domain of research community that the paper belongs to and beyond.
Evaluate the potential contribution and impact of the paper on the specific domain of research community that the paper belongs to and beyond.
How does it compare to existing works in terms of impact?
4. Rigorousness
Rating (1-10):
Expand Down
8 changes: 4 additions & 4 deletions research_town/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


class ColoredFormatter(logging.Formatter):
def format(self, record):
def format(self: logging.Formatter, record: logging.LogRecord) -> Any:
msg_type = record.__dict__.get('msg_type', None)
if msg_type in LOG_COLORS:
msg_type_color = colored(msg_type, LOG_COLORS[msg_type])
Expand All @@ -44,13 +44,13 @@ def format(self, record):
)
name_str = colored(record.name, LOG_COLORS[msg_type])
level_str = colored(record.levelname, LOG_COLORS[msg_type])
if msg_type in ['ERROR']:
if msg_type == 'ERROR':
return f'{time_str} - {name_str}:{level_str}: {record.filename}:{record.lineno}\n{msg_type_color}\n{msg}'
return f'{time_str} - {msg_type_color}\n{msg}'
elif msg_type == 'STEP':
msg = '\n\n==============\n' + record.msg + '\n'
return f'{msg}'
return super().format(record)
return logging.Formatter.format(self, record)


console_formatter = ColoredFormatter(
Expand All @@ -59,7 +59,7 @@ def format(self, record):
)


def get_console_handler():
def get_console_handler() -> logging.StreamHandler[Any]:
"""
Returns a console handler for logging.
"""
Expand Down

0 comments on commit c48205b

Please sign in to comment.