Skip to content

Commit

Permalink
add multiprocessing to pycbc_brute_bank (gwastro#4803)
Browse files Browse the repository at this point in the history
* add multiprocessing

* be more specific

* github review (1) combine the redundant code (2) some cleaning

* implement pycbc pool to remove redundant codes for multiprocessing and for single core

* remove multiprocessing lib
  • Loading branch information
yi-fan-wang authored Jul 2, 2024
1 parent 7786c1a commit d993ca1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
46 changes: 31 additions & 15 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import numpy.random
from scipy.stats import gaussian_kde

import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
import pycbc.pool
from pycbc import transforms
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc.distributions import read_params_from_config
Expand Down Expand Up @@ -75,6 +76,8 @@ parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--nprocesses', type=int, default=1,
help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.')
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -245,24 +248,30 @@ class TriangleBank(object):
mmax = m

def check_params(self, gen, params, threshold):
num_tried = 0
num_added = 0
for i in range(len(tuple(params.values())[0])):
num_tried += 1.0

try:
hp = gen.generate(**{key:params[key][i] for key in params})
except Exception as err:
print(err)
total_num = len(tuple(params.values())[0])
waveform_cache = []

pool = pycbc.pool.choose_pool(args.nprocesses)
for return_wf in pool.imap_unordered(
wf_wrapper,
({k: params[k][idx] for k in params} for idx in range(total_num))
):
waveform_cache += [return_wf]
del pool

for hp in waveform_cache:
if hp is not None:
hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
else:
logging.info("Waveform generation failed!")
continue

hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)

return bank, num_added / float(num_tried)
return bank, num_added / total_num

class GenUniformWaveform(object):
def __init__(self, buffer_length, sample_rate, f_lower):
Expand Down Expand Up @@ -337,6 +346,13 @@ gen = GenUniformWaveform(args.buffer_length,
args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

def wf_wrapper(p):
try:
hp = gen.generate(**p)
return hp
except Exception:
return None

if args.input_file:
f = h5py.File(args.input_file, 'r')
params = {k: f[k][:] for k in f}
Expand Down
1 change: 1 addition & 0 deletions pycbc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def map(self, f, items):
# so that the general pool interfaces can use
# imap irrespective of the pool type.
imap = map
imap_unordered = map

def use_mpi(require_mpi=False, log=True):
""" Get whether MPI is enabled and if so the current size and rank
Expand Down

0 comments on commit d993ca1

Please sign in to comment.