Skip to content

Commit

Permalink
parallelize shrinker
Browse files Browse the repository at this point in the history
  • Loading branch information
yi-fan-wang committed Jul 13, 2024
1 parent 8d33ec3 commit 579f58b
Showing 1 changed file with 51 additions and 22 deletions.
73 changes: 51 additions & 22 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ if args.fixed_params:
fdict = {p: v for (p, v) in zip(args.fixed_params, args.fixed_values)}

class Shrinker(object):
def __init__(self, data):
def __init__(self, data, nprocesses):
self.data = data
self.nprocesses = nprocesses

def pop(self):
if len(self.data) == 0:
return None
l = self.data[-1]
self.data = self.data[:-1]
num_pop = min(len(self.data), self.nprocesses)
l = self.data[-1*num_pop:]
self.data = self.data[:-1*num_pop]
return l

class TriangleBank(object):
Expand Down Expand Up @@ -218,9 +220,10 @@ class TriangleBank(object):
mtau = len(r)

# Try to do some actual matches
inc = Shrinker(r*1)
inc = Shrinker(r*1, args.nprocesses)
while 1:
j = inc.pop()
print(j)
if j is None:
hp.matches = matches[r]
hp.indices = r
Expand All @@ -229,25 +232,41 @@ class TriangleBank(object):
% (mmax, len(self), msig, mtau, mnum))
return False

hc = self[j]
m = hp.gen.match(hp, hc)
matches[j] = m
mnum += 1

# Update bounding match values, apply triangle inequality
maxmatches = hc.matches - m + 1.10
update = numpy.where(maxmatches < matches[hc.indices])[0]
matches[hc.indices[update]] = maxmatches[update]

# Update where to calculate matches
skip_threshold = 1 - (1 - hp.threshold) * 2.0
inc.data = inc.data[matches[inc.data] > skip_threshold]

if m > hp.threshold:
match_cache = {}
pool = pycbc.pool.choose_pool(args.nprocesses)
for return_idx, return_m in pool.imap_unordered(
match_wrapper,
({'idx': idx,
'newhp_data': hp.data,
'newhp_delta_f': hp.delta_f,
'bankhp_data': self[idx].data,
'bankhp_delta_f': self[idx].delta_f} for idx in j)
):
match_cache[return_idx] = return_m
pool.close_pool()
del pool

mnum += len(j)
if any(mv > hp.threshold for mv in match_cache.values()):
return True
if m > mmax:
mmax = m

else:
for idx in j:
hc = self[idx]
m = match_cache[idx]
matches[idx] = m

if m > mmax:
mmax = m

# Update bounding match values, apply triangle inequality
maxmatches = hc.matches - m + 1.10
update = numpy.where(maxmatches < matches[hc.indices])[0]
matches[hc.indices[update]] = maxmatches[update]

# Update where to calculate matches
skip_threshold = 1 - (1 - hp.threshold) * 2.0
inc.data = inc.data[matches[inc.data] > skip_threshold]

def check_params(self, gen, params, threshold):
num_added = 0
total_num = len(tuple(params.values())[0])
Expand Down Expand Up @@ -360,6 +379,16 @@ def wf_wrapper(p):
print(e)
return None

def match_wrapper(p):
'''A wrapper function to compute match
'''
idx= p['idx']
new_hp =pycbc.types.FrequencySeries(initial_array=p['newhp_data'], delta_f=p['newhp_delta_f'])
bank_hp =pycbc.types.FrequencySeries(initial_array=p['bankhp_data'], delta_f=p['bankhp_delta_f'])
new_hp.view = new_hp[gen.kmin:-1]
bank_hp.view = bank_hp[gen.kmin:-1]
return idx, gen.match(new_hp, bank_hp)

if args.input_file:
f = h5py.File(args.input_file, 'r')
params = {k: f[k][:] for k in f}
Expand Down

0 comments on commit 579f58b

Please sign in to comment.