-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
119 lines (99 loc) · 3.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
import utils
import model
BATCH_SIZE = 128
LEARNING_RATE = 0.001
GAMMA = 0.99
TAU = 0.001
class Trainer:
def __init__(self, agent, ram):
self.ram = ram
self.iter = 0
self.agent = agent
self.noise = utils.OrnsteinUhlenbeckActionNoise(self.agent["action_dim"])
self.actor = self.agent["actor"]
self.target_actor = self.agent["target_actor"]
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),LEARNING_RATE)
self.critic = self.agent['critic']
self.target_critic = self.agent['target_critic']
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),LEARNING_RATE)
def get_exploitation_action(self, state):
"""
gets the action from target actor added with exploration noise
:param state: state (Numpy array)
:return: sampled action (Numpy array)
"""
state = Variable(torch.from_numpy(state))
action = self.target_actor.forward(state).detach()
return action.data.numpy()
def get_exploration_action(self, state):
"""
gets the action from actor added with exploration noise
:param state: state (Numpy array)
:return: sampled action (Numpy array)
"""
state = Variable(torch.from_numpy(state))
action = self.actor.forward(state).detach()
new_action = action.data.numpy() + (self.noise.sample() * self.agent["action_lim"])
return new_action
def optimize(self):
"""
Samples a random batch from replay memory and performs optimization
:return:
"""
s1,a1,r1,s2 = self.ram.sample(BATCH_SIZE)
s1 = Variable(torch.from_numpy(s1))
a1 = Variable(torch.from_numpy(a1))
r1 = Variable(torch.from_numpy(r1))
s2 = Variable(torch.from_numpy(s2))
# ---------------------- optimize critic ----------------------
# Use target actor exploitation policy here for loss evaluation
a2 = self.target_actor.forward(s2).detach()
next_val = torch.squeeze(self.target_critic.forward(s2, a2).detach())
# y_exp = r + gamma*Q'( s2, pi'(s2))
y_expected = torch.squeeze(r1 + GAMMA*next_val)
# y_pred = Q( s1, a1)
y_predicted = torch.squeeze(self.critic.forward(s1, a1))
# compute critic loss, and update the critic
loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
# ---------------------- optimize actor ----------------------
pred_a1 = self.actor.forward(s1)
loss_actor = -1*torch.sum(self.critic.forward(s1, pred_a1))
self.actor_optimizer.zero_grad()
loss_actor.backward()
self.actor_optimizer.step()
utils.soft_update(self.target_actor, self.actor, TAU)
utils.soft_update(self.target_critic, self.critic, TAU)
# if self.iter % 100 == 0:
# print 'Iteration :- ', self.iter, ' Loss_actor :- ', loss_actor.data.numpy(),\
# ' Loss_critic :- ', loss_critic.data.numpy()
# self.iter += 1
def save_models(self, episode_count, folder):
"""
saves the target actor and critic models
:param episode_count: the count of episodes iterated
:return:
"""
torch.save(self.target_actor.state_dict(), folder+'/' + str(episode_count) + '_actor.pt')
torch.save(self.target_critic.state_dict(), folder+'/' + str(episode_count) + '_critic.pt')
print ('Models saved successfully')
def load_models(self, episode, folder):
"""
loads the target actor and critic models, and copies them onto actor and critic models
:param episode: the count of episodes iterated (used to find the file name)
:return:
"""
self.actor.load_state_dict(torch.load(folder+'/' + str(episode) + '_actor.pt'))
self.critic.load_state_dict(torch.load(folder+'/' + str(episode) + '_critic.pt'))
utils.hard_update(self.target_actor, self.actor)
utils.hard_update(self.target_critic, self.critic)
print ('Models loaded succesfully')