-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
52 lines (42 loc) · 1.25 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Easy script to test models and render the corresponding environments.
"""
# STD
import random
import time
import sys
# EXT
import torch
import gym
from gym.spaces import Box
# PROJECT
from analyze import discrete_to_continuous
from train import ENVIRONMENTS
def select_greedy_action(model, state):
with torch.no_grad():
action = model(torch.Tensor(state))
return torch.argmax(action).item()
def test_reinforce_model(model, env, num_episodes):
for i in range(num_episodes):
state = env.reset()
time.sleep(1)
env.render()
done = False
while not done:
action = select_greedy_action(model, state) # Greedy action
if isinstance(env.action_space, Box):
action = [discrete_to_continuous(action, env)]
next_state, reward, done, _ = env.step(action)
env.render()
state = next_state
time.sleep(0.05)
env.close()
if __name__ == "__main__":
filename = sys.argv[1]
print(filename)
# Init envs
envs = {name: gym.envs.make(name) for name in ENVIRONMENTS}
env_name = filename.split('_')[0]
model = torch.load(filename)
env = envs[env_name]
test_reinforce_model(model, env, num_episodes=10)