diff --git a/bin/bank/pycbc_brute_bank b/bin/bank/pycbc_brute_bank index 58cfcc9ff19..0fd9709cd4e 100644 --- a/bin/bank/pycbc_brute_bank +++ b/bin/bank/pycbc_brute_bank @@ -26,6 +26,7 @@ import argparse import pickle import numpy.random from scipy.stats import gaussian_kde +import multiprocessing import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions from pycbc import transforms @@ -75,6 +76,7 @@ parser.add_argument('--tau0-crawl', type=float) parser.add_argument('--tau0-start', type=float) parser.add_argument('--tau0-end', type=float) parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0) +parser.add_argument('--num_cores', type=int, default=1, help='number of cores for parallelization.') pycbc.psd.insert_psd_option_group(parser) args = parser.parse_args() @@ -247,20 +249,52 @@ class TriangleBank(object): def check_params(self, gen, params, threshold): num_tried = 0 num_added = 0 - for i in range(len(tuple(params.values())[0])): - num_tried += 1.0 - - try: - hp = gen.generate(**{key:params[key][i] for key in params}) - except Exception as err: - print(err) - continue - - hp.gen = gen - hp.threshold = threshold - if hp not in self: - num_added += 1 - self.insert(hp) + total_num = len(tuple(params.values())[0]) + + if args.num_cores == 1: + for i in range(total_num): + num_tried += 1.0 + + try: + hp = gen.generate(**{key:params[key][i] for key in params}) + except Exception as err: + print(err) + continue + + hp.gen = gen + hp.threshold = threshold + if hp not in self: + num_added += 1 + self.insert(hp) + + # use multiprocessing + elif args.num_cores > 1: + waveform_cache = [] + with multiprocessing.Pool(args.num_cores) as pool: + for return_wf in pool.imap_unordered( + wf_wrapper, + ({k: params[k][idx] for k in params} for idx in range(total_num)) + ): + waveform_cache += [return_wf] + + for i in range(total_num): + num_tried += 1 + hp = waveform_cache[i] + + if isinstance(hp, pycbc.types.FrequencySeries): + hp.gen = gen + hp.threshold = threshold + if hp not in self: + num_added += 1 + self.insert(hp) + elif numpy.isnan(hp): + logging.info("Waveform generation failed!") + continue + else: + raise ValueError("Waveform generation problem.") + + else: + raise ValueError("Number of cores should be a positive integer.") return bank, num_added / float(num_tried) @@ -337,6 +371,14 @@ gen = GenUniformWaveform(args.buffer_length, args.sample_rate, args.low_frequency_cutoff) bank = TriangleBank() +if args.num_cores > 1: + def wf_wrapper(p): + try: + hp = gen.generate(**p) + return hp + except Exception: + return numpy.nan + if args.input_file: f = h5py.File(args.input_file, 'r') params = {k: f[k][:] for k in f}