-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_single.py
45 lines (33 loc) · 1.23 KB
/
run_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from single_qtransformer import QTransformer
import torch
from seq_env_wrapper import SequenceEnvironmentWrapper
import numpy as np
import gymnasium as gym
def transform_state(history):
return torch.from_numpy(history['observations']).unsqueeze(0).float()
def play_episode(env: SequenceEnvironmentWrapper, model):
history = env.reset()
done = False
steps = 0
episode_return = 0
while not done:
print(model(transform_state(history)))
action = model.predict_action(transform_state(history))[0]
history, reward, terminated, truncated, info = env.step(action)
done = truncated or terminated
steps += 1
episode_return += reward
return episode_return
def eval(env: SequenceEnvironmentWrapper, model, episodes=10):
model.eval()
scores = [play_episode(env, model) for ep in range(episodes)]
model.train()
return scores
model = QTransformer(4, 2, 128, 4, 3)
checkpoint = torch.load('models_single/model-8.pt')
model.load_state_dict(checkpoint['model_state'])
model.eval()
scores = eval(SequenceEnvironmentWrapper(gym.make('CartPole-v1', render_mode='human'), num_stack_frames=4), model, episodes=1)
print(np.max(scores))
print(np.min(scores))
print(np.mean(scores))