diff --git a/pearl/utils/uci_data.py b/pearl/utils/uci_data.py index 03074218..bb315fbb 100644 --- a/pearl/utils/uci_data.py +++ b/pearl/utils/uci_data.py @@ -29,27 +29,27 @@ def download_uci_data(data_path: str) -> None: """ for dataset_name in uci_urls.keys(): - url = uci_urls[dataset_name]["url"] + "/" + uci_urls[dataset_name]["file_name"] - filename = os.path.join(data_path, uci_urls[dataset_name]["file_name"]) + unzipped_dataset_dirpath = os.path.join(data_path, dataset_name) + zip_filepath = os.path.join(data_path, uci_urls[dataset_name]["file_name"]) # Download the zip file + url = uci_urls[dataset_name]["url"] + "/" + uci_urls[dataset_name]["file_name"] response = requests_get(url) if response.status_code != 200: raise Exception(f"Failed to download {dataset_name} dataset from {url}.") # Locally save the zip file - with open(filename, "wb") as f: + with open(zip_filepath, "wb") as f: f.write(response.content) # Unzip the file - unzip_filepath = os.path.join(data_path, dataset_name) try: - with zipfile.ZipFile(filename, "r") as z: - z.extractall(unzip_filepath) + with zipfile.ZipFile(zip_filepath, "r") as z: + z.extractall(unzipped_dataset_dirpath) except zipfile.BadZipFile: raise zipfile.BadZipFile( - f"Bad zip file: {filename}. Please delete corrupt file and run again." + f"Bad zip file: {zip_filepath}. Please delete corrupt file and run again." ) # Delete the zip file - os.remove(filename) + os.remove(zip_filepath)