Skip to content

Commit

Permalink
fix the before_year mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lwaekfjlk committed Dec 3, 2024
1 parent 530f394 commit 31742c7
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 51 deletions.
17 changes: 11 additions & 6 deletions research_bench/create_bench_from_paper_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def get_arxiv_ids(input_file: str) -> List[str]:
return arxiv_ids


def process_single_arxiv_id(arxiv_id: str, config: Config) -> Tuple[str, Any]:
def process_single_arxiv_id(arxiv_id: str, config: Config, with_year_limit: bool) -> Tuple[str, Any]:
"""Processes a single arXiv ID, handling any errors gracefully."""
try:
paper_data = get_paper_data(arxiv_id)
return arxiv_id, {
'paper_data': paper_data,
'author_data': get_author_data(
arxiv_id, paper_data['authors'], paper_data['title'], config
arxiv_id, paper_data['authors'], paper_data['title'], config, with_year_limit=with_year_limit,
),
'reference_proposal': get_proposal_from_paper(
arxiv_id, paper_data['introduction'], config
Expand All @@ -56,7 +56,7 @@ def save_benchmark_data(data: Dict[str, Any], output: str) -> None:


def process_arxiv_ids(
arxiv_ids: List[str], output: str, config: Config, num_processes: int
arxiv_ids: List[str], output: str, config: Config, num_processes: int, with_year_limit: bool
) -> None:
"""Processes arXiv IDs using multiprocessing, saving results after each batch."""
arxiv_ids_chunks = [
Expand All @@ -69,14 +69,14 @@ def process_arxiv_ids(
if num_processes == 1:
# Single-process mode
results = [
process_single_arxiv_id(arxiv_id, config) for arxiv_id in chunk
process_single_arxiv_id(arxiv_id, config, with_year_limit) for arxiv_id in chunk
]
else:
# Multiprocessing mode
with Pool(processes=num_processes) as pool:
results = pool.starmap(
process_single_arxiv_id,
[(arxiv_id, config) for arxiv_id in chunk],
[(arxiv_id, config, with_year_limit) for arxiv_id in chunk],
)

# Filter out None results and save data
Expand All @@ -101,14 +101,19 @@ def parse_args() -> argparse.Namespace:
default=1,
help='Number of processes to use. Set to 1 for single-process mode. Default is based on available CPU cores.',
)
parser.add_argument(
'--with_year_limit',
action='store_true',
help='Limit the number of papers to those published within the same year as the input paper.',
)
return parser.parse_args()


def main() -> None:
args = parse_args()
arxiv_ids = get_arxiv_ids(args.input)
config = Config('../configs')
process_arxiv_ids(arxiv_ids, args.output, config, args.num_processes)
process_arxiv_ids(arxiv_ids, args.output, config, args.num_processes, args.with_year_limit)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion research_bench/create_oodbench.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python create_bench_from_paper_links.py --input ./oodbench/oodbench_paper_links.txt --output ./oodbench/oodbench_1202.json
python create_bench_from_paper_links.py --input ./oodbench/oodbench_paper_links.txt --output ./oodbench/oodbench_1202.json --with_year_limit
8 changes: 6 additions & 2 deletions research_bench/eval_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,16 @@ def plot_sorted_metrics(metric1, metric2):
# file1_path = './results/paper_bench_result_4o_mini_fake_research_town.jsonl'
# file2_path = './results/paper_bench_result_4o_mini_citation_only_part1.jsonl'

file1_path = './results/cross_bench_1202_result_4o_mini_fake_research_town.jsonl'
file2_path = './results/cross_bench_1202_result_4o_mini_fake_research_town_twice.jsonl'
#file1_path = './results/cross_bench_1202_result_4o_mini_fake_research_town.jsonl'
#file2_path = './results/cross_bench_1202_result_4o_mini_fake_research_town_twice.jsonl'

#file1_path = './results/paper_bench_mid_500_result_4o_mini_fake_research_town.jsonl'
#file2_path = './results/paper_bench_mid_500_result_4o_mini_fake_research_town_twice.jsonl'

file1_path = './results/paper_bench_hard_500_result_4o_mini_fake_research_town.jsonl'
file2_path = './results/paper_bench_hard_500_result_4o_mini_swarm.jsonl'


print("Finding shared paper_ids...")
shared_ids = get_shared_ids(file1_path, file2_path)
print(f"Number of shared paper_ids: {len(shared_ids)}")
Expand Down
8 changes: 6 additions & 2 deletions research_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,15 @@ def get_proposal_from_paper(arxiv_id: str, intro: str, config: Config) -> str:

@with_cache(cache_dir='author_data')
def get_author_data(
arxiv_id: str, authors: List[str], title: str, config: Config
arxiv_id: str, authors: List[str], title: str, config: Config, with_year_limit: bool = False
) -> Dict[str, Any]:
if with_year_limit:
before_year = int('20' + arxiv_id.split('.')[0][:2])
else:
before_year = None
profile_db = ProfileDB(config.database)
profile_pks = profile_db.pull_profiles(
names=authors, config=config, known_paper_titles=[title]
names=authors, config=config, known_paper_titles=[title], before_year=before_year
)
author_data = {}
for pk in profile_pks:
Expand Down
3 changes: 2 additions & 1 deletion research_town/dbs/db_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def pull_profiles(
names: List[str],
config: Config,
known_paper_titles: Optional[List[str]] = None,
before_year: Optional[int] = None,
) -> List[str]:
profiles: List[Profile] = []
for name in names:
try:
pub_abstracts, pub_titles, collaborators = (
collect_publications_and_coauthors(
name, paper_max_num=20, known_paper_titles=known_paper_titles
name, paper_max_num=20, known_paper_titles=known_paper_titles, before_year=before_year
)
)
logger.info(f'Collected publications for {name}: {pub_titles}')
Expand Down
47 changes: 8 additions & 39 deletions research_town/utils/profile_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def match_author_ids(
search_results = semantic_client.search_author(
author_name,
fields=['authorId', 'papers.title'],
limit=100,
limit=250,
)

author_ids = set()
Expand Down Expand Up @@ -82,56 +82,25 @@ def get_papers_from_author_id(
return papers[:paper_max_num]
else:
# Filter papers based on the year
filtered_papers = [
paper for paper in papers if paper.get('year', datetime.now().year) < before_year
]
filtered_papers = []
for paper in papers:
if paper['year'] is None:
paper['year'] = 2024
if paper['year'] < before_year:
filtered_papers.append(paper)
return filtered_papers[:paper_max_num]


@api_calling_error_exponential_backoff(retries=6, base_wait_time=1)
def get_paper_publish_year(paper_title: str, author_id: str) -> int:
url = "https://api.semanticscholar.org/graph/v1/paper/search"
params = {
"query": paper_title,
"fields": "title,year,authors",
"author_ids": author_id # Include the author ID as a parameter
}
response = requests.get(url, params=params)

if response.status_code == 200:
data = response.json()
if data.get("data"):
# Find the first matching paper by the given author
for paper in data["data"]:
if any(author.get("authorId") == author_id for author in paper.get("authors", [])):
return paper.get("year", None)
print("No matching papers found for the given author.")
return None
else:
print("No matching papers found.")
return None
else:
print(f"Error: {response.status_code}, {response.text}")
return None

def collect_publications_and_coauthors(
author: str,
known_paper_titles: Optional[List[str]] = None,
paper_max_num: int = 20,
exclude_known: bool = True,
before_year: Optional[int] = None,
) -> Tuple[List[str], List[str], List[str]]:
matched_author_ids = match_author_ids(author, known_paper_titles)
author_id = matched_author_ids.pop() # Only one author ID is expected

before_year = None
if known_paper_titles is not None and len(known_paper_titles) > 0:
years = [get_paper_publish_year(paper_title=title, author_id=author_id) for title in known_paper_titles]
years = [year for year in years if year is not None]
if len(years) > 0:
before_year = min(year for year in years if year is not None)
else:
before_year = None

papers = get_papers_from_author_id(author_id, paper_max_num, before_year=before_year)
paper_abstracts = []
paper_titles = []
Expand Down

0 comments on commit 31742c7

Please sign in to comment.