From 98b74b460a841b575e6c3c21c9cc37df26c29557 Mon Sep 17 00:00:00 2001 From: Vladimir Khodygo Date: Thu, 23 Feb 2023 16:43:32 +0000 Subject: [PATCH] UPD: add generator for couples --- .../synthesizer/abstract/base_synthesizer.py | 18 ++- src/mortality_module/synthesizer/couples.py | 106 ++++++++++++++++++ 2 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 src/mortality_module/synthesizer/couples.py diff --git a/src/mortality_module/synthesizer/abstract/base_synthesizer.py b/src/mortality_module/synthesizer/abstract/base_synthesizer.py index a978a55..8312250 100644 --- a/src/mortality_module/synthesizer/abstract/base_synthesizer.py +++ b/src/mortality_module/synthesizer/abstract/base_synthesizer.py @@ -1,13 +1,13 @@ +import random import uuid from abc import ABC, abstractmethod +from typing import Tuple, final + import numpy as np -import random import pandas as pd from mortality_module.synthesizer.constants import COUNTRY_MAP, SEX_MAP -from typing import Tuple, List, final - class Synthesizer(ABC): def __init__(self, seed: int = 1337): @@ -85,7 +85,9 @@ def extract_subset(self, else: hh_match = self._df[household_column_name] == hh_codes - self._data = self._df[hh_match][list(column_names)].reset_index(drop=True) + self._data = self._df[hh_match][list(column_names)]. \ + reset_index(drop=True) + @abstractmethod def augment_data(self): @@ -95,6 +97,7 @@ def augment_data(self): def generate_new_population(self): pass + @final def data_preprocessing(self): self._df['COUNTRY'] = self._df['COUNTRY'] \ .replace(4, 3) \ @@ -106,7 +109,6 @@ def data_preprocessing(self): .replace(SEX_MAP) self._df['AGE'] = self._df['AGE'].astype(int) - def generate_hh_id(self, ss: int) -> list[uuid.UUID, ...]: """Generates unique household ids. @@ -124,4 +126,8 @@ def generate_hh_id(self, ss: int) -> list[uuid.UUID, ...]: @final def cancel_changes(self): - self._data = None \ No newline at end of file + self._data = None + + @staticmethod + def _validate_household_size(dataset): + pass \ No newline at end of file diff --git a/src/mortality_module/synthesizer/couples.py b/src/mortality_module/synthesizer/couples.py new file mode 100644 index 0000000..420e39c --- /dev/null +++ b/src/mortality_module/synthesizer/couples.py @@ -0,0 +1,106 @@ +import itertools as it +import math as m + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from mortality_module.synthesizer.abstract.base_synthesizer import Synthesizer +from mortality_module.synthesizer.sanitizer import Sanitizer +from mortality_module.synthesizer.utils import data_range + + +class UKCouplesHH(Synthesizer): + def __init__(self, seed: int = 13371): + super().__init__(seed) + + def run_sanity_checks(self): + bad_ids = self._validate_household_size(self._data) + if len(bad_ids) > 0: + print("""Households with inconsistent number of people have been + found, filtering them out.""") + self._data = self._data[~self._data['HSERIALP'].isin(bad_ids)] + + def augment_data(self) -> None: + self._data = pd.pivot_table(self._data, + values=['AGE', 'PHHWT14'], + index=['HSERIALP', 'COUNTRY', 'hhtype6'], + columns=['SEX']). \ + reset_index(). \ + drop(columns=('PHHWT14', 'f')) + + new_columns = [s1 if s2 == '' else s1 + '_' + str(s2) for (s1, s2) in + self._data.columns.tolist()] + self._data.columns = new_columns + + self._data['HH_W'] = self._data['PHHWT14_m'] + self._data.drop(columns='PHHWT14_m', inplace=True) + self._data.rename(columns={"AGE_f": "f", "AGE_m": "m", "HH_W": "w"}, + inplace=True) + + def generate_new_population(self) -> pd.DataFrame: + self.data_preprocessing() + self.extract_subset(('COUNTRY', 'SEX', 'AGE', 'PHHWT14', 'HSERIALP', + 'hhtype6'), + (3, 4), + 'hhtype6') + self.run_sanity_checks() + self.augment_data() + return self.populate_couples() + + @staticmethod + def _validate_household_size(dataset): + """Ensures that every household is composed of exactly two people.""" + return Sanitizer.household_size(dataset, 'HSERIALP', 2) + + def populate_couples(self) -> pd.DataFrame: + all_data : list = [] + + for (country_, hh_type) in tqdm(it.product(('e', 'w', 's', 'ni'), + (3, 4))): + t = self._data[(self._data['COUNTRY'] == country_) & + (self._data['hhtype6'] == hh_type)] + + num_bins_f, range_f = data_range(t['f']) + num_bins_m, range_m = data_range(t['m']) + + dist, ages_f, ages_m = np.histogram2d(t['f'], + t['m'], + bins=[num_bins_f, num_bins_m], + range=[range_f, range_m], + weights=t['w'], + density=True) + + assert m.fsum(dist.flatten()) == 1, \ + 'Probabilities must add up to 1.' + + total_sample_households = int(t['w'].sum()) + + linear_dist = dist.flatten() + sample_index = np.random.choice(a=linear_dist.size, + p=linear_dist, + size=total_sample_households) + index_ = np.unravel_index(sample_index, dist.shape) + + ids = self.generate_hh_id(total_sample_households) + + all_data.append(pd.DataFrame(data={'f': ages_f[index_[0]], + 'm': ages_m[index_[1]], + 'COUNTRY': country_, + 'HH_ID': ids, + 'HH_TYPE': hh_type})) + + + result = pd.concat(all_data, + ignore_index=True).melt( + id_vars=['COUNTRY', 'HH_ID', 'HH_TYPE'], value_vars=['f', 'm'], + var_name='SEX', value_name='AGE') + result['AGE'] = result['AGE'].astype(int) + result['HH_TYPE'] = result['HH_TYPE'].astype(int) + + return result.sort_values(by=['HH_ID']) + +if __name__ == "__main__": + ukchh = UKCouplesHH() + ukchh.read_data(input()) + ukchh.generate_new_population().to_csv('couples.csv', index=False)