-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #194 from Urban-Analytics/better-opencl-seeding
OpenCL has same initial case seeding logic
- Loading branch information
Showing
9 changed files
with
369 additions
and
99 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# This is a workaround to allow multiprocessing.Pool to work in the pf_experiments_plots notebook. | ||
# The function called by pool.map ('count_wiggles') needs to be defined in this separate file and imported. | ||
# https://stackoverflow.com/questions/41385708/multiprocessing-example-giving-attributeerror/42383397 | ||
import os | ||
import multiprocessing | ||
import itertools # TEMP | ||
|
||
from microsim.opencl.ramp.snapshot import Snapshot | ||
from microsim.opencl.ramp.simulator import Simulator | ||
from microsim.opencl.ramp.run import run_headless | ||
|
||
|
||
def run_opencl_model_multiprocess(*args): | ||
#*al_i, l_snapshot_filepath, l_params, l_opencl_dir, l_num_seed_days, l_use_gpu): | ||
try: | ||
with multiprocessing.Pool(processes=int(os.cpu_count())) as pool: | ||
results = pool.starmap(_run_opencl_model, zip(*args)) | ||
#results = itertools.starmap(_run_opencl_model, zip(*args)) | ||
return results | ||
|
||
finally: # Make sure they get closed (shouldn't be necessary) | ||
pool.close() | ||
|
||
|
||
def _run_opencl_model(i, iterations, snapshot_filepath, params, opencl_dir, num_seed_days, use_gpu, | ||
store_detailed_counts=True): | ||
|
||
# load snapshot | ||
snapshot = Snapshot.load_full_snapshot(path=snapshot_filepath) | ||
|
||
# set params | ||
snapshot.update_params(params) | ||
|
||
# set the random seed of the model for each repetition, otherwise it is completely deterministic | ||
snapshot.seed_prngs(i) | ||
|
||
# Create a simulator and upload the snapshot data to the OpenCL device | ||
simulator = Simulator(snapshot, opencl_dir=opencl_dir, gpu=use_gpu, num_seed_days=num_seed_days) | ||
simulator.upload_all(snapshot.buffers) | ||
|
||
print(f"Running simulation {i+1}.") | ||
summary, final_state = run_headless(simulator, snapshot, iterations, quiet=True, | ||
store_detailed_counts=store_detailed_counts) | ||
return summary, final_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import os | ||
|
||
|
||
class InitialCases: | ||
def __init__(self, area_codes, not_home_probs, data_dir="microsim/opencl/data/"): | ||
""" | ||
This class loads the initial cases data for seeding infections in the model. | ||
Once the data is loaded it selects the people from higher risk area codes who | ||
spend more time outside of their home. | ||
""" | ||
|
||
# load initial case data | ||
self.initial_cases = pd.read_csv(os.path.join(data_dir, "devon_initial_cases.csv")) | ||
|
||
msoa_risks_df = pd.read_csv(os.path.join(data_dir, "msoas.csv"), usecols=[1, 2]) | ||
|
||
# combine into a single dataframe to allow easy filtering based on high risk area codes and | ||
# not home probabilities | ||
people_df = pd.DataFrame({"area_code": area_codes, | ||
"not_home_prob": not_home_probs}) | ||
people_df = people_df.merge(msoa_risks_df, on="area_code") | ||
|
||
# get people_ids for people in high risk MSOAs and high not home probability | ||
self.high_risk_ids = np.where((people_df["risk"] == "High") & (people_df["not_home_prob"] > 0.3))[0] | ||
|
||
def get_seed_people_ids_for_day(self, day): | ||
"""Randomly choose a given number of people ids from the high risk people""" | ||
|
||
num_cases = self.initial_cases.loc[day, "num_cases"] | ||
if num_cases > self.high_risk_ids.shape[0]: # if there aren't enough high risk individuals then return all of them | ||
return self.high_risk_ids | ||
|
||
selected_ids = np.random.choice(self.high_risk_ids, num_cases, replace=False) | ||
|
||
# remove people from high_risk_ids so they are not chosen again | ||
self.high_risk_ids = np.setdiff1d(self.high_risk_ids, selected_ids) | ||
|
||
return selected_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters