From da8b860a9a270b2439fa220e5e090e477c278e83 Mon Sep 17 00:00:00 2001 From: Andrea Date: Fri, 25 Nov 2022 11:55:47 +0100 Subject: [PATCH 1/3] fix/feat: Improved loading function using pandas only instead of StrinIO --- folktables/load_acs.py | 65 ++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/folktables/load_acs.py b/folktables/load_acs.py index dee48d4..e353a51 100644 --- a/folktables/load_acs.py +++ b/folktables/load_acs.py @@ -8,14 +8,12 @@ import numpy as np import pandas as pd - state_list = ['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA', 'HI', 'ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ', 'NM', 'NY', 'NC', 'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT', 'VT', 'VA', 'WA', 'WV', 'WI', 'WY', 'PR'] - _STATE_CODES = {'AL': '01', 'AK': '02', 'AZ': '04', 'AR': '05', 'CA': '06', 'CO': '08', 'CT': '09', 'DE': '10', 'FL': '12', 'GA': '13', 'HI': '15', 'ID': '16', 'IL': '17', 'IN': '18', 'IA': '19', @@ -35,10 +33,10 @@ def download_and_extract(url, datadir, remote_fname, file_name, delete_download= response = requests.get(url) with open(download_path, 'wb') as handle: handle.write(response.content) - + with zipfile.ZipFile(download_path, 'r') as zip_ref: zip_ref.extract(file_name, path=datadir) - + if delete_download and download_path != os.path.join(datadir, file_name): os.remove(download_path) @@ -57,24 +55,26 @@ def initialize_and_download(datadir, state, year, horizon, survey, download=Fals else: # 2016 and earlier use different file names file_name = f'ss{str(year)[-2:]}{survey_code}{state.lower()}.csv' - + # Assume is the path exists and is a file, then it has been downloaded # correctly file_path = os.path.join(datadir, file_name) if os.path.isfile(file_path): return file_path if not download: - raise FileNotFoundError(f'Could not find {year} {horizon} {survey} survey data for {state} in {datadir}. Call get_data with download=True to download the dataset.') - + raise FileNotFoundError( + f'Could not find {year} {horizon} {survey} survey data for {state} in {datadir}. Call get_data with download=True to download the dataset.') + print(f'Downloading data for {year} {horizon} {survey} survey for {state}...') # Download and extract file - base_url= f'https://www2.census.gov/programs-surveys/acs/data/pums/{year}/{horizon}' + base_url = f'https://www2.census.gov/programs-surveys/acs/data/pums/{year}/{horizon}' remote_fname = f'csv_{survey_code}{state.lower()}.zip' url = f'{base_url}/{remote_fname}' try: download_and_extract(url, datadir, remote_fname, file_name, delete_download=True) except Exception as e: - print(f'\n{os.path.join(datadir, remote_fname)} may be corrupted. Please try deleting it and rerunning this command.\n') + print( + f'\n{os.path.join(datadir, remote_fname)} may be corrupted. Please try deleting it and rerunning this command.\n') print(f'Exception: ', e) return file_path @@ -99,50 +99,27 @@ def load_acs(root_dir, states=None, year=2018, horizon='1-Year', if states is None: states = state_list - + random.seed(random_seed) - + base_datadir = os.path.join(root_dir, str(year), horizon) os.makedirs(base_datadir, exist_ok=True) - + file_names = [] for state in states: file_names.append( initialize_and_download(base_datadir, state, year, horizon, survey, download=download) ) - sample = io.StringIO() - - first = True - + dtypes = {'PINCP': np.float64, 'RT': str, 'SOCP': str, 'SERIALNO': str, 'NAICSP': str} + df_list = [] for file_name in file_names: - - with open(file_name, 'r') as f: - - if first: - sample.write(next(f)) - first = False - else: - next(f) - - if serial_filter_list is None: - for line in f: - if random.uniform(0, 1) < density: - # strip whitespace found in some early files - sample.write(line.replace(' ','')) - else: - for line in f: - serialno = line.split(',')[1] - if serialno in serial_filter_list: - # strip whitespace found in some early files - sample.write(line.replace(' ','')) - - - sample.seek(0) - - dtypes = {'PINCP' : np.float64, 'RT' : str, 'SOCP' : str, 'SERIALNO' : str, 'NAICSP' : str} - - return pd.read_csv(sample, dtype=dtypes) + df = pd.read_csv(file_name, dtype=dtypes).replace(' ','') + if serial_filter_list is not None: + df = df[df['SERIALNO'].isin(serial_filter_list)] + df_list.append(df) + all_df = pd.concat(df_list) + return all_df def load_definitions(root_dir, year=2018, horizon='1-Year', download=False): @@ -214,4 +191,4 @@ def generate_categories(features, definition_df): del mapping_dict[-99999999999999.0] categories[feature] = mapping_dict - return categories \ No newline at end of file + return categories From 8f99ce381a5daae387daac0c234f95a52a7d0c52 Mon Sep 17 00:00:00 2001 From: Andrea Baraldi Date: Thu, 2 Feb 2023 15:12:35 +0100 Subject: [PATCH 2/3] add: mobility filter to speed up filter change pandas loading engine='c' --- folktables/acs.py | 12 +++++++++++- folktables/load_acs.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/folktables/acs.py b/folktables/acs.py index d5d142b..f6d6d03 100644 --- a/folktables/acs.py +++ b/folktables/acs.py @@ -236,6 +236,16 @@ def travel_time_filter(data): postprocess=lambda x: np.nan_to_num(x, -1), ) + +def mobility_filter(data): + """ + Filters for the employment prediction task + """ + df = data + df = df[df['AGEP'] > 18] + df = df[df['AGEP'] < 35] + return df + ACSMobility = folktables.BasicProblem( features=[ 'AGEP', @@ -263,7 +273,7 @@ def travel_time_filter(data): target="MIG", target_transform=lambda x: x == 1, group='RAC1P', - preprocess=lambda x: x.drop(x.loc[(x['AGEP'] <= 18) | (x['AGEP'] >= 35)].index), + preprocess=mobility_filter, postprocess=lambda x: np.nan_to_num(x, -1), ) diff --git a/folktables/load_acs.py b/folktables/load_acs.py index e353a51..acd609a 100644 --- a/folktables/load_acs.py +++ b/folktables/load_acs.py @@ -114,7 +114,7 @@ def load_acs(root_dir, states=None, year=2018, horizon='1-Year', dtypes = {'PINCP': np.float64, 'RT': str, 'SOCP': str, 'SERIALNO': str, 'NAICSP': str} df_list = [] for file_name in file_names: - df = pd.read_csv(file_name, dtype=dtypes).replace(' ','') + df = pd.read_csv(file_name, dtype=dtypes, engine="c").replace(' ','') if serial_filter_list is not None: df = df[df['SERIALNO'].isin(serial_filter_list)] df_list.append(df) From 97bfde320560a85813a5087f585fcb6e6e7ffd05 Mon Sep 17 00:00:00 2001 From: Andrea Date: Mon, 15 Apr 2024 15:23:08 +0200 Subject: [PATCH 3/3] fix: refactor fillna lambda function with a function that also apply fillna of pandas to solve multiple nan encodings; creating directory for cached data when not available. --- folktables/acs.py | 19 +++++++++++-------- folktables/load_acs.py | 1 + 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/folktables/acs.py b/folktables/acs.py index f6d6d03..dedd209 100644 --- a/folktables/acs.py +++ b/folktables/acs.py @@ -65,6 +65,9 @@ def get_definitions(self, download=False): return load_definitions(root_dir=self._root_dir, year=self._survey_year, horizon=self._horizon, download=download) +def fillna_safe(x, value=-1): + x = np.nan_to_num(x, value) + return pd.DataFrame(x).fillna(value=value).values def adult_filter(data): """Mimic the filters in place for Adult data. @@ -98,7 +101,7 @@ def adult_filter(data): target_transform=lambda x: x > 50000, group='RAC1P', preprocess=adult_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) ACSEmployment = folktables.BasicProblem( @@ -124,7 +127,7 @@ def adult_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=lambda x: x, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) ACSHealthInsurance = folktables.BasicProblem( @@ -159,7 +162,7 @@ def adult_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=lambda x: x, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) def public_coverage_filter(data): @@ -197,7 +200,7 @@ def public_coverage_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=public_coverage_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) def travel_time_filter(data): @@ -233,7 +236,7 @@ def travel_time_filter(data): target_transform=lambda x: x > 20, group='RAC1P', preprocess=travel_time_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) @@ -274,7 +277,7 @@ def mobility_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=mobility_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) def employment_filter(data): @@ -311,7 +314,7 @@ def employment_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=employment_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) ACSIncomePovertyRatio = folktables.BasicProblem( @@ -341,5 +344,5 @@ def employment_filter(data): target_transform=lambda x: x < 250, group='RAC1P', preprocess=lambda x: x, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) diff --git a/folktables/load_acs.py b/folktables/load_acs.py index 9464a71..f954a84 100644 --- a/folktables/load_acs.py +++ b/folktables/load_acs.py @@ -144,6 +144,7 @@ def load_definitions(root_dir, year=2018, horizon='1-Year', download=False): year_string = year if horizon == '1-Year' else f'{year - 4}-{year}' url = f'https://www2.census.gov/programs-surveys/acs/tech_docs/pums/data_dict/PUMS_Data_Dictionary_{year_string}.csv' + os.makedirs(base_datadir, exist_ok=True) response = requests.get(url) with open(file_path, 'wb') as handle: handle.write(response.content)