-
Notifications
You must be signed in to change notification settings - Fork 0
/
core.py
84 lines (71 loc) · 2.3 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import torch
import scipy.signal as signal
import time
import glfw, gym, mujoco_py
from mpi_tools import *
# -------------- logger ----------------
class Logger:
def __init__(self, CONST=80):
self.store_dict = dict()
self.string = ""
self.CONST = CONST
def add_string(self, s):
self.string += "\n%s\n" % s
def store(self, **kwargs):
for k,v in kwargs.items():
self.store_dict[k] = v
def log(self, key, val=None, mean=False, with_min_max=False):
if val is not None:
self.store_dict[key] = val
v = self.store_dict[key]
if with_min_max:
mean, std, min_, max_ = mpi_statistics_scalar([v], with_min_and_max=True)
self.add_string("|%s -> min:%.2e | max:%.2e | mean:%.3e | std:%.3e" \
% (key, min_, max_, mean, std) + "|")
else:
if mean:
mean, std = mpi_statistics_scalar([v], with_min_and_max=False)
self.add_string("|%s -> %.1f" % (key, mean) + "|")
else:
self.add_string("|%s -> %s" % (key, str(v)) + "|")
def dump(self):
if proc_id()==0:
print("-"*self.CONST)
print(self.string)
print("-"*self.CONST)
self.string = ""
self.store_dict = dict()
# -------------- rl functions -----------------
def discount(x, disc):
y = signal.lfilter([1], [1, -float(disc)], x=x[::-1])
return y[::-1]
def gae(path, gamma, lam, avg_grad):
path["ret"] = discount(path["rew"], gamma)
bln = path["value"]
td = path["rew"][:-1] + gamma*bln[1:] - bln[:-1]
adv = path["adv"] = discount(td, gamma*lam)
# TODO: check/understand effect of normalization (empirically/theoretically)
if avg_grad:
mean, std = mpi_statistics_scalar(adv)
else:
mean, std = adv.mean(), adv.std()
path["adv"] = (path["adv"] - mean) / std
def validate(ac, env_name, timeout=500, render=True, seed=42):
env = gym.make(env_name)
o = env.reset(seed=seed)
rews = []
done = False
for _ in range(timeout):
act, v, logp = ac.step(o)
o, r, done, info = env.step(act.detach().cpu().numpy())
rews.append(r)
if render: env.render()
if done: break
env.reset()
env.close()
return np.array(rews)
# --------- torch ----------
def change_lr(optim, lr):
for g in optim.param_groups:
g['lr'] = lr