-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added ci-fix benchmark
- Loading branch information
Showing
7 changed files
with
1,892 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from omegaconf import OmegaConf | ||
import os | ||
import pandas as pd | ||
from datasets import load_dataset | ||
import time | ||
import json | ||
from tqdm import tqdm | ||
|
||
from benhmark_functions import process_datapoint, get_results | ||
from benchmark_utils import read_jsonl, save_jsonl | ||
|
||
|
||
def filter_files(directory, files): | ||
return [file for file in files if file != "meta_info.json"] | ||
|
||
|
||
class CIFixBenchmark: | ||
def __init__(self, model_name, config_path, token_gh): | ||
# languages = ["Python", "Kotlin", "Rust", "C++", "Java"] | ||
benchmark_owner = "LCA-CI-fix-benchmark" | ||
self.config = OmegaConf.load(config_path) | ||
language = self.config.language | ||
self.credentials = {"username": self.config.username, "token": token_gh, "model": model_name} | ||
# TODO parents=True (??) | ||
os.makedirs(self.config.out_folder, exist_ok=True) | ||
os.makedirs(self.config.repos_folder, exist_ok=True) | ||
self.dataset_id = f"JetBrains-Research/lca-ci-fixing" | ||
OmegaConf.update(self.config, "benchmark_owner", benchmark_owner, force_add=True) | ||
if hasattr(self.config, "data_cache_dir"): | ||
self.cache_dir = self.config.data_cache_dir | ||
else: | ||
self.cache_dir = None | ||
self.model_name = model_name | ||
|
||
def get_dataset(self, hf_token=None, num_dp=None, force_download=False, dataset_folder=None): | ||
# TODO remove hf_token when dataset becomes public | ||
|
||
if dataset_folder is not None: | ||
self.dataset = load_dataset(path=dataset_folder)["train"] | ||
return self.dataset | ||
if force_download: | ||
download_mode = "force_redownload" | ||
else: | ||
download_mode = None | ||
self.dataset = load_dataset( | ||
self.dataset_id, token=hf_token, cache_dir=self.cache_dir, download_mode=download_mode, split="test" | ||
) | ||
if num_dp is not None: | ||
self.dataset = self.dataset.select(range(num_dp)) | ||
|
||
return self.dataset | ||
|
||
# TODO remove test_dataset argument after debug | ||
def run_dataset(self, fix_repo_function, test_dataset=None): | ||
if test_dataset is None: | ||
test_dataset = self.dataset | ||
self.jobs_ids = [] | ||
jobs_ids_file_path = os.path.join(self.config.out_folder, f"jobs_ids_{self.model_name}.jsonl") | ||
with open(jobs_ids_file_path, "w") as writer: | ||
for datapoint in tqdm(test_dataset): | ||
job_identificator = process_datapoint(datapoint, fix_repo_function, self.config, self.credentials) | ||
self.jobs_ids.append(job_identificator) | ||
json.dump(job_identificator, writer) | ||
writer.write("\n") | ||
return self.jobs_ids | ||
|
||
# TODO remove jobs_ids argument after debug | ||
def eval_jobs(self, jobs_ids=None, job_ids_file=None, result_filename=None): | ||
if result_filename is None: | ||
result_filename = f"jobs_results_{self.model_name}.jsonl" | ||
# Maybe we need to make some pause | ||
jobs_results_file_path = os.path.join(self.config.out_folder, result_filename) | ||
jobs_awaiting_file_path = os.path.join(self.config.out_folder, f"jobs_awaiting_{self.model_name}.jsonl") | ||
jobs_invalid_file_path = os.path.join(self.config.out_folder, f"jobs_invalid_{self.model_name}.jsonl") | ||
result_file = open(jobs_results_file_path, "w") | ||
if job_ids_file is not None: | ||
jobs_ids = read_jsonl(job_ids_file) | ||
elif jobs_ids is None: | ||
jobs_ids = self.jobs_ids | ||
jobs_ids_await = jobs_ids | ||
n_attempts = 0 | ||
jobs_results = [] | ||
jobs_ids_invalid = [] | ||
# TODO discuss number of attempts and waiting time | ||
while len(jobs_ids_await) > 0 and n_attempts < 12: | ||
jobs_ids_await_new = [] | ||
for job_id in jobs_ids_await: | ||
job_url, conclusion = get_results(job_id, self.config, self.credentials) | ||
if conclusion == "waiting": | ||
jobs_ids_await_new.append(job_id) | ||
elif conclusion == "error": | ||
jobs_ids_invalid.append(job_id) | ||
else: | ||
job_id["url"] = job_url | ||
job_id["conclusion"] = conclusion | ||
jobs_results.append(job_id) | ||
json.dump(job_id, result_file) | ||
result_file.write("\n") | ||
|
||
jobs_ids_await = jobs_ids_await_new | ||
if len(jobs_ids_await) != 0: | ||
result_file.close() | ||
save_jsonl(jobs_awaiting_file_path, jobs_ids_await) | ||
save_jsonl(jobs_invalid_file_path, jobs_ids_invalid) | ||
print(f"Waiting 300 s to next request of evaluation. {len(jobs_ids_await)} jobs in waiting list.") | ||
time.sleep(300) | ||
result_file = open(jobs_results_file_path, "a") | ||
|
||
n_attempts += 1 | ||
|
||
result_file.close() | ||
print("Results received") | ||
print(f"{len(jobs_results)} jobs in results.") | ||
print(f"{len(jobs_ids_await)} jobs left in waiting list.") | ||
print(f"{len(jobs_ids_invalid)} jobs are invalid.") | ||
self.jobs_results = jobs_results | ||
return jobs_results | ||
|
||
def analyze_results(self, jobs_results=None, jobs_results_file=None): | ||
if jobs_results_file is not None: | ||
jobs_results = read_jsonl(jobs_results_file) | ||
elif jobs_results is None: | ||
jobs_results = self.jobs_ids | ||
|
||
results_df = pd.DataFrame(jobs_results) | ||
# %% | ||
total_counts = results_df["conclusion"].value_counts() | ||
total_ratio = total_counts / len(results_df) | ||
difficulty_counts = results_df.groupby("difficulty")["conclusion"].value_counts().unstack().fillna(0) | ||
difficulty_ratios = difficulty_counts.div(difficulty_counts.sum(axis=1), axis=0) | ||
|
||
print("Overall results") | ||
print(total_counts) | ||
print("Overall results in ratio") | ||
print(total_ratio) | ||
print("Results aggregated by difficulties") | ||
print(difficulty_counts) | ||
print("Results in ratios aggregated by difficulties") | ||
print(difficulty_ratios) | ||
|
||
def eval_dataset( | ||
self, | ||
fix_repo_function, | ||
hf_token=None, | ||
num_dp=None, | ||
force_download=False, | ||
result_filename=None, | ||
dataset_folder=None, | ||
): | ||
print("---------------- Downloading data -------------------") | ||
self.get_dataset(hf_token, num_dp=num_dp, force_download=force_download, dataset_folder=dataset_folder) | ||
print(f"Got {len(self.dataset)} datapoints") | ||
print("---------------- Running datapoints -------------------") | ||
self.run_dataset(fix_repo_function) | ||
print("---------------- Getting results -------------------") | ||
self.eval_jobs(result_filename=result_filename) | ||
self.analyze_results() | ||
|
||
def run_datapoint(self, datapoint, fix_repo_function): | ||
# This method is for debugging reasons | ||
jobs_ids_file_path = os.path.join(self.config.out_folder, f"jobs_ids_{self.model_name}.jsonl") | ||
with open(jobs_ids_file_path, "w") as writer: | ||
job_identificator = process_datapoint(datapoint, fix_repo_function, self.config, self.credentials) | ||
json.dump(job_identificator, writer) | ||
writer.write("\n") | ||
return job_identificator | ||
|
||
def eval_datapoint(self, job_identificator): | ||
# This method is for debugging reasons | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
repos_folder: /mnt/data/shared-data/lca/CI-fix-benchmark/repos # here the cloned repos would be stored | ||
out_folder: /mnt/data/galimzyanov/data/LCA/benchmark/out # here the result files would be stored | ||
data_cache_dir: /mnt/data/galimzyanov/data/LCA/temp # here the cached dataset would be stored | ||
username: timur-for-test # your GitHub username | ||
test_username: test_user # username that would be displayed in the benchmark | ||
language: Python # dataset language (now only Python is available) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from omegaconf import OmegaConf | ||
import json | ||
import shutil | ||
import os | ||
|
||
|
||
def read_jsonl(file_path): | ||
data = [] | ||
with open(file_path, "r") as f: | ||
for line in f: | ||
data.append(json.loads(line)) | ||
return data | ||
|
||
def save_jsonl(file_path, data): | ||
with open(file_path, "w") as f: | ||
for entry in data: | ||
json.dump(entry, f) | ||
f.write("\n") | ||
|
||
def get_token_gh(config_path): | ||
config_private = OmegaConf.load(config_path) | ||
with open(config_private.token_gh_path) as f: | ||
token_gh = f.read() | ||
return token_gh | ||
|
||
|
||
def get_token_hf(config_path): | ||
config_private = OmegaConf.load(config_path) | ||
token_hf = get_token(config_private.token_hf_path) | ||
return token_hf | ||
|
||
def get_token(token_path): | ||
with open(token_path) as f: | ||
token = f.read() | ||
|
||
return token | ||
|
||
def filter_out_res(data_folder, out_folder): | ||
""" | ||
filter acording of results benchmarks | ||
""" | ||
results_none_path = os.path.join(out_folder, "jobs_results_none.jsonl") | ||
results_diff_path = os.path.join(out_folder, "jobs_results_diff.jsonl") | ||
results_none = read_jsonl(results_none_path) | ||
results_diff = read_jsonl(results_diff_path) | ||
orig_path = os.path.join(data_folder, "datapoints_json_verified") | ||
filtered_path = os.path.join(data_folder, "datapoints_json_filtered") | ||
os.makedirs(filtered_path, exist_ok=True) | ||
original_sha = {result["sha_original"][:7] for result in results_none if result["conclusion"] == "failure"} | ||
fixed_sha = {result["sha_original"][:7] for result in results_diff if result["conclusion"] == "success"} | ||
|
||
sha_valid = original_sha.intersection(fixed_sha) | ||
|
||
for sha in sha_valid: | ||
dp_file = os.path.join(orig_path, f"{sha}.json") | ||
dp_filtered = os.path.join(filtered_path, f"{sha}.json") | ||
shutil.copy2(dp_file, dp_filtered) |
Oops, something went wrong.