Skip to content

Commit

Permalink
UPD: add generator for couples
Browse files Browse the repository at this point in the history
  • Loading branch information
vkhodygo committed Feb 23, 2023
1 parent 3c7d12f commit 98b74b4
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/mortality_module/synthesizer/abstract/base_synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import random
import uuid
from abc import ABC, abstractmethod
from typing import Tuple, final

import numpy as np
import random
import pandas as pd

from mortality_module.synthesizer.constants import COUNTRY_MAP, SEX_MAP

from typing import Tuple, List, final


class Synthesizer(ABC):
def __init__(self, seed: int = 1337):
Expand Down Expand Up @@ -85,7 +85,9 @@ def extract_subset(self,
else:
hh_match = self._df[household_column_name] == hh_codes

self._data = self._df[hh_match][list(column_names)].reset_index(drop=True)
self._data = self._df[hh_match][list(column_names)]. \
reset_index(drop=True)


@abstractmethod
def augment_data(self):
Expand All @@ -95,6 +97,7 @@ def augment_data(self):
def generate_new_population(self):
pass

@final
def data_preprocessing(self):
self._df['COUNTRY'] = self._df['COUNTRY'] \
.replace(4, 3) \
Expand All @@ -106,7 +109,6 @@ def data_preprocessing(self):
.replace(SEX_MAP)
self._df['AGE'] = self._df['AGE'].astype(int)


def generate_hh_id(self, ss: int) -> list[uuid.UUID, ...]:
"""Generates unique household ids.
Expand All @@ -124,4 +126,8 @@ def generate_hh_id(self, ss: int) -> list[uuid.UUID, ...]:

@final
def cancel_changes(self):
self._data = None
self._data = None

@staticmethod
def _validate_household_size(dataset):
pass
106 changes: 106 additions & 0 deletions src/mortality_module/synthesizer/couples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import itertools as it
import math as m

import numpy as np
import pandas as pd
from tqdm import tqdm

from mortality_module.synthesizer.abstract.base_synthesizer import Synthesizer
from mortality_module.synthesizer.sanitizer import Sanitizer
from mortality_module.synthesizer.utils import data_range


class UKCouplesHH(Synthesizer):
def __init__(self, seed: int = 13371):
super().__init__(seed)

def run_sanity_checks(self):
bad_ids = self._validate_household_size(self._data)
if len(bad_ids) > 0:
print("""Households with inconsistent number of people have been
found, filtering them out.""")
self._data = self._data[~self._data['HSERIALP'].isin(bad_ids)]

def augment_data(self) -> None:
self._data = pd.pivot_table(self._data,
values=['AGE', 'PHHWT14'],
index=['HSERIALP', 'COUNTRY', 'hhtype6'],
columns=['SEX']). \
reset_index(). \
drop(columns=('PHHWT14', 'f'))

new_columns = [s1 if s2 == '' else s1 + '_' + str(s2) for (s1, s2) in
self._data.columns.tolist()]
self._data.columns = new_columns

self._data['HH_W'] = self._data['PHHWT14_m']
self._data.drop(columns='PHHWT14_m', inplace=True)
self._data.rename(columns={"AGE_f": "f", "AGE_m": "m", "HH_W": "w"},
inplace=True)

def generate_new_population(self) -> pd.DataFrame:
self.data_preprocessing()
self.extract_subset(('COUNTRY', 'SEX', 'AGE', 'PHHWT14', 'HSERIALP',
'hhtype6'),
(3, 4),
'hhtype6')
self.run_sanity_checks()
self.augment_data()
return self.populate_couples()

@staticmethod
def _validate_household_size(dataset):
"""Ensures that every household is composed of exactly two people."""
return Sanitizer.household_size(dataset, 'HSERIALP', 2)

def populate_couples(self) -> pd.DataFrame:
all_data : list = []

for (country_, hh_type) in tqdm(it.product(('e', 'w', 's', 'ni'),
(3, 4))):
t = self._data[(self._data['COUNTRY'] == country_) &
(self._data['hhtype6'] == hh_type)]

num_bins_f, range_f = data_range(t['f'])
num_bins_m, range_m = data_range(t['m'])

dist, ages_f, ages_m = np.histogram2d(t['f'],
t['m'],
bins=[num_bins_f, num_bins_m],
range=[range_f, range_m],
weights=t['w'],
density=True)

assert m.fsum(dist.flatten()) == 1, \
'Probabilities must add up to 1.'

total_sample_households = int(t['w'].sum())

linear_dist = dist.flatten()
sample_index = np.random.choice(a=linear_dist.size,
p=linear_dist,
size=total_sample_households)
index_ = np.unravel_index(sample_index, dist.shape)

ids = self.generate_hh_id(total_sample_households)

all_data.append(pd.DataFrame(data={'f': ages_f[index_[0]],
'm': ages_m[index_[1]],
'COUNTRY': country_,
'HH_ID': ids,
'HH_TYPE': hh_type}))


result = pd.concat(all_data,
ignore_index=True).melt(
id_vars=['COUNTRY', 'HH_ID', 'HH_TYPE'], value_vars=['f', 'm'],
var_name='SEX', value_name='AGE')
result['AGE'] = result['AGE'].astype(int)
result['HH_TYPE'] = result['HH_TYPE'].astype(int)

return result.sort_values(by=['HH_ID'])

if __name__ == "__main__":
ukchh = UKCouplesHH()
ukchh.read_data(input())
ukchh.generate_new_population().to_csv('couples.csv', index=False)

0 comments on commit 98b74b4

Please sign in to comment.