-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrecord_video.py
91 lines (67 loc) · 2.74 KB
/
record_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm
from evaluation.evaluate_episodes import evaluate_episode_rtg, evaluate_episode_recurrent, evaluate_episode_few_shot
from models.decision_mamba import TrainableDM, TrainableDT
from transformers import Trainer
import mujoco
import gymnasium as gym
def evaluate_episodes(num_eval_episodes, model):
returns, lengths = [], []
model.eval()
if isinstance(model, TrainableDM):
eval_fn = evaluate_episode_rtg # evaluate_episode_recurrent
else:
eval_fn = evaluate_episode_rtg
#eval_fn = evaluate_episode_few_shot
with torch.no_grad():
for _ in tqdm(range(num_eval_episodes)):
ret, length = eval_fn(
env=env,
state_dim=state_dim,
act_dim=act_dim,
model=model,
scale=scale,
state_mean=state_mean,
state_std=state_std,
device=device,
target_return=TARGET_RETURN,
#max_length=1000,
# dataset=load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-medium-v2")
)
returns.append(ret)
lengths.append(ret)
return {
f'target_{TARGET_RETURN}_return_mean': np.mean(returns),
f'target_{TARGET_RETURN}_return_std': np.std(returns),
f'target_{TARGET_RETURN}_length_mean': np.mean(lengths),
f'target_{TARGET_RETURN}_length_std': np.std(lengths),
}
if __name__ == '__main__':
is_mamba = True
if is_mamba:
model = TrainableDM.from_pretrained('trained_models/dm_halfcheetah-expert-v2')
else:
model = TrainableDT.from_pretrained('trained_models/dt_halfcheetah-expert-v2')
dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-expert-v2")
act_dim = len(dataset['train'][0]["actions"][0])
state_dim = len(dataset['train'][0]["observations"][0])
# calculate dataset stats for normalization of states
states = []
traj_lens = []
for obs in dataset['train']["observations"]:
states.extend(obs)
traj_lens.append(len(obs))
states = np.vstack(states)
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
state_mean = state_mean.astype(np.float32)
state_std = state_std.astype(np.float32)
device = 'cuda'
model.to(device)
env = gym.make("HalfCheetah-v4", render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, 'video')
max_ep_len = 1000
scale = 1000.0 # normalization for rewards/returns
TARGET_RETURN = 12000 / scale # evaluation is conditioned on a return of 12000, scaled accordingly
print(evaluate_episodes(3, model))