Skip to content

Commit

Permalink
allow numpy 2.0 to run (gwastro#4829)
Browse files Browse the repository at this point in the history
* allow numpy 2.0 to run

* fix for np2

* fix np2 bug

* tmpltbank fixes

* fixes

* fix string issue

* be explicit about types in cases where the input might use type prediction

* this place too

* cc
  • Loading branch information
ahnitz authored Sep 6, 2024
1 parent 90ccd26 commit 8363708
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 19 deletions.
4 changes: 2 additions & 2 deletions bin/all_sky_search/pycbc_coinc_statmap
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,9 @@ f.attrs['hierarchical_removal_iterations'] = h_iterations

# Write whether hierarchical removals were removed against the
# inclusive background or the exclusive background. Have to use
# numpy.string_ datatype.
# numpy.bytes_ datatype.
if h_iterations != 0:
hrm_method = args.hierarchical_removal_against
f.attrs['hierarchical_removal_method'] = numpy.string_(hrm_method)
f.attrs['hierarchical_removal_method'] = numpy.bytes_(hrm_method)

logging.info("Done")
4 changes: 2 additions & 2 deletions bin/all_sky_search/pycbc_sngls_statmap
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,10 @@ f.attrs['hierarchical_removal_iterations'] = h_iterations

# Write whether hierarchical removals were removed against the
# inclusive background or the exclusive background. Have to use
# numpy.string_ datatype.
# numpy.bytes_ datatype.
if h_iterations != 0:
hrm_method = args.hierarchical_removal_against
f.attrs['hierarchical_removal_method'] = numpy.string_(hrm_method)
f.attrs['hierarchical_removal_method'] = numpy.bytes_(hrm_method)

logging.info("Done")

4 changes: 2 additions & 2 deletions pycbc/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def ensurearray(*args):
inputs was an array.
"""
input_is_array = any(isinstance(arg, numpy.ndarray) for arg in args)
args = numpy.broadcast_arrays(*args)
args = list(numpy.broadcast_arrays(*args))
args.append(input_is_array)
return args
return tuple(args)


def formatreturn(arg, input_is_array=False):
Expand Down
4 changes: 2 additions & 2 deletions pycbc/events/eventmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def findchirp_cluster_over_window(times, values, window_length):

indices = numpy.zeros(len(times), dtype=numpy.int32)
tlen = len(times)
absvalues = numpy.array(abs(values), copy=False)
times = numpy.array(times, dtype=numpy.int32, copy=False)
absvalues = numpy.asarray(abs(values))
times = numpy.asarray(times, dtype=numpy.int32)
k = findchirp_cluster_over_window_cython(times, absvalues, window_length,
indices, tlen)

Expand Down
15 changes: 10 additions & 5 deletions pycbc/events/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ def get_hist(self, ifos=None):

# renormalise to PDF
self.weights[ifo] = \
weights[ifo] / (weights[ifo].sum() * bin_volume)
(weights[ifo] / (weights[ifo].sum() * bin_volume))
self.weights[ifo] = self.weights[ifo].astype(numpy.float32)

if param[ifo].dtype == numpy.int8:
# Older style, incorrectly sorted histogram file
Expand Down Expand Up @@ -641,13 +642,17 @@ def logsignalrate(self, stats, shift, to_shift):

# Get reference ifo information
rate = numpy.zeros(len(shift), dtype=numpy.float32)
ps = {ifo: numpy.array(stats[ifo]['coa_phase'], ndmin=1)
ps = {ifo: numpy.array(stats[ifo]['coa_phase'],
dtype=numpy.float32, ndmin=1)
for ifo in self.ifos}
ts = {ifo: numpy.array(stats[ifo]['end_time'], ndmin=1)
ts = {ifo: numpy.array(stats[ifo]['end_time'],
dtype=numpy.float64, ndmin=1)
for ifo in self.ifos}
ss = {ifo: numpy.array(stats[ifo]['snr'], ndmin=1)
ss = {ifo: numpy.array(stats[ifo]['snr'],
dtype=numpy.float32, ndmin=1)
for ifo in self.ifos}
sigs = {ifo: numpy.array(stats[ifo]['sigmasq'], ndmin=1)
sigs = {ifo: numpy.array(stats[ifo]['sigmasq'],
dtype=numpy.float32, ndmin=1)
for ifo in self.ifos}
for ref_ifo in self.ifos:
rtype = rtypes[ref_ifo]
Expand Down
2 changes: 1 addition & 1 deletion pycbc/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ def _isstring(dtype):
"""Given a numpy dtype, determines whether it is a string. Returns True
if the dtype is string or unicode.
"""
return dtype.type == numpy.unicode_ or dtype.type == numpy.string_
return dtype.type == numpy.unicode_ or dtype.type == numpy.bytes_


def aliases_from_fields(fields):
Expand Down
6 changes: 3 additions & 3 deletions pycbc/tmpltbank/coord_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def get_mu_params(lambdas, metricParams, fUpper):
mus : list of floats or numpy.arrays
Position of the system(s) in the mu coordinate system
"""
lambdas = numpy.array(lambdas, copy=False)
lambdas = numpy.asarray(lambdas)
# If original inputs were floats we need to make this a 2D array
if len(lambdas.shape) == 1:
resize_needed = True
Expand All @@ -451,7 +451,7 @@ def get_mu_params(lambdas, metricParams, fUpper):
evecs = metricParams.evecs[fUpper]
evals = metricParams.evals[fUpper]

evecs = numpy.array(evecs, copy=False)
evecs = numpy.asarray(evecs)

mus = ((lambdas.T).dot(evecs)).T
mus = mus * numpy.sqrt(evals)[:,None]
Expand Down Expand Up @@ -479,7 +479,7 @@ def get_covaried_params(mus, evecsCV):
xis : list of floats or numpy.arrays
Position of the system(s) in the xi coordinate system
"""
mus = numpy.array(mus, copy=False)
mus = numpy.asarray(mus)
# If original inputs were floats we need to make this a 2D array
if len(mus.shape) == 1:
resize_needed = True
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ astropy>=2.0.3,!=4.2.1,!=4.0.5
Mako>=1.0.1
scipy>=0.16.0
matplotlib>=2.0.0
numpy>=1.16.0,!=1.19.0,<2.0.0
numpy>=1.16.0,!=1.19.0
pillow
h5py>=3.0.0,!=3.7.0
jinja2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
setup_requires = ['numpy>=1.16.0']
install_requires = setup_requires + [
'cython>=0.29',
'numpy>=1.16.0,!=1.19.0,<2.0.0',
'numpy>=1.16.0,!=1.19.0',
'scipy>=0.16.0',
'astropy>=2.0.3,!=4.2.1,!=4.0.5',
'matplotlib>=1.5.1',
Expand Down

0 comments on commit 8363708

Please sign in to comment.