Skip to content

Commit

Permalink
mocking retrieval paper process for testing and adjust input names (#67)
Browse files Browse the repository at this point in the history
* rewrite find related papers

* fix the retrieval mocking problem

* support

* rewrite find related papers

* rewrite find related papers
  • Loading branch information
lwaekfjlk authored May 22, 2024
1 parent ddba8ba commit 3b20c1a
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 458 deletions.
497 changes: 137 additions & 360 deletions poetry.lock

Large diffs are not rendered by default.

40 changes: 12 additions & 28 deletions research_town/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import requests

from ..utils.agent_prompting import (
from ..utils.agent_prompter import (
communicate_with_multiple_researchers_prompting,
find_collaborators_prompting,
generate_ideas_prompting,
Expand All @@ -16,8 +16,7 @@
summarize_research_field_prompting,
write_paper_abstract_prompting,
)
from ..utils.author_relation import bfs
from ..utils.paper_collection import get_bert_embedding
from ..utils.author_collector import bfs

ATOM_NAMESPACE = "{http://www.w3.org/2005/Atom}"

Expand Down Expand Up @@ -144,21 +143,13 @@ def communicate(self, message: Dict[str, str]) -> str:
return communicate_with_multiple_researchers_prompting(message)[0]

def read_paper(
self, external_data: Dict[str, Dict[str, List[str]]], domain: str
self, papers: Dict[str, Dict[str, List[str]]], domain: str
) -> str:
time_chunks_embed = {}
dataset = external_data
for time in dataset.keys():
papers = dataset[time]["abstract"]
papers_embedding = get_bert_embedding(papers)
time_chunks_embed[time] = papers_embedding

trend = summarize_research_field_prompting(
profile=self.profile,
keywords=[domain],
dataset=dataset,
data_embedding=time_chunks_embed,
) # trend
papers=papers,
)
trend_output = trend[0]
return trend_output

Expand All @@ -184,34 +175,27 @@ def get_co_author_relationships(self, name: str, max_node: int) -> Tuple[List[Tu
return graph, node_feat, edge_feat

def generate_idea(
self, external_data: Dict[str, Dict[str, List[str]]], domain: str
self, papers: Dict[str, Dict[str, List[str]]], domain: str
) -> List[str]:
time_chunks_embed = {}
dataset = external_data
for time in dataset.keys():
papers = dataset[time]["abstract"]
papers_embedding = get_bert_embedding(papers)
time_chunks_embed[time] = papers_embedding

trends = summarize_research_field_prompting(
profile=self.profile,
keywords=[domain],
dataset=dataset,
data_embedding=time_chunks_embed,
) # trend
papers=papers,
)
ideas: List[str] = []
for trend in trends:
idea = generate_ideas_prompting(trend)[0]
ideas.append(idea)

return ideas

def write_paper(self, input: List[str], external_data: Dict[str, Dict[str, List[str]]]) -> str:
paper_abstract = write_paper_abstract_prompting(input, external_data)
def write_paper(self, input: List[str], papers: Dict[str, Dict[str, List[str]]]) -> str:
paper_abstract = write_paper_abstract_prompting(input, papers)
return paper_abstract[0]

def review_paper(self, external_data: Dict[str, str]) -> Tuple[int, str]:
paper_review = review_paper_prompting(external_data)[0]
def review_paper(self, paper: Dict[str, str]) -> Tuple[int, str]:
paper_review = review_paper_prompting(paper)[0]
print(paper_review)
review_score = review_score_prompting(paper_review)
print(review_score, paper_review)
Expand Down
2 changes: 1 addition & 1 deletion research_town/envs/env_paper_rebuttal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def step(self) -> None:
for name, role in self.roles.items():
if role == "reviewer":
review_dict[name] = self.agents[name].review_paper(
external_data=self.submission)
paper=self.submission)
self.submit_review(review_dict)

# Decision Making
Expand Down
6 changes: 3 additions & 3 deletions research_town/envs/env_paper_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def __init__(self, agent_dict: Dict[str, str]) -> None:
super(PaperSubmissionMultiAgentEnvironment, self).__init__(agent_dict)

def step(self) -> None:
external_data = self.kb.get_data(10, "machine learning")
papers = self.kb.get_data(10, "machine learning")
for agent_name, agent in self.agents.items():
agent.read_paper(external_data=external_data, domain="machine learning")
agent.read_paper(papers=papers, domain="machine learning")
agent.find_collaborators({})
agent.generate_idea(external_data=external_data, domain="machine learning")
agent.generate_idea(papers=papers, domain="machine learning")
agent.write_paper([], {})

self.submit_paper()
Expand Down
2 changes: 1 addition & 1 deletion research_town/kbs/kb_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List

from ..utils.paper_collection import get_daily_papers
from ..utils.paper_collector import get_daily_papers


class BaseKnowledgeBase(object):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import openai

from .decorator import exponential_backoff
from .paper_collection import get_bert_embedding, neiborhood_search
from .paper_collector import get_related_papers

openai.api_base = "https://api.together.xyz"
openai.api_key = os.environ["TOGETHER_API_KEY"]


@exponential_backoff(retries=5, base_wait_time=1)
def openai_prompting(
llm_model: str,
Expand All @@ -28,23 +27,10 @@ def openai_prompting(
return content_l


def get_query_embedding(query: str) -> Any:
return get_bert_embedding([query])


def find_nearest_neighbors(data_embeddings: List[Any], query_embedding: Any, num_neighbors: int) -> Any:
neighbors = neiborhood_search(
data_embeddings, query_embedding, num_neighbors)
neighbors = neighbors.reshape(-1)

return neighbors.tolist()


def summarize_research_field_prompting(
profile: Dict[str, str],
keywords: List[str],
dataset: Dict[str, Any],
data_embedding: Dict[str, Any],
papers: Dict[str, Dict[str, List[str]]],
llm_model: Optional[str] = "mistralai/Mixtral-8x7B-Instruct-v0.1",
) -> List[str]:
"""
Expand All @@ -55,22 +41,15 @@ def summarize_research_field_prompting(
"Here is my profile: {profile}"
"Here are the keywords: {keywords}"
)

template_input = {"profile": profile, "keywords": keywords}
query = query_template.format_map(template_input)

query_embedding = get_query_embedding(query)

text_chunks = [abstract for papers in dataset.values()
for abstract in papers["abstract"]]
data_embeddings = [embedding for embeddings in data_embedding.values()
for embedding in embeddings]
corpus = [abstract for papers in papers.values()
for abstract in papers["abstract"]]

nearest_indices = find_nearest_neighbors(
data_embeddings, query_embedding, num_neighbors=10)
context = [text_chunks[i] for i in nearest_indices]
related_papers = get_related_papers(corpus, query, num=10)

template_input["papers"] = "; ".join(context)
template_input["papers"] = "; ".join(related_papers)
prompt = query_template.format_map(template_input)

return openai_prompting(llm_model, prompt)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from transformers import BertModel, BertTokenizer


def get_related_papers(
corpus: List[str], query: str, num: int
) -> List[str]:
corpus_embedding = get_bert_embedding(corpus)
query_embedding = get_bert_embedding([query])
indices = neiborhood_search(corpus_embedding, query_embedding, num)
related_papers = [corpus[idx] for idx in indices[0].tolist()]
return related_papers

def get_bert_embedding(instructions: List[str]) -> List[torch.Tensor]:
tokenizer = BertTokenizer.from_pretrained("facebook/contriever")
model = BertModel.from_pretrained("facebook/contriever").to(torch.device("cpu"))
Expand Down Expand Up @@ -39,9 +48,9 @@ def neiborhood_search(
faiss.normalize_L2(xb)
index.add(xb) # add vectors to the index
data, index = index.search(xq, neiborhood_num)

return index


def get_daily_papers(
topic: str, query: str = "slam", max_results: int = 2
) -> Tuple[Dict[str, Dict[str, List[str]]], str]:
Expand Down
Loading

0 comments on commit 3b20c1a

Please sign in to comment.