forked from tinkoff-ai/CORL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_eval_IQL.py
84 lines (65 loc) · 3.07 KB
/
test_eval_IQL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import jax
import optax
import flax
from algorithms.offline.rebrac_Fetch_UR5 import DetActor, ActorTrainState, ReplayBuffer, Config
import gym
import gym_UR5_FetchReach
import torch
from pathlib import Path
from algorithms.offline.rebrac_Fetch_UR5 import wrap_env, evaluate
import yaml
import numpy as np
import argparse
from algorithms.offline.iql_Fetch_UR5 import ImplicitQLearning, eval_actor, compute_mean_std
# Load the TrainState from a file
def load_train_state(save_path, state_structure):
with open(save_path, 'rb') as f:
state_dict = flax.serialization.from_bytes(state_structure, f.read())
return state_dict
def create_train_state(actor_module, actor_key, init_state, actor_learning_rate):
return ActorTrainState.create(
apply_fn=actor_module.apply,
params=actor_module.init(actor_key, init_state),
target_params=actor_module.init(actor_key, init_state),
tx=optax.adam(learning_rate=actor_learning_rate),
)
def main(env_name, num_episodes, config_path, model_path, seed):
trainer = ImplicitQLearning(**kwargs)
policy_file = Path(model_path)
trainer.load_state_dict(torch.load(policy_file))
actor = trainer.actor
with open(config_path) as yaml_file:
config_dict = yaml.load(yaml_file)
config = Config.from_dict(Config, config_dict)
dataset_name = '/home/nikisim/Mag_diplom/CORL/data/UR5_FetchReach_new_action2.npy'
dataset = np.load(dataset_name, allow_pickle=True).item()
replay_buffer = ReplayBuffer()
replay_buffer.create_from_d4rl(
dataset_name, False, False
)
env = gym.make('gym_UR5_FetchReach/UR5_FetchReachEnv-v0', render=True)
env.action_space.seed(seed)
env.observation_space.seed(seed)
env = wrap_env(env, replay_buffer.mean, replay_buffer.std)
evaluations = []
eval_scores, eval_success = eval_actor(
env,
actor,
device=config.device,
n_episodes=num_episodes,
seed=config.seed,
)
env.close()
if __name__ == "__main__":
# Set up the argument parser
parser = argparse.ArgumentParser(description="Evaluate a CORL pre-trained model.")
parser.add_argument("--env_name", type=str, default='FetchReach',help="Name of the environment to run.")
parser.add_argument("--config_path", type=str, default='/home/nikisim/Mag_diplom/CORL/data/saved_models/IQL_UR5_FetchReach_new_action/IQL-FetchReach_UR5-221b042a/config.yaml', help="Path to the configuration YAML file.")
parser.add_argument("--model_path", type=str, default='/home/nikisim/Mag_diplom/CORL/data/saved_models/IQL_UR5_FetchReach_new_action/IQL-FetchReach_UR5-221b042a/checkpoint_429999.pt', help="Path to the saved model.")
parser.add_argument("--num_episodes", type=int, default=5, help="Number of episodes to run.")
parser.add_argument("--seed", type=int, default=1, help="Random seed for reproducibility.")
# Parse the command-line arguments
args = parser.parse_args()
# Call the main function with the parsed arguments
main(args.env_name, args.num_episodes, args.config_path,
args.model_path, args.seed)