-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize.py
120 lines (98 loc) · 3.41 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import yaml
from agent.arguments import get_args_visualizer, fetch_params
from test_gen_util import generate_tests
from agent.envs import PLEnv
from agent.policy import GNNPolicy
def main(log_name, run_id):
print(log_name,run_id)
params = fetch_params(run_id)
print(params['make_tests'])
generate_tests(params=params)
# Account for changes in logging
# params["env"]["cursor_start_pos"] = [6, 6]
# params["env"]["perturbation"] = 0
# params["env"]["max_episode_steps"] = 2
# params["env"]["assignment_dir"] = "data/tests"
# params["base"]["done_action"] = params["env"]["done_action"]
path = os.path.join("save", str(run_id) + ".pt")
torch.manual_seed(params["seed"])
torch.cuda.manual_seed_all(params["seed"])
if params["cuda"] and torch.cuda.is_available() and params["cuda_deterministic"]:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(1)
device = torch.device("cuda" if params["cuda"] else "cpu")
params["num_processes"] = 1
env_kwargs = params["env"] # used to be "eval"
env_kwargs["max_episode_steps"] = params["env"]["max_episode_steps"]
env = PLEnv.make_vec_envs(
params["seed"], params["num_processes"], device, render=True, **env_kwargs
)
base_kwargs = params["base"]
actor_critic = GNNPolicy(
env.get_attr("orig_obs_space")[0],
env.get_attr("action_space")[0],
env.get_attr("num_actions")[0],
base_kwargs=base_kwargs,
device=device,
done_action = params['env']['done_action'],
)
actor_critic.to(device)
actor_critic.load_state_dict(torch.load(path)[0])
actor_critic.eval()
obs = env.reset()
env.render()
step = 0
stop_on_update = False
made_hole = False
stop_on_hole=False
stop_on_fail=True
stop_on_hole_success=True
stop_on_new_episode=True
num_tests = 0
while True:
with torch.no_grad():
(_, action, _, _,) = actor_critic.act(
obs,
None,
None,
)
if stop_on_new_episode and step == 0: breakpoint()
step +=1
print(f'step: {step}')
print(f"Action: {action}")
if stop_on_update:
breakpoint() # to allow manual change, action[0] = n
print()
# if needed add python input l
if action[0] == 7:
made_hole=True
if stop_on_hole: breakpoint()
obs, reward, done, info = env.step(action.reshape((-1,)))
if done[0]:
print(f"Reward: {info[0]['episode']['r']}")
step = 0
print()
print(f'step: {step}')
# breakpoint()
if info[0]["episode"]["r"] == 0:
print('failed')
if stop_on_fail: breakpoint()
else:
# succeeded
if made_hole:
print('Succeeded despite hole creation')
if stop_on_hole_success: breakpoint()
num_tests +=1
# breakpoint()
print("---------------Environment reset---------------")
if num_tests % 5000 == 0 :
breakpoint()
made_hole=False
env.render()
print()
if __name__ == "__main__":
args = get_args_visualizer()
main(args.log_name, args.run_id)