-
Notifications
You must be signed in to change notification settings - Fork 9
/
llama_pipe.py
489 lines (440 loc) · 21.4 KB
/
llama_pipe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
import torch
from torch import nn
import transformers
import accelerate
from pipeline_model import ComputeMetrics, LayerSpec, PipelineModel, move_data_to_device, set_data
from utils import DTYPE_MAP
class EmbeddingPipe(nn.Module):
def __init__(self, loader_util, orig, model, embedding_on_cpu=False):
super().__init__()
self.orig = orig
# The original model object, e.g. LlamaModel. Use a list so the nn.Module isn't registered to this module.
self.model = [model]
self.embedding_on_cpu = embedding_on_cpu
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
input_ids, attention_mask, position_ids, labels = inputs
original_device = input_ids.device
if self.embedding_on_cpu:
self.orig.to('cpu')
input_ids = input_ids.to('cpu')
inputs_embeds = self.orig(input_ids).to(original_device)
original_attention_mask = attention_mask
past_key_values = None # always None for training
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if self.model[0].config.model_type == 'mistral':
attention_mask = self.model[0]._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, False, False
)
else:
attention_mask = self.model[0]._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, False
)
if attention_mask is None:
# With FA, attention_mask can end up being None. But with deepspeed we can't pass None
# between GPUs. So force it back to the original attention_mask.
attention_mask = original_attention_mask
hidden_states = inputs_embeds
if self.model[0].config.model_type == 'gemma2':
normalizer = torch.tensor(self.model[0].config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# We have to do this so activation checkpointing with reentrant checkpoint function (the default) works.
# We could just use non-reentrant instead, but that has some weird bug with flash attn where the memory usage is very high.
hidden_states.requires_grad_(True)
# Without flash attn, the attention_mask is a float. With pipeline parallel, any float tensors sent across GPUs must have requires_grad.
# This is a workaround, theoretically there's no reason to require this.
if torch.is_floating_point(attention_mask):
attention_mask.requires_grad_(True)
return hidden_states, attention_mask, position_ids, labels
class LlamaRMSNormPipe(nn.Module):
def __init__(self, loader_util, orig):
super().__init__()
self.orig = orig
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
hidden_states, _, _, labels = inputs
return self.orig(hidden_states), labels
class LmHeadPipe(nn.Module):
def __init__(self, loader_util, lm_head, logit_scale=1.0, tie_weights=None):
super().__init__()
# Unlike the other wrapper classes, this is called lm_head and not orig. Because this is directly a
# nn.Linear layer, it needs to keep the same attribute name so quantization knows not to quantize it.
self.lm_head = lm_head
self.logit_scale = logit_scale
if tie_weights:
self.lm_head.weight.original_name = tie_weights
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
hidden_states, labels = inputs
return self.lm_head(hidden_states*self.logit_scale), labels
class Gemma2LmHeadPipe(nn.Module):
def __init__(self, loader_util, lm_head, model_config, tie_weights=None):
super().__init__()
self.lm_head = lm_head
self.model_config = model_config
if tie_weights:
self.lm_head.weight.original_name = tie_weights
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
hidden_states, labels = inputs
logits = self.lm_head(hidden_states)
if self.model_config.final_logit_softcapping is not None:
logits = logits / self.model_config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.model_config.final_logit_softcapping
return logits, labels
class LlamaDecoderLayerPipe(nn.Module):
def __init__(self, loader_util, orig):
super().__init__()
self.orig = orig
self.mlp_offloaded_to_cpu = False
loader_util.load_state_dict_into_module(self)
# A note on MLP offloading:
# We take advantage of how activation checkpointing works with reentrant checkpointing functions.
# During the forward pass, if gradients are disabled (eval or first forward pass of activation checkpointing)
# we offload the weights back to CPU at the end of the function. If gradients are enabled (second forward pass
# of activation checkpointing) we leave the weights on GPU, and use a backward hook to offload to CPU after the
# backward pass of this function is completed. This way the weights stay on the GPU for the backward pass.
def forward(self, inputs):
def set_cpu_data():
set_data(self.orig.mlp.up_proj, cpu_up_proj)
set_data(self.orig.mlp.down_proj, cpu_down_proj)
set_data(self.orig.mlp.gate_proj, cpu_gate_proj)
def set_cpu_data_hook(grad):
set_cpu_data()
return None
hidden_states, attention_mask, position_ids, labels = inputs
if self.mlp_offloaded_to_cpu:
if hidden_states.requires_grad:
hidden_states.register_hook(set_cpu_data_hook)
cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, hidden_states.device)
cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, hidden_states.device)
cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, hidden_states.device)
result = (self.orig(hidden_states, attention_mask=attention_mask, position_ids=position_ids)[0], attention_mask, position_ids, labels)
if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():
set_cpu_data()
return result
def offload_mlp_to_cpu(self):
self.mlp_offloaded_to_cpu = True
move_data_to_device(self.orig.mlp.up_proj, 'cpu')
move_data_to_device(self.orig.mlp.down_proj, 'cpu')
move_data_to_device(self.orig.mlp.gate_proj, 'cpu')
class Phi3DecoderLayerPipe(nn.Module):
def __init__(self, loader_util, orig):
super().__init__()
self.orig = orig
self.mlp_offloaded_to_cpu = False
loader_util.load_state_dict_into_module(self)
# A note on MLP offloading:
# We take advantage of how activation checkpointing works with reentrant checkpointing functions.
# During the forward pass, if gradients are disabled (eval or first forward pass of activation checkpointing)
# we offload the weights back to CPU at the end of the function. If gradients are enabled (second forward pass
# of activation checkpointing) we leave the weights on GPU, and use a backward hook to offload to CPU after the
# backward pass of this function is completed. This way the weights stay on the GPU for the backward pass.
def forward(self, inputs):
def set_cpu_data():
set_data(self.orig.mlp.gate_up_proj, cpu_up_proj)
set_data(self.orig.mlp.down_proj, cpu_down_proj)
def set_cpu_data_hook(grad):
set_cpu_data()
return None
hidden_states, attention_mask, position_ids, labels = inputs
if self.mlp_offloaded_to_cpu:
if hidden_states.requires_grad:
hidden_states.register_hook(set_cpu_data_hook)
cpu_up_proj = move_data_to_device(self.orig.mlp.gate_up_proj, hidden_states.device)
cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, hidden_states.device)
result = (self.orig(hidden_states, attention_mask=attention_mask, position_ids=position_ids)[0], attention_mask, position_ids, labels)
if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():
set_cpu_data()
return result
def offload_mlp_to_cpu(self):
self.mlp_offloaded_to_cpu = True
move_data_to_device(self.orig.mlp.gate_up_proj, 'cpu')
move_data_to_device(self.orig.mlp.down_proj, 'cpu')
# A little bit of inheritance and MRO trickery since LlamaForCausalLM.__init__ only takes a
# positional argument. We inherit PipelineModel first, but call LlamaForCausalLM init first,
# and make sure PipelineModel doesn't have a super().__init__() call.
class LlamaForCausalLMPipe(PipelineModel, transformers.LlamaForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.LlamaConfig.from_pretrained(config['model'])
model_config._attn_implementation = 'flash_attention_2'
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.LlamaForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
def initial_layer(inputs):
input_ids, attention_mask, labels = inputs
batch_size, seq_length = input_ids.shape[:2]
device = input_ids.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
return input_ids, attention_mask, position_ids, labels
result = [
initial_layer,
LayerSpec(
EmbeddingPipe,
self.loader_util,
self.model.embed_tokens,
self.model,
embedding_on_cpu=not self.train_config['full_fine_tune']
),
]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(LayerSpec(
LmHeadPipe,
self.loader_util,
self.lm_head,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
_estimated_size=0
))
result.append(
LayerSpec(
ComputeMetrics,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma
)
)
return result
class Qwen2ForCausalLMPipe(PipelineModel, transformers.Qwen2ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Qwen2Config.from_pretrained(config['model'])
model_config._attn_implementation = 'flash_attention_2'
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Qwen2ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
def initial_layer(inputs):
input_ids, attention_mask, labels = inputs
batch_size, seq_length = input_ids.shape[:2]
device = input_ids.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
return input_ids, attention_mask, position_ids, labels
result = [
initial_layer,
LayerSpec(
EmbeddingPipe,
self.loader_util,
self.model.embed_tokens,
self.model,
embedding_on_cpu=not self.train_config['full_fine_tune']
),
]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(LayerSpec(
LmHeadPipe,
self.loader_util,
self.lm_head,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
_estimated_size=0
))
result.append(
LayerSpec(
ComputeMetrics,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma
)
)
return result
class CohereForCausalLMPipe(PipelineModel, transformers.CohereForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.CohereConfig.from_pretrained(config['model'])
model_config._attn_implementation = 'flash_attention_2'
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.CohereForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
# the embedding table for this model is huge; load balance it better with some heuristics
embedding_relative_size = 4
def initial_layer(inputs):
input_ids, attention_mask, labels = inputs
batch_size, seq_length = input_ids.shape[:2]
device = input_ids.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
return input_ids, attention_mask, position_ids, labels
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [
initial_layer,
LayerSpec(
EmbeddingPipe,
self.loader_util,
self.model.embed_tokens,
self.model,
embedding_on_cpu=embedding_on_cpu,
_estimated_size=1 if embedding_on_cpu else embedding_relative_size,
),
]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
LmHeadPipe,
self.loader_util,
self.lm_head,
logit_scale=self.logit_scale,
tie_weights='model.embed_tokens.weight'
)
)
result.append(
LayerSpec(
ComputeMetrics,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
_estimated_size=embedding_relative_size
)
)
return result
class Phi3ForCausalLMPipe(PipelineModel, transformers.Phi3ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Phi3Config.from_pretrained(config['model'])
model_config._attn_implementation = 'flash_attention_2'
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Phi3ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
def initial_layer(inputs):
input_ids, attention_mask, labels = inputs
batch_size, seq_length = input_ids.shape[:2]
device = input_ids.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
return input_ids, attention_mask, position_ids, labels
result = [
initial_layer,
LayerSpec(
EmbeddingPipe,
self.loader_util,
self.model.embed_tokens,
self.model,
embedding_on_cpu=not self.train_config['full_fine_tune']
),
]
for block in self.model.layers:
result.append(LayerSpec(Phi3DecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(LayerSpec(LmHeadPipe, self.loader_util, self.lm_head, _estimated_size=0))
result.append(
LayerSpec(
ComputeMetrics,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma
)
)
return result
class Gemma2ForCausalLMPipe(PipelineModel, transformers.Gemma2ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Gemma2Config.from_pretrained(config['model'])
# TODO: change this when Gemma works with other attn implementations
model_config._attn_implementation = 'eager'
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Gemma2ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
# the embedding table for this model is huge; load balance it better with some heuristics
# this value optimized for LoRA, pipeline_stages=2
embedding_relative_size = 8
def initial_layer(inputs):
input_ids, attention_mask, labels = inputs
batch_size, seq_length = input_ids.shape[:2]
device = input_ids.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
return input_ids, attention_mask, position_ids, labels
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [
initial_layer,
LayerSpec(
EmbeddingPipe,
self.loader_util,
self.model.embed_tokens,
self.model,
embedding_on_cpu=embedding_on_cpu,
_estimated_size=1 if embedding_on_cpu else embedding_relative_size,
),
]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(LayerSpec(Gemma2LmHeadPipe, self.loader_util, self.lm_head, self.config, tie_weights='model.embed_tokens.weight'))
result.append(
LayerSpec(
ComputeMetrics,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
_estimated_size=embedding_relative_size
)
)
return result
class MistralForCausalLMPipe(PipelineModel, transformers.MistralForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.MistralConfig.from_pretrained(config['model'])
model_config._attn_implementation = 'flash_attention_2'
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.MistralForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
def initial_layer(inputs):
input_ids, attention_mask, labels = inputs
batch_size, seq_length = input_ids.shape[:2]
device = input_ids.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
return input_ids, attention_mask, position_ids, labels
result = [
initial_layer,
LayerSpec(
EmbeddingPipe,
self.loader_util,
self.model.embed_tokens,
self.model,
embedding_on_cpu=not self.train_config['full_fine_tune']
),
]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(LayerSpec(LmHeadPipe, self.loader_util, self.lm_head, _estimated_size=0))
result.append(
LayerSpec(
ComputeMetrics,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma
)
)
return result