-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path06_dqn_dueling.py
executable file
·97 lines (77 loc) · 3.18 KB
/
06_dqn_dueling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
import gym
import ptan
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from lib import common
class DuelingDQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(DuelingDQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc_adv = nn.Sequential(
nn.Linear(conv_out_size, 256),
nn.ReLU(),
nn.Linear(256, n_actions)
)
self.fc_val = nn.Sequential(
nn.Linear(conv_out_size, 256),
nn.ReLU(),
nn.Linear(256, 1)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
fx = x.float() / 256
conv_out = self.conv(fx).view(fx.size()[0], -1)
val = self.fc_val(conv_out)
adv = self.fc_adv(conv_out)
return val + (adv - adv.mean(dim=1, keepdim=True))
if __name__ == "__main__":
params = common.HYPERPARAMS['pong']
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
env = gym.make(params['env_name'])
env = ptan.common.wrappers.wrap_dqn(env)
writer = SummaryWriter(comment="-" + params['run_name'] + "-dueling")
net = DuelingDQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = ptan.agent.TargetNet(net)
selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params['epsilon_start'])
epsilon_tracker = common.EpsilonTracker(selector, params)
agent = ptan.agent.DQNAgent(net, selector, device=device)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=1)
buffer = ptan.experience.ExperienceReplayBuffer(exp_source, buffer_size=params['replay_size'])
optimizer = optim.Adam(net.parameters(), lr=params['learning_rate'])
frame_idx = 0
with common.RewardTracker(writer, params['stop_reward']) as reward_tracker:
while True:
frame_idx += 1
buffer.populate(1)
epsilon_tracker.frame(frame_idx)
new_rewards = exp_source.pop_total_rewards()
if new_rewards:
if reward_tracker.reward(new_rewards[0], frame_idx, selector.epsilon):
break
if len(buffer) < params['replay_initial']:
continue
optimizer.zero_grad()
batch = buffer.sample(params['batch_size'])
loss_v = common.calc_loss_dqn(batch, net, tgt_net.target_model, gamma=params['gamma'], device=device)
loss_v.backward()
optimizer.step()
if frame_idx % params['target_net_sync'] == 0:
tgt_net.sync()