You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @MorvanZhou. Thank you for your tutorial. I'm trying to modify the A3C from Cartpole to MsPacman. I found that after I change the network to a CNN, the code will get stuck in the forward function. It could be run on Mac without problems. But It will get stuck when running on Linux. To illustrate the problem, I simply changed the N_S to 10000 in discrete_A3C.py and use a randomly generated numpy vector as a state. It will also stuck in forward function and has no any warning or error information. Do you have any ideas about that?
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import numpy as np
import torch.nn as nn
from utils import v_wrap, set_init, push_and_pull, record
import torch.nn.functional as F
import torch.multiprocessing as mp
from shared_adam import SharedAdam
import gym
import os
os.environ["OMP_NUM_THREADS"] = "1"
UPDATE_GLOBAL_ITER = 10
GAMMA = 0.9
MAX_EP = 4000
env = gym.make('CartPole-v0')
N_S = 10000
N_A = env.action_space.n
class Net(nn.Module):
def __init__(self, s_dim, a_dim):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
self.pi1 = nn.Linear(s_dim, 100)
self.pi2 = nn.Linear(100, a_dim)
self.v1 = nn.Linear(s_dim, 100)
self.v2 = nn.Linear(100, 1)
set_init([self.pi1, self.pi2, self.v1, self.v2])
self.distribution = torch.distributions.Categorical
def forward(self, x):
pi1 = F.relu(self.pi1(x))
logits = self.pi2(pi1)
v1 = F.relu(self.v1(x))
values = self.v2(v1)
return logits, values
def choose_action(self, s):
self.eval()
logits, _ = self.forward(s)
prob = F.softmax(logits, dim=1).data
m = self.distribution(prob)
return m.sample().numpy()[0]
def loss_func(self, s, a, v_t):
self.train()
logits, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
probs = F.softmax(logits, dim=1)
m = self.distribution(probs)
exp_v = m.log_prob(a) * td.detach().squeeze()
a_loss = -exp_v
total_loss = (c_loss + a_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name):
super(Worker, self).__init__()
self.name = 'w%i' % name
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
self.gnet, self.opt = gnet, opt
self.lnet = Net(N_S, N_A) # local network
self.env = gym.make('MsPacman-v0').unwrapped
def run(self):
total_step = 1
while self.g_ep.value < MAX_EP:
s = self.env.reset()
s = np.random.rand(N_S)
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
if self.name == 'w0':
self.env.render()
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
s_ = np.random.rand(N_S)
if done: r = -1
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
buffer_s, buffer_a, buffer_r = [], [], []
if done: # done and print information
record(self.g_ep, ep_r, self.res_queue, self.name, 1, 0)
break
s = s_
total_step += 1
self.res_queue.put(None)
if __name__ == "__main__":
gnet = Net(N_S, N_A) # global network
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(gnet.parameters(), lr=0.0001) # global optimizer
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
# parallel training
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
import matplotlib.pyplot as plt
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()
The text was updated successfully, but these errors were encountered:
I encountered a similar problem long time ago. It will be stuck when using a large numpy array with multiprocessing. I haven't found any useful method to deal with this problem.
The problem could be solved by add mp.set_start_method("spawn") to the beginning of the if __name__ == '__main__' scope. The answer is referred from here.
Hi @MorvanZhou. Thank you for your tutorial. I'm trying to modify the A3C from Cartpole to MsPacman. I found that after I change the network to a CNN, the code will get stuck in the forward function. It could be run on Mac without problems. But It will get stuck when running on Linux. To illustrate the problem, I simply changed the N_S to 10000 in discrete_A3C.py and use a randomly generated numpy vector as a state. It will also stuck in forward function and has no any warning or error information. Do you have any ideas about that?
The text was updated successfully, but these errors were encountered: