-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathmain.py
102 lines (87 loc) · 3.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import os,json, sys
import numpy as np
# single gpu
os.system('nvidia-smi -q -d Memory | grep -A5 GPU | grep Free > tmp.txt')
memory_gpu = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()]
os.environ["CUDA_VISIBLE_DEVICES"] = str(np.argmax(memory_gpu))
os.system('rm tmp.txt')
import torch
import utils
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == "__main__":
utils.print_logo(subtitle="Maintained by Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab")
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="SimpleDoorKey", help="SimpleDoorKey, KeyInBox, RandomBoxKey, ColoredDoorKey")
parser.add_argument("--save_name", type=str, required=True, help="path to folder containing policy and run details")
parser.add_argument("--logdir", type=str, default="./log/") # Where to log diagnostics to
parser.add_argument("--record", default=False, action='store_true')
parser.add_argument("--seed", default=None)
parser.add_argument("--ask_lambda", type=float, default=0.01, help="weight on communication penalty term")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lam", type=float, default=0.95, help="Generalized advantage estimate discount")
parser.add_argument("--gamma", type=float, default=0.99, help="MDP discount")
parser.add_argument("--n_itr", type=int, default=1000, help="Number of iterations of the learning algorithm")
parser.add_argument("--policy", type=str, default='ppo')
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--traj_per_itr", type=int, default=10)
parser.add_argument("--show", default=False, action='store_true')
parser.add_argument("--test_num", type=int, default=100)
parser.add_argument("--frame_stack", type=int, default=1)
parser.add_argument("--run_seed_list", type=int, nargs="*", default=[0])
if sys.argv[1] == 'eval':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
output_dir = os.path.join(args.logdir, args.policy, args.task, args.save_name)
policy = torch.load(output_dir + "/acmodel.pt")
policy.eval()
eval = utils.Eval(args,policy)
eval.eval_policy(args.test_num)
elif sys.argv[1] == 'eval_RL':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
output_dir = os.path.join(args.logdir, args.policy, args.task, args.save_name)
policy = torch.load(output_dir + "/acmodel.pt")
policy.eval()
eval = utils.Eval(args,policy)
eval.eval_RL_policy(args.test_num)
elif sys.argv[1] == 'train':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
from env.Game import Game
for i in args.run_seed_list:
setup_seed(i)
args.save_name = args.save_name + str(i)
game = Game(args, run_seed=i)
game.reset()
game.train()
elif sys.argv[1] == 'train_RL':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
from env.Game_RL import Game_RL
game = Game_RL(args)
game.reset()
game.train()
elif sys.argv[1] == 'baseline':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
eval = utils.Eval(args)
eval.eval_baseline(args.test_num)
elif sys.argv[1] == 'random':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
eval = utils.Eval(args)
eval.eval_policy(args.test_num)
elif sys.argv[1] == 'always':
sys.argv.remove(sys.argv[1])
args = parser.parse_args()
eval = utils.Eval(args)
eval.eval_always_ask(args.test_num)
else:
print("Invalid option '{}'".format(sys.argv[1]))