Skip to content

Commit

Permalink
support review paper and find collaborators (#41)
Browse files Browse the repository at this point in the history
* Initial commit

* support project template (#1)

* support template

* init test

* change test file name

* fix mypy

* support empty class structure (#3)

* just create a structure

* pass all tests

* Make mytest checking more relaxed in research town (#7)

Import errors eliminated when push

* sync with template repo (#9)

* delete useless stubs and sync with template repo

* sync with project template

* Fix error in init town (#8)

* Add files via upload

* Add files via upload

* Add files via upload

* Test whether more relaxed mypy could reduce captured error in init-town

* sync template

* support research_town

* fix codespell bug

* fix formatting

* test pytest

* fix mypy errors

* try to pass pytest

* try to pass pytest

* pass all pre-commit

* pass all pre-commit

* pass all pre-commit

* fix isort bug

* fix isort bug

* fix mypy types

* fix ruff

* delete install mypy things

* fix isort

* fix isort

* update poetry dependency

* add torch support

* add isort support

* add transformers

* fix pytest error

* fix pytest error

* solve pytest warning

* solve pytest

---------

Co-authored-by: ft2023 <[email protected]>
Co-authored-by: Haofei Yu <[email protected]>

* fix typo in issue template (#18)

* Update codebase_reorg.yml

* fix bug issue

* split the utils and reorg (#19)

* split the utils

* fix isort

* move prompting into agents

* move prompting into agents

* Fix pytest warning on arxiv lib (#26)

* support communicate method for agent (#24)

Co-authored-by: Haofei Yu <[email protected]>

* support write abstract method for agent (#25)

* Implemented write abstract method for agent

* fix merge error

* add communicate func

---------

Co-authored-by: Haofei Yu <[email protected]>

* fix pytest segmentation fault bug (#29)

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Update test_agent_base.py

* Update agent_base.py

* Update test_agent_base.py

* fix ruff

* fix testing path error

---------

Co-authored-by: ft2023 <[email protected]>

* support review paper in agent class (#15)

* test review paper (#15)

* test review paper (#15)

* support find collaborators (#16)

* fix test errors

* fix test errors

* fix test error

* update prompting function name

---------

Co-authored-by: Haofei Yu <[email protected]>
Co-authored-by: Haofei Yu <[email protected]>
Co-authored-by: Jinwei <[email protected]>
Co-authored-by: ft2023 <[email protected]>
Co-authored-by: timsanders256 <[email protected]>
  • Loading branch information
6 people authored May 15, 2024
1 parent d73b031 commit 947c280
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 26 deletions.
60 changes: 40 additions & 20 deletions research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import requests

from ..utils.agent_prompting import (
communicate_with_multiple_researchers,
generate_ideas,
summarize_research_direction,
summarize_research_field,
write_paper_abstract,
communicate_with_multiple_researchers_prompting,
find_collaborators_prompting,
generate_ideas_prompting,
review_paper_prompting,
summarize_research_direction_prompting,
summarize_research_field_prompting,
write_paper_abstract_prompting,
)
from ..utils.author_relation import bfs
from ..utils.paper_collection import get_bert_embedding
Expand Down Expand Up @@ -42,20 +44,24 @@ def get_profile(self, author_name: str) -> Dict[str, Any]:
papers_by_year: Dict[int, List[ElementTree.Element]] = {}

for entry in entries:
title = self.find_text(entry, "{http://www.w3.org/2005/Atom}title")
title = self.find_text(
entry, "{http://www.w3.org/2005/Atom}title")
published = self.find_text(
entry, "{http://www.w3.org/2005/Atom}published"
)
abstract = self.find_text(entry, "{http://www.w3.org/2005/Atom}summary")
authors_elements = entry.findall("{http://www.w3.org/2005/Atom}author")
abstract = self.find_text(
entry, "{http://www.w3.org/2005/Atom}summary")
authors_elements = entry.findall(
"{http://www.w3.org/2005/Atom}author")
authors = [
self.find_text(author, "{http://www.w3.org/2005/Atom}name")
for author in authors_elements
]
link = self.find_text(entry, "{http://www.w3.org/2005/Atom}id")

if author_name in authors:
coauthors = [author for author in authors if author != author_name]
coauthors = [
author for author in authors if author != author_name]
coauthors_str = ", ".join(coauthors)

papers_list.append(
Expand Down Expand Up @@ -122,15 +128,15 @@ def get_profile(self, author_name: str) -> Dict[str, Any]:
personal_info = "; ".join(
[f"{details['Title & Abstract']}" for details in papers_list]
)
info = summarize_research_direction(personal_info)
info = summarize_research_direction_prompting(personal_info)
return {"name": author_name, "profile": info[0]}

else:
print("Failed to fetch data from arXiv.")
return {"info": "fail!"}

def communicate(self, message: Dict[str, str]) -> str:
return communicate_with_multiple_researchers(message)[0]
return communicate_with_multiple_researchers_prompting(message)[0]

def read_paper(
self, external_data: Dict[str, Dict[str, List[str]]], domain: str
Expand All @@ -142,21 +148,34 @@ def read_paper(
papers_embedding = get_bert_embedding(papers)
time_chunks_embed[time] = papers_embedding

trend = summarize_research_field(
trend = summarize_research_field_prompting(
profile=self.profile,
keywords=[domain],
dataset=dataset,
data_embedding=time_chunks_embed,
) # trend
trend_output=trend[0]
trend_output = trend[0]
return trend_output

def find_collaborators(self, input: Dict[str, str]) -> List[str]:
return ["Alice", "Bob", "Charlie"]
def find_collaborators(self, input: Dict[str, str], parameter: float =0.5, max_number: int =3) -> List[str]:
start_author = [self.name]
graph, _, _ = bfs(
author_list=start_author, node_limit=max_number)
collaborators = list(
{name for pair in graph for name in pair if name != self.name})
self_profile = {self.name: self.profile["profile"]}
collaborator_profiles = {author: self.get_profile(
author)["profile"] for author in collaborators}
result = find_collaborators_prompting(
input, self_profile, collaborator_profiles, parameter, max_number)
collaborators_list = [
collaborator for collaborator in collaborators if collaborator in result]
return collaborators_list

def get_co_author_relationships(self, name: str, max_node: int) -> Tuple[List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]], Dict[str, List[Dict[str, Any]]]]:
start_author = [name]
graph, node_feat, edge_feat = bfs(author_list=start_author, node_limit=max_node)
graph, node_feat, edge_feat = bfs(
author_list=start_author, node_limit=max_node)
return graph, node_feat, edge_feat

def generate_idea(
Expand All @@ -169,25 +188,26 @@ def generate_idea(
papers_embedding = get_bert_embedding(papers)
time_chunks_embed[time] = papers_embedding

trends, paper_links = summarize_research_field(
trends, paper_links = summarize_research_field_prompting(
profile=self.profile,
keywords=[domain],
dataset=dataset,
data_embedding=time_chunks_embed,
) # trend
ideas: List[str] = []
for trend in trends:
idea = generate_ideas(trend)[0]
idea = generate_ideas_prompting(trend)[0]
ideas.append(idea)

return ideas

def write_paper(self, input: List[str], external_data: Dict[str, Dict[str, List[str]]]) -> str:
paper_abstract = write_paper_abstract(input, external_data)
paper_abstract = write_paper_abstract_prompting(input, external_data)
return paper_abstract[0]

def review_paper(self, input: Dict[str, str], external_data: Dict[str, str]) -> str:
return "review comments"
paper_review = review_paper_prompting(input, external_data)
return paper_review[0]

def make_review_decision(
self, input: Dict[str, str], external_data: Dict[str, str]
Expand Down
63 changes: 58 additions & 5 deletions research_town/utils/agent_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def find_nearest_neighbors(data_embeddings: List[Any], query_embedding: Any, num
return neighbors.tolist()


def summarize_research_field(
def summarize_research_field_prompting(
profile: Dict[str, str],
keywords: List[str],
dataset: Dict[str, Any],
Expand Down Expand Up @@ -71,8 +71,32 @@ def summarize_research_field(

return openai_prompting(llm_model, prompt)

def find_collaborators_prompting(input: Dict[str, str], self_profile: Dict[str, str], collaborator_profiles: Dict[str, str], parameter: float =0.5, max_number: int =3, llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",) -> List[str]:
self_serialize = [
f"Name: {name}\nProfile: {self_profile[name]}" for _, name in enumerate(self_profile.keys())]
self_serialize_all = "\n\n".join(self_serialize)

task_serialize = [f"Time: {timestamp}\nAbstract: {input[timestamp]}\n" for _,
timestamp in enumerate(input.keys())]
task_serialize_all = "\n\n".join(task_serialize)

collaborator_serialize = [
f"Name: {name}\nProfile: {collaborator_profiles[name]}" for _, name in enumerate(collaborator_profiles.keys())]
collaborator_serialize_all = "\n\n".join(collaborator_serialize)

prompt_qa = (
"Given the name and profile of me, could you find {max_number} collaborators for the following collaboration task?"
"Here is my profile: {self_serialize_all}"
"The collaboration task include: {task_serialize_all}"
"Here are a full list of the names and profiles of potential collaborators: {collaborators_serialize_all}"
"Generate the collaborator in a list separated by '-' for each collaborator"
)
input = {"max_number": str(max_number), "self_serialize_all": self_serialize_all,
"task_serialize_all": task_serialize_all, "collaborators_serialize_all": collaborator_serialize_all}
prompt = prompt_qa.format_map(input)
return openai_prompting(llm_model, prompt)

def generate_ideas(
def generate_ideas_prompting(
trend: str,
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
) -> List[str]:
Expand All @@ -89,7 +113,7 @@ def generate_ideas(
return openai_prompting(llm_model, prompt)


def summarize_research_direction(
def summarize_research_direction_prompting(
personal_info: str,
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
) -> List[str]:
Expand All @@ -106,7 +130,7 @@ def summarize_research_direction(
return openai_prompting(llm_model, prompt)


def write_paper_abstract(
def write_paper_abstract_prompting(
ideas: List[str],
external_data: Dict[str, Dict[str, List[str]]],
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
Expand Down Expand Up @@ -138,8 +162,37 @@ def write_paper_abstract(
prompt = prompt_template.format_map(template_input)
return openai_prompting(llm_model, prompt)

def review_paper_prompting(titles: Dict[str, str], external_data: Dict[str, str], llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1") -> List[str]:
"""
Review paper from using list, and external data (published papers)
"""

titles_serialize = []
for _, timestamp in enumerate(titles.keys()):
title_entry = f"Time: {timestamp}\nPaper: {external_data[timestamp]}"
titles_serialize.append(title_entry)
titles_serialize_all = "\n\n".join(titles_serialize)

papers_serialize = []
for _, timestamp in enumerate(external_data.keys()):
paper_entry = f"Time: {timestamp}\nPaper: {external_data[timestamp]}"
papers_serialize.append(paper_entry)
papers_serialize_all = "\n\n".join(papers_serialize)

prompt_qa = (
"Please give some reviews based on the following inputs and external data."
"You might use two or more of these titles if they are related and works well together."
"Here are the titles: {titles_serialize_all}"
"Here are the external data, which is a list of related papers: {papers_serialize_all}"
)

input = {"titles_serialize_all": titles_serialize_all,
"papers_serialize_all": papers_serialize_all}

prompt = prompt_qa.format_map(input)
return openai_prompting(llm_model, prompt)

def communicate_with_multiple_researchers(
def communicate_with_multiple_researchers_prompting(
input: Dict[str, str],
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
) -> List[str]:
Expand Down
20 changes: 19 additions & 1 deletion tests/test_agent_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

from typing import List
from unittest.mock import MagicMock, patch

from research_town.agents.agent_base import BaseResearchAgent
Expand Down Expand Up @@ -42,6 +42,15 @@ def test_write_paper_abstract(mock_openai_prompting: MagicMock) -> None:
assert isinstance(abstract, str)
assert abstract != ""

@patch("research_town.utils.agent_prompting.openai_prompting")
def test_review_paper(mock_openai_prompting: MagicMock) -> None:
mock_openai_prompting.return_value = ["This is a paper review for MambaOut."]

research_agent = BaseResearchAgent("Jiaxuan You")
review = research_agent.review_paper(input={"13 May 2024": "MambaOut: Do We Really Need Mamba for Vision?"}, external_data={"13 May 2024": "Mamba, an architecture with RNN-like token mixer of state space model (SSM), was recently introduced to address the quadratic complexity of the attention mechanism and subsequently applied to vision tasks. Nevertheless, the performance of Mamba for vision is often underwhelming when compared with convolutional and attention-based models. In this paper, we delve into the essence of Mamba, and conceptually conclude that Mamba is ideally suited for tasks with long-sequence and autoregressive characteristics. For vision tasks, as image classification does not align with either characteristic, we hypothesize that Mamba is not necessary for this task; Detection and segmentation tasks are also not autoregressive, yet they adhere to the long-sequence characteristic, so we believe it is still worthwhile to explore Mamba's potential for these tasks. To empirically verify our hypotheses, we construct a series of models named \\emph{MambaOut} through stacking Mamba blocks while removing their core token mixer, SSM. Experimental results strongly support our hypotheses. Specifically, our MambaOut model surpasses all visual Mamba models on ImageNet image classification, indicating that Mamba is indeed unnecessary for this task. As for detection and segmentation, MambaOut cannot match the performance of state-of-the-art visual Mamba models, demonstrating the potential of Mamba for long-sequence visual tasks."})
assert isinstance(review, str)
assert review != ""

@patch("research_town.utils.agent_prompting.openai_prompting")
def test_read_paper(mock_openai_prompting: MagicMock) -> None:
mock_openai_prompting.return_value = ["This is a paper"]
Expand All @@ -51,3 +60,12 @@ def test_read_paper(mock_openai_prompting: MagicMock) -> None:
research_agent = BaseResearchAgent("Jiaxuan You")
summary = research_agent.read_paper(external_data, domain)
assert isinstance(summary, str)

@patch("research_town.utils.agent_prompting.openai_prompting")
def test_find_collaborators(mock_openai_prompting: MagicMock) -> None:
mock_openai_prompting.return_value = ["These are collaborators including Jure Leskovec, Rex Ying, Saining Xie, Kaiming He."]

research_agent = BaseResearchAgent("Jiaxuan You")
collaborators = research_agent.find_collaborators(
input={"11 May 2024": "Organize a workshop on how far are we from AGI (artificial general intelligence) at ICLR 2024. This workshop aims to become a melting pot for ideas, discussions, and debates regarding our proximity to AGI."}, parameter=0.5, max_number=3)
assert isinstance(collaborators, List)

0 comments on commit 947c280

Please sign in to comment.