-
Notifications
You must be signed in to change notification settings - Fork 0
/
mpi_tools.py
146 lines (126 loc) · 4.85 KB
/
mpi_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from mpi4py import MPI
import os, subprocess, sys
import numpy as np
import torch
import copy
# -------------- mpi functions ----------------
# from spinning-up: https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/spinup/utils/mpi_tools.py#L56
def mpi_fork(n, bind_to_core=False):
"""
Re-launches the current script with workers linked by MPI.
Also, terminates the original process that launched it.
Taken almost without modification from the Baselines function of the
`same name`_.
.. _`same name`: https://github.com/openai/baselines/blob/master/baselines/common/mpi_fork.py
Args:
n (int): Number of process to split into.
bind_to_core (bool): Bind each MPI process to a core.
"""
if n<=1:
return
if os.getenv("IN_MPI") is None:
env = os.environ.copy()
env.update(
MKL_NUM_THREADS="1",
OMP_NUM_THREADS="1",
IN_MPI="1"
)
args = ["mpirun", "-np", str(n)]
if bind_to_core:
args += ["-bind-to", "core"]
args += [sys.executable] + sys.argv
subprocess.check_call(args, env=env)
sys.exit()
def msg(m, string=''):
print(('Message from %d: %s \t '%(MPI.COMM_WORLD.Get_rank(), string))+str(m))
def proc_id():
"""Get rank of calling process."""
return MPI.COMM_WORLD.Get_rank()
def allreduce(*args, **kwargs):
return MPI.COMM_WORLD.Allreduce(*args, **kwargs)
def num_procs():
"""Count active MPI processes."""
return MPI.COMM_WORLD.Get_size()
def broadcast(x, root=0):
MPI.COMM_WORLD.Bcast(x, root=root)
def mpi_op(x, op):
x, scalar = ([x], True) if np.isscalar(x) else (x, False)
x = np.asarray(x, dtype=np.float32)
buff = np.zeros_like(x, dtype=np.float32)
allreduce(x, buff, op=op)
return buff[0] if scalar else buff
def mpi_sum(x):
return mpi_op(x, MPI.SUM)
def mpi_avg(x):
"""Average a scalar or vector over MPI processes."""
return mpi_sum(x) / num_procs()
def mpi_statistics_scalar(x, with_min_and_max=False):
"""
Get mean/std and optional min/max of scalar x across MPI processes.
Args:
x: An array containing samples of the scalar to produce statistics
for.
with_min_and_max (bool): If true, return min and max of x in
addition to mean and std.
"""
x = np.array(x, dtype=np.float32)
global_sum, global_n = mpi_sum([np.sum(x), len(x)])
mean = global_sum / global_n
global_sum_sq = mpi_sum(np.sum((x - mean)**2))
std = np.sqrt(global_sum_sq / global_n) # compute global std
if with_min_and_max:
global_min = mpi_op(np.min(x) if len(x) > 0 else np.inf, op=MPI.MIN)
global_max = mpi_op(np.max(x) if len(x) > 0 else -np.inf, op=MPI.MAX)
return mean, std, global_min, global_max
return mean, std
def setup_pytorch_for_mpi():
"""
Avoid slowdowns caused by each separate process's PyTorch using
more than its fair share of CPU resources.
"""
#print('Proc %d: Reporting original number of Torch threads as %d.'%(proc_id(), torch.get_num_threads()), flush=True)
if torch.get_num_threads()==1:
return
fair_num_threads = max(int(torch.get_num_threads() / num_procs()), 1)
torch.set_num_threads(fair_num_threads)
#print('Proc %d: Reporting new number of Torch threads as %d.'%(proc_id(), torch.get_num_threads()), flush=True)
def mpi_avg_grads(module):
""" Average contents of gradient buffers across MPI processes. """
if num_procs()==1:
return
for p in module.parameters():
p_grad_numpy = p.grad.numpy() # numpy view of tensor data
avg_p_grad = mpi_avg(p.grad)
p_grad_numpy[:] = avg_p_grad[:]
def mpi_avg_params(module):
""" Average contents of parameter buffers across MPI processes. """
if num_procs()==1:
return
for p in module.parameters():
p_numpy = p.data.numpy()
avg_p = mpi_avg(p.data)
# all the elements in the array become a copy of the other (not the reference)
p_numpy[:] = avg_p[:]
def sync_params(module):
""" Sync all parameters of module across all MPI processes. """
if num_procs()==1:
return
for p in module.parameters():
p_numpy = p.data.numpy()
broadcast(p_numpy)
# ---------- for easgd ----------------
def mpi_avg_params_ac(module):
if num_procs()==1:
return module
module_cp = copy.deepcopy(module)
for (cp, p) in zip(module_cp.parameters(), module.parameters()):
p_numpy = p.data.numpy()
cp_numpy = cp.data.numpy()
avg_p = mpi_avg(p.data)
cp_numpy[:] = avg_p[:]
return module_cp
def mpi_inplace_add(amod, bmod, consts=(1,1)):
for (ap, bp) in zip(amod.parameters(), bmod.parameters()):
a_numpy = ap.data.numpy()
b_numpy = bp.data.numpy()
a_numpy[:] = a_numpy*consts[0] + b_numpy*consts[1]