Skip to content

Commit

Permalink
support database testing and fix pytest warning (#134)
Browse files Browse the repository at this point in the history
* add test_db

* fix pre-commit error

* rename test files

* fix test_db_base.py bugs

* fix test_db load safe files bug

* add a empty folder for db saving

* fix mypy errors

* fix pytest warning and change saving position

* fix pre-commit errors

---------

Co-authored-by: Haofei Yu <[email protected]>
  • Loading branch information
Kunlun-Zhu and lwaekfjlk authored May 30, 2024
1 parent 8cba2f1 commit 7c8d1c0
Show file tree
Hide file tree
Showing 25 changed files with 283 additions and 33 deletions.
Empty file added data/dbs/.gitkeep
Empty file.
2 changes: 1 addition & 1 deletion examples/minimal_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from beartype.typing import Dict, List

from research_town.dbs import (
AgentProfile,
Expand Down
2 changes: 1 addition & 1 deletion research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple

from beartype import beartype
from beartype.typing import Any, Dict, List, Tuple

from ..dbs import (
AgentAgentDiscussionLog,
Expand Down
2 changes: 1 addition & 1 deletion research_town/dbs/agent_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional

from beartype.typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field


Expand Down
2 changes: 1 addition & 1 deletion research_town/dbs/env_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional, Type, TypeVar

from beartype.typing import Any, Dict, List, Optional, Type, TypeVar
from pydantic import BaseModel, Field

T = TypeVar('T', bound=BaseModel)
Expand Down
4 changes: 2 additions & 2 deletions research_town/dbs/paper_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional

from beartype.typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field

from ..utils.paper_collector import get_daily_papers
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self) -> None:
def add_paper(self, paper: PaperProfile) -> None:
self.data[paper.pk] = paper

def update_paper(self, paper_pk: str, updates: Dict[str, Optional[str]]) -> bool:
def update_paper(self, paper_pk: str, updates: Dict[str, Any]) -> bool:
if paper_pk in self.data:
for key, value in updates.items():
if value is not None:
Expand Down
2 changes: 1 addition & 1 deletion research_town/dbs/progress_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional, Type, TypeVar

from beartype.typing import Any, Dict, List, Optional, Type, TypeVar
from pydantic import BaseModel, Field

T = TypeVar('T', bound=BaseModel)
Expand Down
2 changes: 1 addition & 1 deletion research_town/envs/env_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from beartype.typing import List

from ..agents.agent_base import BaseResearchAgent
from ..dbs import AgentProfile, EnvLogDB
Expand Down
3 changes: 1 addition & 2 deletions research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Tuple

from beartype import beartype
from beartype.typing import Dict, List, Tuple

from ..dbs import (
AgentPaperMetaReviewLog,
Expand Down
3 changes: 1 addition & 2 deletions research_town/envs/env_paper_submission.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List

from beartype import beartype
from beartype.typing import Dict, List

from ..agents.agent_base import BaseResearchAgent
from ..dbs import (
Expand Down
3 changes: 1 addition & 2 deletions research_town/evaluators/output_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

from typing import Type, TypeVar

from beartype.typing import Type, TypeVar
from pydantic import BaseModel, Extra, Field, validator

T = TypeVar('T', bound=BaseModel)
Expand Down
3 changes: 2 additions & 1 deletion research_town/evaluators/quality_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import re
from typing import Any

from beartype.typing import Any

from ..utils.decorator import parsing_error_exponential_backoff
from ..utils.eval_prompter import (
Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/agent_collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Tuple

from arxiv import Client, Search
from beartype.typing import Any, Dict, List, Tuple
from tqdm import tqdm


Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/agent_prompter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Union

from beartype import beartype
from beartype.typing import Dict, List, Union

from .model_prompting import model_prompting
from .paper_collector import get_related_papers
Expand Down
9 changes: 8 additions & 1 deletion research_town/utils/decorator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import math
import time
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, cast

from beartype.typing import (
Any,
Callable,
List,
Optional,
TypeVar,
cast,
)
from pydantic import BaseModel

INF = float(math.inf)
Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/eval_prompter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict

from beartype import beartype
from beartype.typing import Dict

from .model_prompting import model_prompting

Expand Down
3 changes: 2 additions & 1 deletion research_town/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Dict, List, Union

from beartype.typing import Dict, List, Union

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
def logging_callback(messages: Union[List[Dict[str, str]], None] = None) -> None:
Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/model_prompting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Optional

import litellm
from beartype import beartype
from beartype.typing import List, Optional

from .decorator import api_calling_error_exponential_backoff

Expand Down
2 changes: 1 addition & 1 deletion research_town/utils/paper_collector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime
from typing import Any, Dict, List, Tuple
from xml.etree import ElementTree

import arxiv
import faiss
import requests
import torch
from beartype.typing import Any, Dict, List, Tuple
from transformers import BertModel, BertTokenizer

ATOM_NAMESPACE = "{http://www.w3.org/2005/Atom}"
Expand Down
2 changes: 1 addition & 1 deletion research_town/utils/serializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
from typing import Any, Dict, List, Set, Tuple, Union

from beartype.typing import Any, Dict, List, Set, Tuple, Union
from pydantic import BaseModel


Expand Down
2 changes: 1 addition & 1 deletion research_town/utils/string_mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Union
from beartype.typing import Dict, List, Union


def map_idea_list_to_str(ideas: List[Dict[str, str]]) -> str:
Expand Down
3 changes: 2 additions & 1 deletion research_town/utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import json
import os
from typing import Any, Dict

from beartype.typing import Any, Dict


def show_time() -> str:
Expand Down
Loading

0 comments on commit 7c8d1c0

Please sign in to comment.