Skip to content

Commit

Permalink
add multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
yi-fan-wang committed Jun 27, 2024
1 parent 534e7ef commit 15cc266
Showing 1 changed file with 56 additions and 14 deletions.
70 changes: 56 additions & 14 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import argparse
import pickle
import numpy.random
from scipy.stats import gaussian_kde
import multiprocessing

import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
from pycbc import transforms
Expand Down Expand Up @@ -75,6 +76,7 @@ parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--num_cores', type=int, default=1, help='number of cores for parallelization.')
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -247,20 +249,52 @@ class TriangleBank(object):
def check_params(self, gen, params, threshold):
num_tried = 0
num_added = 0
for i in range(len(tuple(params.values())[0])):
num_tried += 1.0

try:
hp = gen.generate(**{key:params[key][i] for key in params})
except Exception as err:
print(err)
continue

hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
total_num = len(tuple(params.values())[0])

if args.num_cores == 1:
for i in range(total_num):
num_tried += 1.0

try:
hp = gen.generate(**{key:params[key][i] for key in params})
except Exception as err:
print(err)
continue

hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)

# use multiprocessing
elif args.num_cores > 1:
waveform_cache = []
with multiprocessing.Pool(args.num_cores) as pool:
for return_wf in pool.imap_unordered(
wf_wrapper,
({k: params[k][idx] for k in params} for idx in range(total_num))
):
waveform_cache += [return_wf]

for i in range(total_num):
num_tried += 1
hp = waveform_cache[i]

if isinstance(hp, pycbc.types.FrequencySeries):
hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
elif numpy.isnan(hp):
logging.info("Waveform generation failed!")
continue
else:
raise ValueError("Waveform generation problem.")

else:
raise ValueError("Number of cores should be a positive integer.")

return bank, num_added / float(num_tried)

Expand Down Expand Up @@ -337,6 +371,14 @@ gen = GenUniformWaveform(args.buffer_length,
args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

if args.num_cores > 1:
def wf_wrapper(p):
try:
hp = gen.generate(**p)
return hp
except Exception:
return numpy.nan

if args.input_file:
f = h5py.File(args.input_file, 'r')
params = {k: f[k][:] for k in f}
Expand Down

0 comments on commit 15cc266

Please sign in to comment.