From 31742c7a050e3e172f8c0ac2528f5c58370579bc Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Tue, 3 Dec 2024 05:17:27 -0600 Subject: [PATCH] fix the before_year mode --- .../create_bench_from_paper_links.py | 17 ++++--- research_bench/create_oodbench.sh | 2 +- research_bench/eval_only.py | 8 +++- research_bench/utils.py | 8 +++- research_town/dbs/db_profile.py | 3 +- research_town/utils/profile_collector.py | 47 ++++--------------- 6 files changed, 34 insertions(+), 51 deletions(-) diff --git a/research_bench/create_bench_from_paper_links.py b/research_bench/create_bench_from_paper_links.py index 9564b647..892d3c94 100644 --- a/research_bench/create_bench_from_paper_links.py +++ b/research_bench/create_bench_from_paper_links.py @@ -28,14 +28,14 @@ def get_arxiv_ids(input_file: str) -> List[str]: return arxiv_ids -def process_single_arxiv_id(arxiv_id: str, config: Config) -> Tuple[str, Any]: +def process_single_arxiv_id(arxiv_id: str, config: Config, with_year_limit: bool) -> Tuple[str, Any]: """Processes a single arXiv ID, handling any errors gracefully.""" try: paper_data = get_paper_data(arxiv_id) return arxiv_id, { 'paper_data': paper_data, 'author_data': get_author_data( - arxiv_id, paper_data['authors'], paper_data['title'], config + arxiv_id, paper_data['authors'], paper_data['title'], config, with_year_limit=with_year_limit, ), 'reference_proposal': get_proposal_from_paper( arxiv_id, paper_data['introduction'], config @@ -56,7 +56,7 @@ def save_benchmark_data(data: Dict[str, Any], output: str) -> None: def process_arxiv_ids( - arxiv_ids: List[str], output: str, config: Config, num_processes: int + arxiv_ids: List[str], output: str, config: Config, num_processes: int, with_year_limit: bool ) -> None: """Processes arXiv IDs using multiprocessing, saving results after each batch.""" arxiv_ids_chunks = [ @@ -69,14 +69,14 @@ def process_arxiv_ids( if num_processes == 1: # Single-process mode results = [ - process_single_arxiv_id(arxiv_id, config) for arxiv_id in chunk + process_single_arxiv_id(arxiv_id, config, with_year_limit) for arxiv_id in chunk ] else: # Multiprocessing mode with Pool(processes=num_processes) as pool: results = pool.starmap( process_single_arxiv_id, - [(arxiv_id, config) for arxiv_id in chunk], + [(arxiv_id, config, with_year_limit) for arxiv_id in chunk], ) # Filter out None results and save data @@ -101,6 +101,11 @@ def parse_args() -> argparse.Namespace: default=1, help='Number of processes to use. Set to 1 for single-process mode. Default is based on available CPU cores.', ) + parser.add_argument( + '--with_year_limit', + action='store_true', + help='Limit the number of papers to those published within the same year as the input paper.', + ) return parser.parse_args() @@ -108,7 +113,7 @@ def main() -> None: args = parse_args() arxiv_ids = get_arxiv_ids(args.input) config = Config('../configs') - process_arxiv_ids(arxiv_ids, args.output, config, args.num_processes) + process_arxiv_ids(arxiv_ids, args.output, config, args.num_processes, args.with_year_limit) if __name__ == '__main__': diff --git a/research_bench/create_oodbench.sh b/research_bench/create_oodbench.sh index 197cd43f..e8bf82e5 100644 --- a/research_bench/create_oodbench.sh +++ b/research_bench/create_oodbench.sh @@ -1 +1 @@ -python create_bench_from_paper_links.py --input ./oodbench/oodbench_paper_links.txt --output ./oodbench/oodbench_1202.json +python create_bench_from_paper_links.py --input ./oodbench/oodbench_paper_links.txt --output ./oodbench/oodbench_1202.json --with_year_limit diff --git a/research_bench/eval_only.py b/research_bench/eval_only.py index 3f48c631..c146ec7d 100644 --- a/research_bench/eval_only.py +++ b/research_bench/eval_only.py @@ -90,12 +90,16 @@ def plot_sorted_metrics(metric1, metric2): # file1_path = './results/paper_bench_result_4o_mini_fake_research_town.jsonl' # file2_path = './results/paper_bench_result_4o_mini_citation_only_part1.jsonl' - file1_path = './results/cross_bench_1202_result_4o_mini_fake_research_town.jsonl' - file2_path = './results/cross_bench_1202_result_4o_mini_fake_research_town_twice.jsonl' + #file1_path = './results/cross_bench_1202_result_4o_mini_fake_research_town.jsonl' + #file2_path = './results/cross_bench_1202_result_4o_mini_fake_research_town_twice.jsonl' #file1_path = './results/paper_bench_mid_500_result_4o_mini_fake_research_town.jsonl' #file2_path = './results/paper_bench_mid_500_result_4o_mini_fake_research_town_twice.jsonl' + file1_path = './results/paper_bench_hard_500_result_4o_mini_fake_research_town.jsonl' + file2_path = './results/paper_bench_hard_500_result_4o_mini_swarm.jsonl' + + print("Finding shared paper_ids...") shared_ids = get_shared_ids(file1_path, file2_path) print(f"Number of shared paper_ids: {len(shared_ids)}") diff --git a/research_bench/utils.py b/research_bench/utils.py index e5ca8aad..53f0cc46 100644 --- a/research_bench/utils.py +++ b/research_bench/utils.py @@ -142,11 +142,15 @@ def get_proposal_from_paper(arxiv_id: str, intro: str, config: Config) -> str: @with_cache(cache_dir='author_data') def get_author_data( - arxiv_id: str, authors: List[str], title: str, config: Config + arxiv_id: str, authors: List[str], title: str, config: Config, with_year_limit: bool = False ) -> Dict[str, Any]: + if with_year_limit: + before_year = int('20' + arxiv_id.split('.')[0][:2]) + else: + before_year = None profile_db = ProfileDB(config.database) profile_pks = profile_db.pull_profiles( - names=authors, config=config, known_paper_titles=[title] + names=authors, config=config, known_paper_titles=[title], before_year=before_year ) author_data = {} for pk in profile_pks: diff --git a/research_town/dbs/db_profile.py b/research_town/dbs/db_profile.py index 360a23a3..168c47d0 100644 --- a/research_town/dbs/db_profile.py +++ b/research_town/dbs/db_profile.py @@ -37,13 +37,14 @@ def pull_profiles( names: List[str], config: Config, known_paper_titles: Optional[List[str]] = None, + before_year: Optional[int] = None, ) -> List[str]: profiles: List[Profile] = [] for name in names: try: pub_abstracts, pub_titles, collaborators = ( collect_publications_and_coauthors( - name, paper_max_num=20, known_paper_titles=known_paper_titles + name, paper_max_num=20, known_paper_titles=known_paper_titles, before_year=before_year ) ) logger.info(f'Collected publications for {name}: {pub_titles}') diff --git a/research_town/utils/profile_collector.py b/research_town/utils/profile_collector.py index 71c8f2a6..ab1ccbac 100644 --- a/research_town/utils/profile_collector.py +++ b/research_town/utils/profile_collector.py @@ -33,7 +33,7 @@ def match_author_ids( search_results = semantic_client.search_author( author_name, fields=['authorId', 'papers.title'], - limit=100, + limit=250, ) author_ids = set() @@ -82,56 +82,25 @@ def get_papers_from_author_id( return papers[:paper_max_num] else: # Filter papers based on the year - filtered_papers = [ - paper for paper in papers if paper.get('year', datetime.now().year) < before_year - ] + filtered_papers = [] + for paper in papers: + if paper['year'] is None: + paper['year'] = 2024 + if paper['year'] < before_year: + filtered_papers.append(paper) return filtered_papers[:paper_max_num] -@api_calling_error_exponential_backoff(retries=6, base_wait_time=1) -def get_paper_publish_year(paper_title: str, author_id: str) -> int: - url = "https://api.semanticscholar.org/graph/v1/paper/search" - params = { - "query": paper_title, - "fields": "title,year,authors", - "author_ids": author_id # Include the author ID as a parameter - } - response = requests.get(url, params=params) - - if response.status_code == 200: - data = response.json() - if data.get("data"): - # Find the first matching paper by the given author - for paper in data["data"]: - if any(author.get("authorId") == author_id for author in paper.get("authors", [])): - return paper.get("year", None) - print("No matching papers found for the given author.") - return None - else: - print("No matching papers found.") - return None - else: - print(f"Error: {response.status_code}, {response.text}") - return None - def collect_publications_and_coauthors( author: str, known_paper_titles: Optional[List[str]] = None, paper_max_num: int = 20, exclude_known: bool = True, + before_year: Optional[int] = None, ) -> Tuple[List[str], List[str], List[str]]: matched_author_ids = match_author_ids(author, known_paper_titles) author_id = matched_author_ids.pop() # Only one author ID is expected - before_year = None - if known_paper_titles is not None and len(known_paper_titles) > 0: - years = [get_paper_publish_year(paper_title=title, author_id=author_id) for title in known_paper_titles] - years = [year for year in years if year is not None] - if len(years) > 0: - before_year = min(year for year in years if year is not None) - else: - before_year = None - papers = get_papers_from_author_id(author_id, paper_max_num, before_year=before_year) paper_abstracts = [] paper_titles = []