-
Notifications
You must be signed in to change notification settings - Fork 64
/
utils.py
75 lines (61 loc) · 1.91 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import torch
import shutil
import torch.autograd as Variable
def soft_update(target, source, tau):
"""
Copies the parameters from source network (x) to target network (y) using the below update
y = TAU*x + (1 - TAU)*y
:param target: Target network (PyTorch)
:param source: Source network (PyTorch)
:return:
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + param.data * tau
)
def hard_update(target, source):
"""
Copies the parameters from source network to target network
:param target: Target network (PyTorch)
:param source: Source network (PyTorch)
:return:
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
def save_training_checkpoint(state, is_best, episode_count):
"""
Saves the models, with all training parameters intact
:param state:
:param is_best:
:param filename:
:return:
"""
filename = str(episode_count) + 'checkpoint.path.rar'
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class OrnsteinUhlenbeckActionNoise:
def __init__(self, action_dim, mu = 0, theta = 0.15, sigma = 0.2):
self.action_dim = action_dim
self.mu = mu
self.theta = theta
self.sigma = sigma
self.X = np.ones(self.action_dim) * self.mu
def reset(self):
self.X = np.ones(self.action_dim) * self.mu
def sample(self):
dx = self.theta * (self.mu - self.X)
dx = dx + self.sigma * np.random.randn(len(self.X))
self.X = self.X + dx
return self.X
# use this to plot Ornstein Uhlenbeck random motion
if __name__ == '__main__':
ou = OrnsteinUhlenbeckActionNoise(1)
states = []
for i in range(1000):
states.append(ou.sample())
import matplotlib.pyplot as plt
plt.plot(states)
plt.show()