Skip to content

Commit

Permalink
added way to track random seeds for reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Bethany Lusch authored and Bethany Lusch committed Oct 13, 2019
1 parent f06ad04 commit cad0651
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions helperfns.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def set_defaults(params):
params['d'] = len(params['widths']) # d must be calculated like this

# defaults related to initialization of parameters
if 'seed' not in params:
random_seed = np.random.randint(2 ** 30)
print("setting default: choosing random seed of %d and saving to params" % random_seed)
params['seed'] = random_seed
if 'dist_weights' not in params:
print("setting default: distribution for weights on main net is tn (truncated normal)")
params['dist_weights'] = 'tn'
Expand Down
2 changes: 2 additions & 0 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def main_exp(params):
if not os.path.exists(params['folder_name']):
os.makedirs(params['folder_name'])

tf.set_random_seed(params['seed'])
np.random.seed(params['seed'])
# data is num_steps x num_examples x n but load flattened version (matrix instead of tensor)
data_val = np.loadtxt(('./data/%s_val_x.csv' % (params['data_name'])), delimiter=',', dtype=np.float64)
try_net(data_val, params)

0 comments on commit cad0651

Please sign in to comment.