Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support database class testing #134

Merged
merged 12 commits into from
May 30, 2024
Empty file added data/dbs/.gitkeep
Empty file.
2 changes: 1 addition & 1 deletion examples/minimal_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from beartype.typing import Dict, List

from research_town.dbs import (
AgentProfile,
Expand Down
2 changes: 1 addition & 1 deletion research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple

from beartype import beartype
from beartype.typing import Any, Dict, List, Tuple

from ..dbs import (
AgentAgentDiscussionLog,
Expand Down
2 changes: 1 addition & 1 deletion research_town/dbs/agent_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional

from beartype.typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field


Expand Down
2 changes: 1 addition & 1 deletion research_town/dbs/env_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional, Type, TypeVar

from beartype.typing import Any, Dict, List, Optional, Type, TypeVar
from pydantic import BaseModel, Field

T = TypeVar('T', bound=BaseModel)
Expand Down
4 changes: 2 additions & 2 deletions research_town/dbs/paper_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional

from beartype.typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field

from ..utils.paper_collector import get_daily_papers
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self) -> None:
def add_paper(self, paper: PaperProfile) -> None:
self.data[paper.pk] = paper

def update_paper(self, paper_pk: str, updates: Dict[str, Optional[str]]) -> bool:
def update_paper(self, paper_pk: str, updates: Dict[str, Any]) -> bool:
if paper_pk in self.data:
for key, value in updates.items():
if value is not None:
Expand Down
2 changes: 1 addition & 1 deletion research_town/dbs/progress_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
from typing import Any, Dict, List, Optional, Type, TypeVar

from beartype.typing import Any, Dict, List, Optional, Type, TypeVar
from pydantic import BaseModel, Field

T = TypeVar('T', bound=BaseModel)
Expand Down
2 changes: 1 addition & 1 deletion research_town/envs/env_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from beartype.typing import List

from ..agents.agent_base import BaseResearchAgent
from ..dbs import AgentProfile, EnvLogDB
Expand Down
3 changes: 1 addition & 2 deletions research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Tuple

from beartype import beartype
from beartype.typing import Dict, List, Tuple

from ..dbs import (
AgentPaperMetaReviewLog,
Expand Down
3 changes: 1 addition & 2 deletions research_town/envs/env_paper_submission.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List

from beartype import beartype
from beartype.typing import Dict, List

from ..agents.agent_base import BaseResearchAgent
from ..dbs import (
Expand Down
3 changes: 1 addition & 2 deletions research_town/evaluators/output_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

from typing import Type, TypeVar

from beartype.typing import Type, TypeVar
from pydantic import BaseModel, Extra, Field, validator

T = TypeVar('T', bound=BaseModel)
Expand Down
3 changes: 2 additions & 1 deletion research_town/evaluators/quality_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import re
from typing import Any

from beartype.typing import Any

from ..utils.decorator import parsing_error_exponential_backoff
from ..utils.eval_prompter import (
Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/agent_collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Tuple

from arxiv import Client, Search
from beartype.typing import Any, Dict, List, Tuple
from tqdm import tqdm


Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/agent_prompter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Union

from beartype import beartype
from beartype.typing import Dict, List, Union

from .model_prompting import model_prompting
from .paper_collector import get_related_papers
Expand Down
9 changes: 8 additions & 1 deletion research_town/utils/decorator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import math
import time
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, cast

from beartype.typing import (
Any,
Callable,
List,
Optional,
TypeVar,
cast,
)
from pydantic import BaseModel

INF = float(math.inf)
Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/eval_prompter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict

from beartype import beartype
from beartype.typing import Dict

from .model_prompting import model_prompting

Expand Down
3 changes: 2 additions & 1 deletion research_town/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Dict, List, Union

from beartype.typing import Dict, List, Union

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
def logging_callback(messages: Union[List[Dict[str, str]], None] = None) -> None:
Expand Down
3 changes: 1 addition & 2 deletions research_town/utils/model_prompting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Optional

import litellm
from beartype import beartype
from beartype.typing import List, Optional

from .decorator import api_calling_error_exponential_backoff

Expand Down
2 changes: 1 addition & 1 deletion research_town/utils/paper_collector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime
from typing import Any, Dict, List, Tuple
from xml.etree import ElementTree

import arxiv
import faiss
import requests
import torch
from beartype.typing import Any, Dict, List, Tuple
from transformers import BertModel, BertTokenizer

ATOM_NAMESPACE = "{http://www.w3.org/2005/Atom}"
Expand Down
2 changes: 1 addition & 1 deletion research_town/utils/serializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
from typing import Any, Dict, List, Set, Tuple, Union

from beartype.typing import Any, Dict, List, Set, Tuple, Union
from pydantic import BaseModel


Expand Down
2 changes: 1 addition & 1 deletion research_town/utils/string_mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Union
from beartype.typing import Dict, List, Union


def map_idea_list_to_str(ideas: List[Dict[str, str]]) -> str:
Expand Down
3 changes: 2 additions & 1 deletion research_town/utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import json
import os
from typing import Any, Dict

from beartype.typing import Any, Dict


def show_time() -> str:
Expand Down
Loading
Loading