-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRL_task_overview.py
81 lines (60 loc) · 1.85 KB
/
RL_task_overview.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from utils_lib.utils import Utils
import random
import numpy as np
import matplotlib.pyplot as plt
utils = Utils()
env_j = 16
policie_i = 2
fe_k=2
fev_l=3
if not(utils.compatible_env_policie(policie_i,env_j)):
print("not compatible")
policie = utils.all_policies[policie_i]["policie"]
policie_name = utils.all_policies[policie_i]["name"]
compute_opti = utils.all_policies[policie_i]["compute_opti"]
compute_opti = "cpu"
env = utils.all_envs[env_j]["env"]
env_name = utils.all_envs[env_j]["name"]
if len(utils.all_feature_extractor) <= fe_k:
print(len(utils.all_feature_extractor))
print("bad fe {}".format(fe_k))
feature_extract = utils.all_feature_extractor[fe_k]
if len(feature_extract["order"]) <= fev_l:
print("bad fev {}".format(fev_l))
feature_extract_name = feature_extract["name"]
feature_order = feature_extract["order"][fev_l]
feature_obs_shape = feature_extract["obs_shape"]
env = utils.get_env(env,env_j,feature_obs_shape)
policy_kwargs = utils.get_fe_kwargs(env,feature_extract,feature_order,compute_opti)
if policy_kwargs is not None:
model = policie(
policy="MlpPolicy",
#learning_rate = old_lr,
env=env,
policy_kwargs = policy_kwargs,
device=compute_opti,
verbose=1,
#seed=random.randint(100,100000),
)
rewards_tab = []
r_epi = []
def register_reward(input,_):
# print("---")
# print(input)
val =input['rewards'][0]
r_epi.append(val)
if input["dones"][0]:
rewards_tab.append(np.sum(r_epi)/len(r_epi))
r_epi.clear()
model.learn(
total_timesteps=1000000,
log_interval=1,
callback=register_reward
)
plt.plot(rewards_tab)
plt.show()
obs = env.reset()
for i in range(1000000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()