-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rebuild data type * add reserach paper class * change dict to list class * prompter should keep the same without any class in and out * rebuild data type in rebuttal-related functions * update agent list * rebuild data type for test functions for rebuttal-related functions * add constants for testing * add a rebuttal instance * update get profile * rebuild data type for submission-related functions * fix test errors * rebuild data type for test functions for submission-related functions * fix test errors * rebuild the database class * fix some bugs * add some optional param * udpate db class content * fix id to pk * fix pre-commit bug * fix pre-commit bug * fix some mypy error * add a TODO comment * fix part of mypy errors * fix part of mypy errors * fix mypy.ini format * recover test functions * fix pre-commit * add default --------- Co-authored-by: chengzr01 <[email protected]>
- Loading branch information
Showing
20 changed files
with
798 additions
and
278 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
[mypy] | ||
plugins = pydantic.mypy | ||
|
||
[mypy-arxiv.*] | ||
ignore_missing_imports = True | ||
|
||
[mypy-faiss.*] | ||
ignore_missing_imports = True | ||
|
||
[mypy-transformers.*] | ||
ignore_missing_imports = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from .agent_db import AgentProfile, AgentProfileDB | ||
from .env_db import ( | ||
AgentAgentDiscussionLog, | ||
AgentPaperMetaReviewLog, | ||
AgentPaperRebuttalLog, | ||
AgentPaperReviewLog, | ||
EnvLogDB, | ||
) | ||
from .paper_db import PaperProfile, PaperProfileDB | ||
from .progress_db import ( | ||
ResearchIdea, | ||
ResearchPaperDraft, | ||
ResearchProgressDB, | ||
) | ||
|
||
__all__ = [ | ||
"AgentAgentDiscussionLog", | ||
"AgentPaperMetaReviewLog", | ||
"AgentPaperRebuttalLog", | ||
"AgentPaperReviewLog", | ||
"PaperProfile", | ||
"AgentProfile", | ||
"ResearchIdea", | ||
"ResearchPaperDraft", | ||
"EnvLogDB", | ||
"PaperProfileDB", | ||
"AgentProfileDB", | ||
"ResearchProgressDB" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import json | ||
import uuid | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class AgentProfile(BaseModel): | ||
pk: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
name: Optional[str] = Field(default=None) | ||
bio: Optional[str] = Field(default=None) | ||
|
||
|
||
class AgentProfileDB(object): | ||
def __init__(self) -> None: | ||
self.data: Dict[str, AgentProfile] = {} | ||
|
||
def add(self, agent: AgentProfile) -> None: | ||
self.data[agent.pk] = agent | ||
|
||
def update(self, agent_pk: str, updates: Dict[str, Optional[str]]) -> bool: | ||
if agent_pk in self.data: | ||
for key, value in updates.items(): | ||
if value is not None: | ||
setattr(self.data[agent_pk], key, value) | ||
return True | ||
return False | ||
|
||
def delete(self, agent_pk: str) -> bool: | ||
if agent_pk in self.data: | ||
del self.data[agent_pk] | ||
return True | ||
return False | ||
|
||
def get(self, **conditions: Dict[str, Any]) -> List[AgentProfile]: | ||
result = [] | ||
for agent in self.data.values(): | ||
if all(getattr(agent, key) == value for key, value in conditions.items()): | ||
result.append(agent) | ||
return result | ||
|
||
def save_to_file(self, file_name: str) -> None: | ||
with open(file_name, "w") as f: | ||
json.dump({aid: agent.dict() | ||
for aid, agent in self.data.items()}, f, indent=2) | ||
|
||
def load_from_file(self, file_name: str) -> None: | ||
with open(file_name, "r") as f: | ||
data = json.load(f) | ||
self.data = {aid: AgentProfile(**agent_data) | ||
for aid, agent_data in data.items()} | ||
|
||
def update_db(self, data: Dict[str, List[Dict[str, Any]]]) -> None: | ||
for date, agents in data.items(): | ||
for agent_data in agents: | ||
agent = AgentProfile(**agent_data) | ||
self.add(agent) |
Oops, something went wrong.