diff --git a/research_town/agents/agent_base.py b/research_town/agents/agent_base.py index 5d8a5181..40157ab1 100644 --- a/research_town/agents/agent_base.py +++ b/research_town/agents/agent_base.py @@ -16,6 +16,7 @@ from ..utils.author_relation import bfs from ..utils.paper_collection import get_bert_embedding +ATOM_NAMESPACE = "{http://www.w3.org/2005/Atom}" class BaseResearchAgent(object): def __init__(self, name: str) -> None: @@ -38,103 +39,104 @@ def get_profile(self, author_name: str) -> Dict[str, Any]: if response.status_code == 200: root = ElementTree.fromstring(response.content) - entries = root.findall("{http://www.w3.org/2005/Atom}entry") - - total_papers = 0 - papers_by_year: Dict[int, List[ElementTree.Element]] = {} - - for entry in entries: - title = self.find_text( - entry, "{http://www.w3.org/2005/Atom}title") - published = self.find_text( - entry, "{http://www.w3.org/2005/Atom}published" - ) - abstract = self.find_text( - entry, "{http://www.w3.org/2005/Atom}summary") - authors_elements = entry.findall( - "{http://www.w3.org/2005/Atom}author") - authors = [ - self.find_text(author, "{http://www.w3.org/2005/Atom}name") - for author in authors_elements - ] - link = self.find_text(entry, "{http://www.w3.org/2005/Atom}id") - - if author_name in authors: - coauthors = [ - author for author in authors if author != author_name] - coauthors_str = ", ".join(coauthors) - - papers_list.append( - { - "date": published, - "Title & Abstract": f"{title}; {abstract}", - "coauthors": coauthors_str, - "link": link, - } - ) - - total_papers += 1 - published_date = published - date_obj = datetime.datetime.strptime( - published_date, "%Y-%m-%dT%H:%M:%SZ" - ) - year = date_obj.year - if year not in papers_by_year: - papers_by_year[year] = [] - papers_by_year[year].append(entry) - - if total_papers > 40: - for cycle_start in range( - min(papers_by_year), max(papers_by_year) + 1, 5 - ): - cycle_end = cycle_start + 4 - for year in range(cycle_start, cycle_end + 1): - if year in papers_by_year: - selected_papers = papers_by_year[year][:2] - for paper in selected_papers: - title = self.find_text( - paper, "{http://www.w3.org/2005/Atom}title" - ) - abstract = self.find_text( - paper, "{http://www.w3.org/2005/Atom}summary" - ) - authors_elements = paper.findall( - "{http://www.w3.org/2005/Atom}author" - ) - co_authors = [ - self.find_text( - author, "{http://www.w3.org/2005/Atom}name" - ) - for author in authors_elements - if self.find_text( - author, "{http://www.w3.org/2005/Atom}name" - ) - != author_name - ] - - papers_list.append( - { - "Author": author_name, - "Title & Abstract": f"{title}; {abstract}", - "Date Period": f"{year}", - "Cycle": f"{cycle_start}-{cycle_end}", - "Co_author": ", ".join(co_authors), - } - ) + entries = root.findall(f"{ATOM_NAMESPACE}entry") + papers_list, papers_by_year = self._get_papers(entries, author_name) + if len(papers_list) > 40: + papers_list = self._select_papers(papers_by_year, author_name) + # Trim the list to the 10 most recent papers papers_list = papers_list[:10] personal_info = "; ".join( [f"{details['Title & Abstract']}" for details in papers_list] ) - info = summarize_research_direction_prompting(personal_info) - return {"name": author_name, "profile": info[0]} + profile_info = summarize_research_direction_prompting(personal_info) + return {"name": author_name, "profile": profile_info[0]} else: print("Failed to fetch data from arXiv.") return {"info": "fail!"} + def _get_papers(self, entries: List[ElementTree.Element], author_name: str) -> Tuple[List[Dict[str, Any]], Dict[int, List[ElementTree.Element]]]: + papers_list: List[Dict[str, Any]] = [] + papers_by_year: Dict[int, List[ElementTree.Element]] = {} + + for entry in entries: + title = self.find_text(entry, f"{ATOM_NAMESPACE}title") + published = self.find_text(entry, f"{ATOM_NAMESPACE}published") + abstract = self.find_text(entry, f"{ATOM_NAMESPACE}summary") + authors_elements = entry.findall(f"{ATOM_NAMESPACE}author") + authors = [ + self.find_text(author, f"{ATOM_NAMESPACE}name") + for author in authors_elements + ] + link = self.find_text(entry, f"{ATOM_NAMESPACE}id") + + if author_name in authors: + coauthors = [author for author in authors if author != author_name] + coauthors_str = ", ".join(coauthors) + + papers_list.append( + { + "date": published, + "Title & Abstract": f"{title}; {abstract}", + "coauthors": coauthors_str, + "link": link, + } + ) + + published_date = published + date_obj = datetime.datetime.strptime( + published_date, "%Y-%m-%dT%H:%M:%SZ" + ) + year = date_obj.year + if year not in papers_by_year: + papers_by_year[year] = [] + papers_by_year[year].append(entry) + + return papers_list, papers_by_year + + def _select_papers(self, papers_by_year: Dict[int, List[ElementTree.Element]], author_name: str) -> List[Dict[str, Any]]: + papers_list: List[Dict[str, Any]] = [] + + for cycle_start in range(min(papers_by_year), max(papers_by_year) + 1, 5): + cycle_end = cycle_start + 4 + for year in range(cycle_start, cycle_end + 1): + if year in papers_by_year: + selected_papers = papers_by_year[year][:2] + for paper in selected_papers: + title = self.find_text( + paper, f"{ATOM_NAMESPACE}title" + ) + abstract = self.find_text( + paper, f"{ATOM_NAMESPACE}summary" + ) + authors_elements = paper.findall( + f"{ATOM_NAMESPACE}author" + ) + co_authors = [ + self.find_text( + author, f"{ATOM_NAMESPACE}name" + ) + for author in authors_elements + if self.find_text( + author, f"{ATOM_NAMESPACE}name" + ) + != author_name + ] + + papers_list.append( + { + "Author": author_name, + "Title & Abstract": f"{title}; {abstract}", + "Date Period": f"{year}", + "Cycle": f"{cycle_start}-{cycle_end}", + "Co_author": ", ".join(co_authors), + } + ) + return papers_list + def communicate(self, message: Dict[str, str]) -> str: return communicate_with_multiple_researchers_prompting(message)[0]