diff --git a/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py b/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py index 74f11a4a..de9c1207 100644 --- a/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py +++ b/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py @@ -271,18 +271,20 @@ def run_cb_benchmarks( run_config: dictionary with config files of the run parameters. """ - # Download uci datasets if dont exist + # Create UCI data directory if it does not already exist uci_data_path = "./utils/instantiations/environments/uci_datasets" if not os.path.exists(uci_data_path): os.makedirs(uci_data_path) - download_uci_data(data_path=uci_data_path) - # Path to save results + # Download UCI data + download_uci_data(data_path=uci_data_path) + + # Create folder for result if it does not already exist save_results_path: str = "./utils/scripts/cb_benchmark/experiments_results" if not os.path.exists(save_results_path): os.makedirs(save_results_path) - # run all CB algorithms on all benchmarks + # Run all CB algorithms on all benchmarks for algorithm in cb_algorithms_config.keys(): for dataset_name in test_environments_config.keys(): env = SLCBEnvironment(**test_environments_config[dataset_name]) diff --git a/pearl/utils/uci_data.py b/pearl/utils/uci_data.py index 0c8bd6ed..03074218 100644 --- a/pearl/utils/uci_data.py +++ b/pearl/utils/uci_data.py @@ -29,19 +29,27 @@ def download_uci_data(data_path: str) -> None: """ for dataset_name in uci_urls.keys(): - url = os.path.join( - uci_urls[dataset_name]["url"], uci_urls[dataset_name]["file_name"] - ) + url = uci_urls[dataset_name]["url"] + "/" + uci_urls[dataset_name]["file_name"] filename = os.path.join(data_path, uci_urls[dataset_name]["file_name"]) # Download the zip file + response = requests_get(url) + if response.status_code != 200: + raise Exception(f"Failed to download {dataset_name} dataset from {url}.") + + # Locally save the zip file with open(filename, "wb") as f: - f.write(requests_get(url).content) + f.write(response.content) # Unzip the file unzip_filepath = os.path.join(data_path, dataset_name) - with zipfile.ZipFile(filename, "r") as z: - z.extractall(unzip_filepath) + try: + with zipfile.ZipFile(filename, "r") as z: + z.extractall(unzip_filepath) + except zipfile.BadZipFile: + raise zipfile.BadZipFile( + f"Bad zip file: {filename}. Please delete corrupt file and run again." + ) # Delete the zip file os.remove(filename)