Skip to content

Commit

Permalink
Merge pull request #354 from LSSTDESC/issue/41/rando
Browse files Browse the repository at this point in the history
Addressing the last few remaining places where randoms aren't using default_rng and/or a configurable seed
  • Loading branch information
sschmidt23 authored May 3, 2023
2 parents ee34b98 + 6e4ccc4 commit a015dce
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/rail/creation/degradation/grid_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run(self):
If using a color-based redshift cut, galaxies with redshifts > the percentile cut are removed from the sample
before making the random selection.
"""
np.random.seed(self.config.random_seed)
rng = np.random.default_rng(seed=self.config.random_seed)

data = self.get_data('input')
with open(self.config.settings_file, 'rb') as handle:
Expand Down Expand Up @@ -189,13 +189,13 @@ def run(self):
number_to_keep = len(temp_data)

if int(number_to_keep) != number_to_keep:
random_num = np.random.uniform()
random_num = rng.uniform()
else:
random_num = 2

number_to_keep = np.floor(number_to_keep)
indices_to_list = list(temp_data.index.values)
np.random.shuffle(indices_to_list)
rng.shuffle(indices_to_list)

if random_num > xratio: # pragma: no cover
for j in range(int(number_to_keep)):
Expand Down
4 changes: 2 additions & 2 deletions src/rail/estimation/algos/knnpz.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def run(self):
trainszs = np.array(training_data[self.config.redshift_col])
colordata = _computecolordata(knndf, self.config.ref_band, self.config.bands)
nobs = colordata.shape[0]
rng = np.random.default_rng
perm = rng().permutation(nobs)
rng = np.random.default_rng(seed=self.config.seed)
perm = rng.permutation(nobs)
ntrain = round(nobs * self.config.trainfrac)
xtrain_data = colordata[perm[:ntrain]]
train_data = copy.deepcopy(xtrain_data)
Expand Down
4 changes: 3 additions & 1 deletion src/rail/estimation/algos/randomPZ.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RandomPZ(CatEstimator):
rand_zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"),
rand_zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"),
nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"),
seed=Param(int, 87, msg="random seed"),
column_name=Param(str, "mag_i_lsst", msg="name of a column that has the correct number of galaxies to find length of"))

def __init__(self, args, comm=None):
Expand All @@ -35,7 +36,8 @@ def _process_chunk(self, start, end, data, first):
pdf = []
# allow for either format for now
numzs = len(data[self.config.column_name])
zmode = np.round(np.random.uniform(0.0, self.config.rand_zmax, numzs), 3)
rng = np.random.default_rng(seed=self.config.seed + start)
zmode = np.round(rng.uniform(0.0, self.config.rand_zmax, numzs), 3)
widths = self.config.rand_width * (1.0 + zmode)
self.zgrid = np.linspace(self.config.rand_zmin, self.config.rand_zmax, self.config.nzbins)
for i in range(numzs):
Expand Down

0 comments on commit a015dce

Please sign in to comment.