Skip to content

Commit

Permalink
github review (1) combine the redundant code (2) some cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
yi-fan-wang committed Jun 28, 2024
1 parent ed5721d commit 02d02d4
Showing 1 changed file with 27 additions and 51 deletions.
78 changes: 27 additions & 51 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ from scipy.stats import gaussian_kde
import multiprocessing

import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
import pycbc.pool
from pycbc import transforms
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc.distributions import read_params_from_config
Expand Down Expand Up @@ -76,7 +77,8 @@ parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--num_cores', type=int, default=1, help='number of cores for waveform generation parallelization.')
parser.add_argument('--nprocesses', type=int, default=1,
help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.')
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -247,56 +249,31 @@ class TriangleBank(object):
mmax = m

def check_params(self, gen, params, threshold):
num_tried = 0
num_added = 0
total_num = len(tuple(params.values())[0])

if args.num_cores == 1:
for i in range(total_num):
num_tried += 1.0

try:
hp = gen.generate(**{key:params[key][i] for key in params})
except Exception as err:
print(err)
continue

waveform_cache = []

with multiprocessing.Pool(args.nprocesses) as pool:
for return_wf in pool.imap_unordered(
wf_wrapper,
({k: params[k][idx] for k in params} for idx in range(total_num))
):
waveform_cache += [return_wf]

for hp in waveform_cache:
if isinstance(hp, pycbc.types.FrequencySeries):
hp.gen = gen
hp.threshold = threshold
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
elif hp is None:
logging.info("Waveform generation failed!")
continue
else:
raise ValueError("Waveform generation wrapper problem!")

# use multiprocessing
elif args.num_cores > 1:
waveform_cache = []
with multiprocessing.Pool(args.num_cores) as pool:
for return_wf in pool.imap_unordered(
wf_wrapper,
({k: params[k][idx] for k in params} for idx in range(total_num))
):
waveform_cache += [return_wf]

for i in range(total_num):
num_tried += 1
hp = waveform_cache[i]

if isinstance(hp, pycbc.types.FrequencySeries):
hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
elif numpy.isnan(hp):
logging.info("Waveform generation failed!")
continue
else:
raise ValueError("Waveform generation problem.")

else:
raise ValueError("Number of cores should be a positive integer.")

return bank, num_added / float(num_tried)
return bank, num_added / total_num

class GenUniformWaveform(object):
def __init__(self, buffer_length, sample_rate, f_lower):
Expand Down Expand Up @@ -371,13 +348,12 @@ gen = GenUniformWaveform(args.buffer_length,
args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

if args.num_cores > 1:
def wf_wrapper(p):
try:
hp = gen.generate(**p)
return hp
except Exception:
return numpy.nan
def wf_wrapper(p):
try:
hp = gen.generate(**p)
return hp
except Exception:
return None

if args.input_file:
f = h5py.File(args.input_file, 'r')
Expand Down

0 comments on commit 02d02d4

Please sign in to comment.