Skip to content

Commit

Permalink
fix: refactor fillna lambda function with a function that also apply …
Browse files Browse the repository at this point in the history
…fillna of pandas to solve multiple nan encodings; creating directory for cached data when not available.
  • Loading branch information
baraldian committed Apr 15, 2024
1 parent 1762722 commit 97bfde3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
19 changes: 11 additions & 8 deletions folktables/acs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def get_definitions(self, download=False):
return load_definitions(root_dir=self._root_dir, year=self._survey_year, horizon=self._horizon,
download=download)

def fillna_safe(x, value=-1):
x = np.nan_to_num(x, value)
return pd.DataFrame(x).fillna(value=value).values

def adult_filter(data):
"""Mimic the filters in place for Adult data.
Expand Down Expand Up @@ -98,7 +101,7 @@ def adult_filter(data):
target_transform=lambda x: x > 50000,
group='RAC1P',
preprocess=adult_filter,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)

ACSEmployment = folktables.BasicProblem(
Expand All @@ -124,7 +127,7 @@ def adult_filter(data):
target_transform=lambda x: x == 1,
group='RAC1P',
preprocess=lambda x: x,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)

ACSHealthInsurance = folktables.BasicProblem(
Expand Down Expand Up @@ -159,7 +162,7 @@ def adult_filter(data):
target_transform=lambda x: x == 1,
group='RAC1P',
preprocess=lambda x: x,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)

def public_coverage_filter(data):
Expand Down Expand Up @@ -197,7 +200,7 @@ def public_coverage_filter(data):
target_transform=lambda x: x == 1,
group='RAC1P',
preprocess=public_coverage_filter,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)

def travel_time_filter(data):
Expand Down Expand Up @@ -233,7 +236,7 @@ def travel_time_filter(data):
target_transform=lambda x: x > 20,
group='RAC1P',
preprocess=travel_time_filter,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)


Expand Down Expand Up @@ -274,7 +277,7 @@ def mobility_filter(data):
target_transform=lambda x: x == 1,
group='RAC1P',
preprocess=mobility_filter,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)

def employment_filter(data):
Expand Down Expand Up @@ -311,7 +314,7 @@ def employment_filter(data):
target_transform=lambda x: x == 1,
group='RAC1P',
preprocess=employment_filter,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)

ACSIncomePovertyRatio = folktables.BasicProblem(
Expand Down Expand Up @@ -341,5 +344,5 @@ def employment_filter(data):
target_transform=lambda x: x < 250,
group='RAC1P',
preprocess=lambda x: x,
postprocess=lambda x: np.nan_to_num(x, -1),
postprocess=fillna_safe,
)
1 change: 1 addition & 0 deletions folktables/load_acs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_definitions(root_dir, year=2018, horizon='1-Year', download=False):
year_string = year if horizon == '1-Year' else f'{year - 4}-{year}'
url = f'https://www2.census.gov/programs-surveys/acs/tech_docs/pums/data_dict/PUMS_Data_Dictionary_{year_string}.csv'

os.makedirs(base_datadir, exist_ok=True)
response = requests.get(url)
with open(file_path, 'wb') as handle:
handle.write(response.content)
Expand Down

0 comments on commit 97bfde3

Please sign in to comment.