forked from jacklxc/Virtual-Rat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BlockTime.py
executable file
·72 lines (63 loc) · 2.4 KB
/
BlockTime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
"""
This script trains Virtual Rat model with different block length.
"""
import argparse
import numpy as npp
import minpy.numpy as np
import cPickle
import matplotlib.pyplot as plt
import minpy
minpy.set_global_policy('only_numpy')
from VirtualRatModel import VirtualRatModel
from VirtualRatSolver import VirtualRatSolver
from VirtualRatBox import VirtualRatBox
from VirtualRat import VirtualRat
from dataProcessFunctions import *
parser = argparse.ArgumentParser()
parser.add_argument("index", help="job_array_index",
type=int)
args = parser.parse_args()
block_index = args.index%7
model_index = args.index/7
epoch_per_loop = 100
num_loop = 100
pre_train = 30
block_lengths = np.array([5,10,15,20,30,40,50])
block_length = block_lengths[block_index]
save_directory = "/scratch/xl1066/VirtualRat/publication/BlockTime/"
for j in [model_index]:
ratname = 'VirtualRat'+str(j)
print ratname
np.random.seed(j)
npp.random.seed(j)
model = VirtualRatModel()
rat = VirtualRat(model)
box = VirtualRatBox(mode="no_rule",length=500000,block_size=block_length,
trial_per_episode=30, repeat = False, p2a = 0.5,
block_correction = True, left_right_correction = True)
solver = VirtualRatSolver(model, box,
update_rule='adam',
optim_config={
'learning_rate': 5e-3,
'decay_rate': 1
},
init_rule='xavier',
num_episodes=epoch_per_loop,
verbose=False,
supervised = True,
stop = False,
print_every=50)
solver.init()
solver.train()
params = solver.save_params()
save_weights(save_directory+"trainedBlockTime-"+str(block_length)+"-"+str(j)+'-'+str(0)+".pkl",params)
for i in range(1,pre_train+1):
solver.train()
params = solver.save_params()
save_weights(save_directory+"trainedBlockTime-"+str(block_length)+"-"+str(j)+'-'+str(i*epoch_per_loop)+".pkl",params)
box.change_mode("alternative")
for i in range(1,num_loop - pre_train+1):
solver.train()
params = solver.save_params()
save_weights(save_directory + "trainedBlockTime-"+str(block_length)+"-"+str(j)+'-'+str((i+pre_train)*epoch_per_loop)+".pkl",params)