Skip to content

Commit

Permalink
Merge pull request #177 from UCL/age_dependent_sex_balance
Browse files Browse the repository at this point in the history
Age dependent sex balance
  • Loading branch information
mmcleod89 authored May 20, 2024
2 parents 9ddce88 + 7c56af2 commit a6f3021
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 53 deletions.
8 changes: 4 additions & 4 deletions src/hivpy/column_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
LTP_AGE_GROUP = "ltp_age_group" # int: discrete age group for starting / longevity of ltp

RISK = "risk" # float: overall risk value from combined factors
RISK_AGE = "risk_age" # float: risk reduction factor based on age
RISK_ADC = "risk_adc" # float: risk reduction for AIDS defining condition
RISK_BALANCE = "risk_balance" # float: risk reduction factor to re-balance male & female partner numbers
RISK_DIAGNOSIS = "risk_diagnosis" # float: risk reduction associated with recent HIV diagnosis
RISK_AGE = "risk_age" # float: risk factor based on age
RISK_ADC = "risk_adc" # float: risk for AIDS defining condition
RISK_BALANCE = "risk_balance" # float: risk factor to re-balance male & female partner numbers
RISK_DIAGNOSIS = "risk_diagnosis" # float: risk associated with recent HIV diagnosis
RISK_PERSONAL = "risk_personal" # float: individual risk reduction applied with a certain probability
RISK_LTP = "risk_long_term_partnered" # float: risk reduction for people in long term partnerships
RISK_ART_ADHERENCE = "risk_art_adherence" # float: risk reduction associated with low ART adherence
Expand Down
55 changes: 41 additions & 14 deletions src/hivpy/output.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import itertools
import math
import operator
from itertools import product

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -63,6 +65,10 @@ def _init_df(self, start_date, stop_date, time_step):
"Non-HIV deaths (over 15, female)", "Non-HIV deaths (20-59, male)",
"Non-HIV deaths (20-59, female)"]

for (age, sex) in product([15, 25, 35, 45, 55], (SexType.Male, SexType.Female)):
key = f"Short term partners ({age}-{age+9}, {sex})"
output_columns.insert(17, key)

for age_bound in range(self.age_min, self.age_max, self.age_step):
if age_bound < self.age_max_active:
# inserted after 'Partner sex balance (female)'
Expand Down Expand Up @@ -139,7 +145,10 @@ def _update_HIV_prevalence(self, pop: Population):

# Update HIV prevalence in female sex workers
sex_workers_idx = pop.get_sub_pop([(col.SEX_WORKER, operator.eq, True)])
self.output_stats.loc[self.step, "Sex worker (ratio)"] = self._ratio(sex_workers_idx, women_idx)
potential_sw = pop.get_sub_pop([(col.SEX, operator.eq, SexType.Female),
(col.AGE, operator.ge, 15),
(col.AGE, operator.lt, 50)])
self.output_stats.loc[self.step, "Sex worker (ratio)"] = self._ratio(sex_workers_idx, potential_sw)
self.output_stats.loc[self.step, "HIV prevalence (sex worker)"] = (
self._ratio(pop.get_sub_pop_intersection(sex_workers_idx, HIV_pos_idx), sex_workers_idx))

Expand Down Expand Up @@ -247,6 +256,18 @@ def _update_partners(self, pop: Population):
(col.AGE, operator.lt, 65),
(col.NUM_PARTNERS, operator.ge, 1)])
self.output_stats.loc[self.step, "Short term partners (15-64)"] = self._ratio(stp_idx, age_idx)

# Proportion of people with at least one short term partner
for (age, sex) in product([15, 25, 35, 45, 55], (SexType.Male, SexType.Female)):
self.output_stats.loc[self.step, f"Short term partners ({age}-{age+9}, {sex})"] = \
self._ratio(pop.get_sub_pop([(col.AGE, operator.ge, age),
(col.AGE, operator.lt, age+10),
(col.SEX, operator.eq, sex),
(col.NUM_PARTNERS, operator.ge, 1)]),
pop.get_sub_pop([(col.AGE, operator.ge, age),
(col.AGE, operator.lt, age+10),
(col.SEX, operator.eq, sex)]))

# Update proportion of people with 5+ short term partners
stp_over_5_idx = pop.get_sub_pop([(col.AGE, operator.ge, 15),
(col.AGE, operator.lt, 65),
Expand All @@ -257,16 +278,20 @@ def _update_partner_sex_balance(self, pop: Population):
# Update short term partner sex balance statistics
men_idx = pop.get_sub_pop([(col.SEX, operator.eq, SexType.Male)])
women_idx = pop.get_sub_pop([(col.SEX, operator.eq, SexType.Female)])
active_idx = pop.get_sub_pop([(col.NUM_PARTNERS, operator.gt, 0)])
active_idx = pop.get_sub_pop([(col.AGE, operator.ge, 15),
(col.AGE, operator.lt, 65),
(col.NUM_PARTNERS, operator.gt, 0)])
active_men = pop.get_sub_pop_intersection(active_idx, men_idx)
active_women = pop.get_sub_pop_intersection(active_idx, women_idx)
# Get flattened lists of partner age groups (values 0-4)
women_stp_age_list = pop.get_variable(col.STP_AGE_GROUPS, active_women).values
women_stp_age_list = (np.concatenate(women_stp_age_list).ravel() if len(women_stp_age_list) > 0
else women_stp_age_list).tolist()
men_stp_age_list = pop.get_variable(col.STP_AGE_GROUPS, active_men).values
men_stp_age_list = (np.concatenate(men_stp_age_list).ravel() if len(men_stp_age_list) > 0
else men_stp_age_list).tolist()

def get_partners_in_groups(stp_of_sex):
a, f = np.unique(list(itertools.chain.from_iterable(stp_of_sex)), return_counts=True)
return dict(zip(a, f))
male_stp_in_age_groups = get_partners_in_groups(women_stp_age_list)
female_stp_in_age_groups = get_partners_in_groups(men_stp_age_list)

# FIXME: should we log all ratios here or have this step happen in post?
# NOTE: sum type converted from numpy.int64
Expand All @@ -280,22 +305,24 @@ def _update_partner_sex_balance(self, pop: Population):
# Update short term partner sex balance statistics by age group
for age_bound in range(self.age_min, self.age_max_active, self.age_step):
age_group = int(age_bound/10)-1
age_idx = pop.get_sub_pop([(col.AGE, operator.ge, age_bound),
(col.AGE, operator.lt, age_bound+self.age_step)])
men_of_age = pop.get_sub_pop_intersection(age_idx, active_men)
women_of_age = pop.get_sub_pop_intersection(age_idx, active_women)

key = f"Partner sex balance ({age_bound}-{age_bound+(self.age_step-1)}, male)"
# Count occurrences of current age group
women_stp_num = women_stp_age_list.count(age_group)
n_male_stp = male_stp_in_age_groups.get(age_group)
if n_male_stp is None:
n_male_stp = 0
self.output_stats.loc[self.step, key] = self._log(
self._ratio(int(pop.get_variable(col.NUM_PARTNERS, men_of_age).sum()), women_stp_num))
pop.sexual_behaviour.num_stp_in_age_sex_group[age_group][SexType.Male] /
pop.sexual_behaviour.num_stp_of_age_sex_group[age_group][SexType.Male])

key = f"Partner sex balance ({age_bound}-{age_bound+(self.age_step-1)}, female)"
# Count occurrences of current age group
men_stp_num = men_stp_age_list.count(age_group)
n_female_stp = female_stp_in_age_groups.get(age_group)
if n_female_stp is None:
n_female_stp = 0
self.output_stats.loc[self.step, key] = self._log(
self._ratio(int(pop.get_variable(col.NUM_PARTNERS, women_of_age).sum()), men_stp_num))
pop.sexual_behaviour.num_stp_in_age_sex_group[age_group][SexType.Female] /
pop.sexual_behaviour.num_stp_of_age_sex_group[age_group][SexType.Female])

def _update_births(self, pop: Population, time_step):
# Update total births
Expand Down
34 changes: 31 additions & 3 deletions src/hivpy/sexual_behaviour.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib.resources
import logging
import operator
from enum import IntEnum
from typing import TYPE_CHECKING
Expand All @@ -10,9 +11,13 @@

import hivpy.column_names as col

from .common import AND, COND, SexType, date, diff_years, rng, timedelta
from .common import (AND, COND, SexType, date, diff_years, opposite_sex, rng,
timedelta)
from .sex_behaviour_data import SexualBehaviourData

# import warnings


if TYPE_CHECKING:
from .population import Population

Expand Down Expand Up @@ -118,6 +123,10 @@ def __init__(self, **kwargs):
self.balance_thresholds = [0.1, 0.03, 0.005, 0.004, 0.003, 0.002, 0.001]
self.balance_factors = [0.1, 0.7, 0.7, 0.75, 0.8, 0.9, 0.97]
self.p_risk_p = self.sb_data.p_risk_p_dist.sample()
# Number of short term partners of people in a demographic group by age and sex
self.num_stp_of_age_sex_group = np.zeros([self.num_sex_mix_groups, 2])
# Number of short term partners who themselves are in a demographic group by age and sex
self.num_stp_in_age_sex_group = np.zeros([self.num_sex_mix_groups, 2])

# long term partnerships parameters
self.new_ltp_rate = 0.1 * np.exp(rng.normal() * 0.25) # three month ep-rate
Expand Down Expand Up @@ -154,7 +163,7 @@ def init_sex_behaviour(self, population: Population):
population.init_variable(col.LTP_AGE_GROUP, 0)
population.init_variable(col.LTP_LONGEVITY, 0)
population.init_variable(col.SEX_MIX_AGE_GROUP, 0)
population.init_variable(col.STP_AGE_GROUPS, np.array([[0]]*population.size))
population.init_variable(col.STP_AGE_GROUPS, [np.array([])]*population.size)
population.init_variable(col.RISK_LTP, 1)
population.init_variable(col.LIFE_SEX_RISK, 1)
population.init_variable(col.SEX_WORKER, False)
Expand All @@ -178,6 +187,7 @@ def update_sex_behaviour(self, population: Population):
self.update_sex_groups(population)
self.num_short_term_partners(population)
self.assign_stp_ages(population)
self.update_sex_age_balance(population)
self.update_long_term_partners(population)

# Code for sex work ---------------------------------------------------------------------------
Expand Down Expand Up @@ -564,14 +574,20 @@ def update_risk_balance(self, population: Population):
population.set_present_variable(col.RISK_BALANCE, 1/risk_balance, men)

def gen_stp_ages(self, sex, age_group, num_partners, size):
# TODO: Check if this needs additional balancing factors for age
stp_age_probs = self.sex_mixing_matrix[sex][age_group]
stp_age_groups = rng.choice(self.num_sex_mix_groups, [size, num_partners], p=stp_age_probs)
self.num_stp_of_age_sex_group[age_group][sex] += (num_partners * size)
for i in stp_age_groups.flatten():
self.num_stp_in_age_sex_group[i][opposite_sex(sex)] += 1
return list(stp_age_groups) # dataframe won't accept a 2D numpy array

def assign_stp_ages(self, population: Population):
"""Calculate the ages of a persons short term partners
from the mixing matrices."""
# reset stp age/sex counts
self.num_stp_in_age_sex_group = np.zeros([self.num_sex_mix_groups, 2])
self.num_stp_of_age_sex_group = np.zeros([self.num_sex_mix_groups, 2])

population.set_present_variable(col.SEX_MIX_AGE_GROUP,
(np.digitize(population.get_variable(col.AGE),
self.sex_mix_age_groups) - 1))
Expand All @@ -581,6 +597,18 @@ def assign_stp_ages(self, population: Population):
self.gen_stp_ages,
sub_pop=active_pop)

def update_sex_age_balance(self, population: Population):
def get_ratio(sex, age):
if (self.num_stp_of_age_sex_group[age][sex] > 0):
ratio = self.num_stp_in_age_sex_group[age][sex] / self.num_stp_of_age_sex_group[age][sex]
logging.info(f"Ratio (sex, age): {sex}, {age} = {ratio}\n")
return ratio
else:
return 1
for age_group in range(self.risk_categories+1):
for sex in [0, 1]:
self.age_based_risk[age_group, sex] = self.age_based_risk[age_group, sex] * get_ratio(sex, age_group//2)

# Code for long term partnerships -------------------------------------------------------------

def update_ltp_rate_change(self, date):
Expand Down
24 changes: 1 addition & 23 deletions src/tests/test_output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import isclose, log
from math import isclose

import hivpy.column_names as col
from hivpy.common import SexType, date, timedelta
Expand Down Expand Up @@ -64,25 +64,3 @@ def test_HIV_incidence():
age_group = int(age_bound/10)-1
key = f"HIV incidence ({age_bound}-{age_bound+(age_step-1)}, female)"
assert isclose(out.output_stats[key], 1-age_group*0.25)


def test_partner_sex_balance():

# build population
N = 100
pop = Population(size=N, start_date=date(1990, 1, 1))
pop.data.loc[:int(N/2)-1, col.SEX] = SexType.Female
pop.data.loc[int(N/2):, col.SEX] = SexType.Male
pop.data[col.AGE] = 25
pop.data[col.NUM_PARTNERS] = 2
pop.data[col.STP_AGE_GROUPS] = [[1, 2]]*N

out = SimulationOutput(date(1990, 1, 1), date(1990, 3, 1), timedelta(days=90))
out._update_partner_sex_balance(pop)

# check overall sex balance is equal
assert isclose(out.output_stats["Partner sex balance (male)"], 0)
assert isclose(out.output_stats["Partner sex balance (female)"], 0)
# check age-specific ratio is 2
assert isclose(out.output_stats["Partner sex balance (25-34, male)"], log(2, 10))
assert isclose(out.output_stats["Partner sex balance (25-34, female)"], log(2, 10))
Loading

0 comments on commit a6f3021

Please sign in to comment.