-
Notifications
You must be signed in to change notification settings - Fork 18
/
gen_experiment.py
executable file
·40 lines (32 loc) · 1.18 KB
/
gen_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
"""Script for generating experiments.txt"""
import os
# The home dir on the node's scratch disk
USER = os.getenv('USER')
# This may need changing to e.g. /disk/scratch_fast depending on the cluster
SCRATCH_DISK = '/disk/scratch'
SCRATCH_HOME = f'{SCRATCH_DISK}/{USER}'
DATA_HOME = f'{SCRATCH_HOME}/mnist/data'
base_call = (f"python main.py -i {DATA_HOME}/input -o {DATA_HOME}/output "
"--epochs 50")
repeats = 3
learning_rates = [10, 1, 1e-1, 1e-2]
gammas = [.4, .5, .6, .7, .8]
settings = [(lr, gam, rep) for lr in learning_rates for gam in gammas
for rep in range(repeats)]
nr_expts = len(learning_rates) * len(gammas) * repeats
nr_servers = 10
avg_expt_time = 20 # mins
print(f'Total experiments = {nr_expts}')
print(f'Estimated time = {(nr_expts / nr_servers * avg_expt_time)/60} hrs')
output_file = open("experiment.txt", "w")
for lr, gam, rep in settings:
# Note that we don't set a seed for rep - a seed is selected at random
# and recorded in the output data by the python script
expt_call = (
f"{base_call} "
f"--lr {lr} "
f"--gamma {gam}"
)
print(expt_call, file=output_file)
output_file.close()