-
Notifications
You must be signed in to change notification settings - Fork 23
/
workload_applyer.py
executable file
·397 lines (372 loc) · 15.5 KB
/
workload_applyer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""
Copyright (c) 2021, Alibaba Group;
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import sys
import math
import time
from utils.utils import WorkloadWriter, CommGroup, CommType, ReduceOp
from utils.benchmark_logger import bench_logger
import utils.utils as utils
class WorkloadApplyer:
def __init__(self, workload=None, args=None, filename=None) -> None:
if workload is None or args is None:
assert (
filename is None
), f"you should either pass workload,args or filename to init WorkloadApplyer"
workload, args = WorkloadWriter.load_workload(filename)
# if not hasattr(args, "backend"):
# args.backend = "nccl"
# torch.distributed.init_process_group(backend=args.backend)
self.args = args
world_size = torch.distributed.get_world_size()
# args.rank = torch.distributed.get_rank()
if args.world_size != world_size:
print(
f"WARNNING: world_size is {args.world_size} when generating workload, but now world size is {world_size}"
)
args.world_size = torch.distributed.get_world_size()
device_count = torch.cuda.device_count()
self.device = args.rank % device_count
torch.cuda.set_device(self.device)
self.device = torch.cuda.current_device()
self.comm_group_info, self.pp_global_rank_info = (
self._generate_dp_tp_pp_ep_groups()
)
self.workload = workload
self.comm_type_function = {
CommType.barrier: self._apply_barrier,
CommType.broadcast: self._apply_broadcast,
CommType.reduce: self._apply_reduce,
CommType.all_reduce: self._apply_all_reduce,
CommType.all_gather: self._apply_all_gather,
CommType.reduce_scatter: self._apply_reduce_scatter,
CommType.isend: self._apply_p2pcommunication,
CommType.irecv: self._apply_p2pcommunication,
CommType.all_gather_into_tensor: self._apply_all_gather,
CommType.reduce_scatter_tensor: self._apply_reduce_scatter,
CommType.computation: self._apply_computation,
CommType.all_to_all: self._apply_all_to_all,
CommType.epoch_end: bench_logger.end_epoch,
}
cal_tuple_num = lambda t: math.prod(t[0]) + math.prod(t[1])
max_msg_size = max(
[
(
item.msg_size
if isinstance(item.msg_size, int)
else cal_tuple_num(item.msg_size)
)
for item in self.workload.workload
]
)
self.gemm_cache = {}
self.computation_aiob = False
if args.aiob_enable and args.frame == "Megatron":
self.computation_aiob = True
self.skip_computation = False
self.always_apply_gemm = False
self.gemm_iters = 1 if self.always_apply_gemm else 50
self.buffer = torch.empty(
(max_msg_size,), dtype=torch.bfloat16, device=self.device
)
def _generate_dp_tp_pp_ep_groups(self):
"""Borrow from Megatron-LM"""
all_data_parallel_group_ranks = []
world_size = self.args.world_size
rank = torch.distributed.get_rank()
self.rank = rank
tensor_model_parallel_size, pipeline_model_parallel_size, data_parallel_size,expert_model_parallel_size = (
self.args.tensor_model_parallel_size,
self.args.pipeline_model_parallel,
self.args.dp_num,
self.args.expert_model_parallel_size,
)
rank_generator = utils.RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=self.args.context_parallel_size,
order='tp-cp-ep-dp-pp',
)
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
group = torch.distributed.new_group(
ranks
)
if rank in ranks:
ep_group = group
for ranks in rank_generator.get_ranks('tp'):
group = torch.distributed.new_group(
ranks
)
if rank in ranks:
tp_group = group
for ranks in rank_generator.get_ranks('pp'):
group = torch.distributed.new_group(
ranks
)
if rank in ranks:
pp_group = group
pp_global_rank = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
# if len(ranks) > 1:
# embedding_ranks = [ranks[0], ranks[-1]]
# position_embedding_ranks = [ranks[0]]
# if self.args.pipeline_model_parallel_split_rank is not None:
# if ranks[self.args.pipeline_model_parallel_split_rank] not in embedding_ranks:
# embedding_ranks = [
# ranks[0],
# ranks[self.args.pipeline_model_parallel_split_rank],
# ranks[-1],
# ]
# if ranks[self.args.pipeline_model_parallel_split_rank] not in position_embedding_ranks:
# position_embedding_ranks = [ranks[0], ranks[self.args.pipeline_model_parallel_split_rank]]
# else:
# embedding_ranks = ranks
# position_embedding_ranks = ranks
# group = torch.distributed.new_group(
# embedding_ranks
# )
# if rank in embedding_ranks:
# _EMBEDDING_GROUP = group
# if rank in ranks:
# _EMBEDDING_GLOBAL_RANKS = embedding_ranks
# group = torch.distributed.new_group(
# position_embedding_ranks,
# )
# if rank in position_embedding_ranks:
# _POSITION_EMBEDDING_GROUP = group
# if rank in ranks:
# _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
for ranks in rank_generator.get_ranks('dp'):
group = torch.distributed.new_group(
ranks
)
if rank in ranks:
dp_group = group
for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True):
group = torch.distributed.new_group(
ranks
)
if rank in ranks:
ep_tp_group = group
for ranks in rank_generator.get_ranks('dp', independent_ep=True):
group = torch.distributed.new_group(
ranks
)
if rank in ranks:
ep_dp_group = group
return {
CommGroup.tp_group: tp_group,
CommGroup.dp_group: dp_group,
CommGroup.pp_group: pp_group,
CommGroup.ep_group: ep_group,
CommGroup.ep_tp_group: ep_tp_group,
CommGroup.ep_dp_group: ep_dp_group,
}, pp_global_rank
def _get_pipeline_parallel_size(self):
group = self.comm_group_info["pp_group"]
pp_group_size = torch.distributed.get_world_size(group)
return pp_group_size
def _get_pipeline_parallel_rank(self):
group = self.comm_group_info["pp_group"]
pp_rank = torch.distributed.get_rank(group)
return pp_rank
def _get_pipeline_prev_rank(self):
rank_in_pipeline = self._get_pipeline_parallel_rank()
world_size = self._get_pipeline_parallel_size()
return self.pp_global_rank_info[(rank_in_pipeline - 1) % world_size]
def _get_pipeline_next_rank(self):
rank_in_pipeline = self._get_pipeline_parallel_rank()
world_size = self._get_pipeline_parallel_size()
return self.pp_global_rank_info[(rank_in_pipeline + 1) % world_size]
@bench_logger.log_timing("comm")
def _apply_p2pcommunication(self, item):
ops = []
tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2)
if item.additional == "send_prev":
if self._get_pipeline_parallel_rank() != 0:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor, self._get_pipeline_prev_rank()
)
ops.append(send_prev_op)
else:
pass
if item.additional == "send_next":
if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor, self._get_pipeline_next_rank()
)
ops.append(send_next_op)
else:
pass
if item.additional == "recv_prev":
if self._get_pipeline_parallel_rank() != 0:
tensor_recv_prev = torch.empty(
item.msg_size // 2, dtype=torch.bfloat16, device=self.device
)
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
self._get_pipeline_prev_rank(),
)
ops.append(recv_prev_op)
else:
pass
if item.additional == "recv_next":
if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1:
tensor_recv_next = torch.empty(
item.msg_size // 2, dtype=torch.bfloat16, device=self.device
)
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
self._get_pipeline_next_rank(),
)
ops.append(recv_next_op)
else:
pass
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
def _apply_barrier(self, item):
torch.distributed.barrier()
@bench_logger.log_timing("comm")
def _apply_broadcast(self, item):
tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2)
group = self.comm_group_info[item.comm_group]
src = torch.distributed.get_global_rank(group, 0)
return torch.distributed.broadcast(
tensor=tensor, src=src, group=group, async_op=False
)
@bench_logger.log_timing("comm")
def _apply_reduce(self, item):
tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2)
group = self.comm_group_info[item.comm_group]
dst = item.dst
return torch.distributed.reduce(
tensor=tensor,
dst=dst,
op=torch.distributed.ReduceOp.SUM,
group=group,
async_op=False,
)
@bench_logger.log_timing("comm")
def _apply_all_reduce(self, item):
tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2)
group = self.comm_group_info[item.comm_group]
return torch.distributed.all_reduce(
tensor=tensor,
op=torch.distributed.ReduceOp.SUM,
group=group,
async_op=False,
)
@bench_logger.log_timing("comm")
def _apply_all_gather(self, item):
group = self.comm_group_info[item.comm_group]
num_elements = item.msg_size // 2
padding_size = (
(group.size() - num_elements % group.size())
if num_elements % group.size()
else 0
)
num_elements = num_elements + padding_size
output_tensor = torch.narrow(self.buffer, 0, 0, num_elements)
input_tensor_size = output_tensor.numel() // group.size()
group_rank = torch.distributed.get_group_rank(group, self.rank)
input_tensor = torch.narrow(
output_tensor, 0, group_rank * input_tensor_size, input_tensor_size
)
return torch.distributed.all_gather_into_tensor(
output_tensor, input_tensor, group=group, async_op=False
)
@bench_logger.log_timing("comm")
def _overlap(self, item):
item.additional = 'overlap'
@bench_logger.log_timing("comm")
def _apply_reduce_scatter(self, item):
group = self.comm_group_info[item.comm_group]
num_elements = item.msg_size // 2
padding_size = (
(group.size() - num_elements % group.size())
if num_elements % group.size()
else 0
)
num_elements = num_elements + padding_size
input_tensor = torch.narrow(self.buffer, 0, 0, num_elements)
group = self.comm_group_info[item.comm_group]
output_tensor_size = input_tensor.numel() // group.size()
group_rank = torch.distributed.get_group_rank(group, self.rank)
output_tensor = torch.narrow(
input_tensor, 0, group_rank * output_tensor_size, output_tensor_size
)
return torch.distributed.reduce_scatter_tensor(
output_tensor, input_tensor, group=group, async_op=False
)
@bench_logger.log_timing("comm")
def _apply_all_to_all(self, item):
group = self.comm_group_info[item.comm_group]
num_elements = item.msg_size // 2
input_tensor = torch.narrow(self.buffer, 0, 0, num_elements)
# output_tensor = torch.narrow(self.buffer, 0, 0 , num_elements)
output_tensor = torch.empty(
num_elements * group.size(),
dtype=self.buffer.dtype,
device=self.buffer.device,
)
return torch.distributed.all_to_all_single(
output_tensor, input_tensor, group=group
)
@bench_logger.log_timing("comp")
def _apply_computation(self, item):
if self.skip_computation:
return
if self.computation_aiob:
time.sleep(item._elapsed_time/ 1e9)
else:
# item.msg_size = 1
input_shape1, input_shape2 = item.msg_size
A, B = torch.rand(input_shape1, device=self.device), torch.rand(
input_shape2, device=self.device
)
torch.matmul(A, B)
return
def apply_workload(self):
torch.cuda.synchronize(self.device)
start = time.perf_counter()
key = "backward"
for item in self.workload.workload:
if (
self.computation_aiob
and item.comm_type == CommType.all_reduce
and key in item.stage
):
comm_func = self.comm_type_function[item.comm_type]
# comm_func = self._overlap()
# comm_func(item)
else:
comm_func = self.comm_type_function[item.comm_type]
comm_func(item)
torch.cuda.synchronize(self.device)
end = time.perf_counter()
return end - start
if __name__ == "__main__":
filename = "results/model_workload/local_deepspeed_stage3.csv"
applyer = WorkloadApplyer(filename=filename)
applyer.apply_workload()
# timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
if torch.distributed.get_rank() == 0:
bench_logger.analyze_comm_log(bench_logger.comm_log)