Skip to content

Commit

Permalink
allow approximant to be read from config file (gwastro#4805)
Browse files Browse the repository at this point in the history
* allow approximant to be read from config file

* fix

* cc
  • Loading branch information
ahnitz authored Jul 2, 2024
1 parent d993ca1 commit 70bd500
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ parser.add_argument('--min',
help='list of the minimum parameter values', nargs='+', type=float)
parser.add_argument('--max',
help='list of the maximum parameter values', nargs='+', type=float)
parser.add_argument('--approximant', required=True,
help='The waveform approximant to place')
parser.add_argument('--approximant', required=False,
help='The waveform approximant to place.')
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
help='size of waveform buffer in seconds')
Expand All @@ -71,12 +71,13 @@ parser.add_argument('--min-mchirp', type=float)
parser.add_argument('--max-mchirp', type=float)
parser.add_argument('--fixed-params', type=str, nargs='*')
parser.add_argument('--fixed-values', type=float, nargs='*')
parser.add_argument('--use-cross', action='store_true')
parser.add_argument('--max-q', type=float)
parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--nprocesses', type=int, default=1,
parser.add_argument('--nprocesses', type=int, default=1,
help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.')
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -254,7 +255,7 @@ class TriangleBank(object):

pool = pycbc.pool.choose_pool(args.nprocesses)
for return_wf in pool.imap_unordered(
wf_wrapper,
wf_wrapper,
({k: params[k][idx] for k in params} for idx in range(total_num))
):
waveform_cache += [return_wf]
Expand All @@ -263,7 +264,7 @@ class TriangleBank(object):
for hp in waveform_cache:
if hp is not None:
hp.gen = gen
hp.threshold = threshold
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
Expand Down Expand Up @@ -302,11 +303,14 @@ class GenUniformWaveform(object):
f = self.f_lower

kwds['f_lower'] = f
if hasattr(kwds['approximant'], 'decode'):
kwds['approximant'] = kwds['approximant'].decode()

if kwds['approximant'] in pycbc.waveform.fd_approximants():
hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_ref=10.0, **kwds)

if args.use_cross:
hp = hc

if 'fratio' in kwds:
hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])
Expand All @@ -316,12 +320,13 @@ class GenUniformWaveform(object):
hp = pycbc.waveform.get_waveform_filter(
pycbc.types.zeros(self.flen, dtype=numpy.complex64),
delta_f=self.delta_f, delta_t=dt,
f_lower=self.f_lower, **kwds)
**kwds)

hp.resize(self.flen)
hp = hp.astype(numpy.complex64)
hp[self.kmin:-1] *= self.w
s = float(1.0 / pycbc.filter.sigmasq(hp, low_frequency_cutoff=self.f_lower) ** 0.5)
s = float(1.0 / pycbc.filter.sigmasq(hp,
low_frequency_cutoff=f) ** 0.5)
hp *= s
hp.params = kwds
hp.view = hp[self.kmin:-1]
Expand Down Expand Up @@ -350,7 +355,8 @@ def wf_wrapper(p):
try:
hp = gen.generate(**p)
return hp
except Exception:
except Exception as e:
print(e)
return None

if args.input_file:
Expand Down Expand Up @@ -399,7 +405,8 @@ def draw(rtype):
if args.input_config is not None and waveform_transforms is not None:
params = transforms.apply_transforms(params, waveform_transforms)

params['approximant'] = numpy.array([args.approximant]*size)
if args.approximant is not None:
params['approximant'] = numpy.array([args.approximant]*size)

# Filter out stuff (kde method may also generate samples outside boundaries).
l = None
Expand Down Expand Up @@ -429,7 +436,6 @@ def draw(rtype):
l = dists_joint.contains(params)

params = {k: params[k][l] for k in params}

return params

def cdraw(rtype, ts, te):
Expand Down

0 comments on commit 70bd500

Please sign in to comment.