-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtrain.py
705 lines (625 loc) · 28.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
import os
import sys
from typing import List
import torch
import transformers
import datasets
from datasets import load_from_disk
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
)
from transformers import LlamaTokenizer
from transformers import set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.optimization import get_scheduler
from transformers import Trainer
from typing import Optional
from torch import nn
_is_torch_generator_available = True
from transformers.file_utils import is_datasets_available
from torch.utils.data import RandomSampler, SequentialSampler
from transformers.trainer_pt_utils import (
LengthGroupedSampler,
)
from transformers.trainer_utils import has_length
from transformers.utils import is_sagemaker_mp_enabled
from transformers.utils import logging
from transformers.trainer_callback import TrainerCallback
logger = logging.get_logger(__name__)
SAVE_PATH = ''
class ResetReloraCallback(TrainerCallback):
def __init__(self, T=50, reset_optimizer=True,
relora_warmup_step=50, is_pretrain=False, relora_scheduler=False,
remora_types=2):
self.T = T
self.reset_optimizer = reset_optimizer
self.relora_warmup_step = relora_warmup_step
self.is_pretrain = is_pretrain
self.relora_scheduler = relora_scheduler
self.remora_types = remora_types
def on_step_end(self, args, state, control, **kwargs):
model = kwargs['model'].base_model
optimizer = kwargs['optimizer']
if state.global_step % self.T == 0 and state.global_step > 0:
for layer in kwargs['model'].base_model.model.model.layers:
for linear in [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj, layer.self_attn.o_proj,
layer.mlp.gate_proj, layer.mlp.down_proj, layer.mlp.up_proj]:
linear.merge()
linear.merged_adapters = []
if linear.use_mora['default']:
#print('mora change type', linear.mora_type['default'], end='->')
if self.remora_types == 4:
# 1->2->3->4
mora_type_map = {1:2, 2:3, 3:4, 4:1}
else:
mora_type_map = {1:2, 2:1}
print('mora change type', linear.mora_type['default'], end='->')
print(mora_type_map[linear.mora_type['default']])
linear.reset_lora_parameters('default', init_lora_weights=True,
mora_type=mora_type_map[linear.mora_type['default']])
#print(linear.mora_type['default'])
else:
linear.reset_lora_parameters('default', init_lora_weights=True)
# save base model
print('save base model', os.path.join(SAVE_PATH, "base-model"), self.reset_optimizer)
model.base_model.save_pretrained(os.path.join(SAVE_PATH, "base-model"))
if self.reset_optimizer:
# reset optimizer
from collections import defaultdict
#optimizer.__setstate__({'state': defaultdict(dict)})
if self.is_pretrain:
for name, param in model.named_parameters():
if 'lora' in name:
del optimizer.state[param]
else:
optimizer.state = defaultdict(dict)
if not self.relora_scheduler:
# if we use relora scheduler, we don't need to reset scheduler
# reset warmup steps to 50
scheduler = kwargs['lr_scheduler']
part = scheduler.lr_lambdas[0]
_,_, f = part.__reduce__()
f, _, k, n = f
k['num_warmup_steps'] = self.relora_warmup_step
k['num_training_steps'] = state.max_steps-state.global_step
scheduler._step_count = 0
scheduler.last_epoch = 0
scheduler._step_count = 0
for i in range(len(scheduler.base_lrs)):
scheduler.base_lrs[i] = scheduler._last_lr[0]
print('reset scheduler', scheduler._last_lr[0], scheduler.state_dict())
else:
print('not reset optimizer')
class OurTrainer(Trainer):
shuffle_data = True
lora_plus_lambda = 1
use_relora = False
use_relora_step = 50
use_relora_reset_optimizer = True
relora_warmup_step = 50
is_pretrain = False
relora_scheduler = False
remora_types = 2
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
if self.lora_plus_lambda > 1:
lora_b_params = set([n for n, p in opt_model.named_parameters() if 'lora_B' in n])
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n not in lora_b_params)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in lora_b_params)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate * self.lora_plus_lambda,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.learning_rate,
},
]
print(len(optimizer_grouped_parameters[0]['params']), len(optimizer_grouped_parameters[1]['params']), len(optimizer_grouped_parameters[2]['params']))
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.use_relora:
self.add_callback(ResetReloraCallback(T=self.use_relora_step,
reset_optimizer=self.use_relora_reset_optimizer,
relora_warmup_step=self.relora_warmup_step,
is_pretrain=self.is_pretrain,
relora_scheduler=self.relora_scheduler,
remora_types=self.remora_types))
# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
else None
)
else:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
elif not self.shuffle_data:
return SequentialSampler(self.train_dataset)
else:
return RandomSampler(self.train_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if 'labels' not in inputs:
inputs['labels'] = inputs['input_ids'].clone()
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
if is_peft_available() and isinstance(model, PeftModel):
model_name = unwrap_model(model.base_model)._get_name()
else:
model_name = unwrap_model(model)._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
# sync loss from all processes
self.state.loss = self._nested_gather(loss).mean().item()
# model.base_model.model.model.layers[0].mlp.gate_proj.lora_A['default'].weight
return (loss, outputs) if return_outputs else loss
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
"""
if self.lr_scheduler is None:
if self.relora_scheduler:
from training_utils import get_scheculer as relora_get_scheduler
if num_training_steps % self.use_relora_step > 0:
num_training_steps = ((num_training_steps // self.use_relora_step)+1)*self.use_relora_step
self.lr_scheduler = relora_get_scheduler(
scheduler_type='cosine_restarts',
optimizer=optimizer,
num_training_steps=num_training_steps,
warmup_steps=self.args.get_warmup_steps(num_training_steps),
min_lr_ratio=0.1,
cycle_length=self.use_relora_step,
restart_warmup_steps=self.relora_warmup_step,
adjust_step=0,
)
else:
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
def train(
# model/data params
base_model: str = "", # the only required argument
data_path: str = "",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 1,
learning_rate: float = 3e-4,
lr_scheduler_type: str = 'linear',
cutoff_len: int = 2048,
val_set_size: int = 0,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj","k_proj","v_proj","o_proj","gate_proj","down_proj","up_proj"
],
# llm hyperparams
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
seed: int = 42,
use_4bit: bool = False,
use_16bit: bool = False,
debug: bool = False,
full_ft: bool = False,
deepspeed: str = None,
warmup_steps: int = 100,
logging_steps: int = 10,
use_flash_atten: bool = False,
not_shuffle_data: bool = False,
max_steps: int = -1,
use_gptq: bool = False,
use_bf16: bool = False,
train_embhead: bool = False,
max_samples: int = -1,
save_total_limit: int = 7,
new_pad_token: bool = False,
save_steps: int = 200,
grad_checkpoint: bool = False,
pretrain: str = None,
# dora
use_dora: bool = False,
# lora+
lora_plus_lambda: int = 1,
# adalora
use_adalora: bool = False,
# asylora
use_asymmetriclora: bool = False,
# mora
use_mora: bool = False,
mora_type: int = 1,
# reslora
use_relora: bool = False,
use_relora_step: int = 50,
use_relora_not_reset_optimizer: bool = False ,
relora_warmup_step: int = 50,
relora_scheduler: bool = False,
remora_types: int = 4,
):
global SAVE_PATH
set_seed(seed)
gradient_accumulation_steps = batch_size // micro_batch_size
output_dir = wandb_run_name
# bug in transformers
if output_dir == wandb_run_name:
output_dir = 'save_' + output_dir
SAVE_PATH = output_dir
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
#torch.cuda.set_device(int(os.environ.get("LOCAL_RANK")))
gradient_accumulation_steps = gradient_accumulation_steps // world_size
torch.distributed.init_process_group("nccl")
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
device_id = rank % torch.cuda.device_count()
device = torch.device(device_id)
torch.cuda.set_device(device)
else:
rank = 0
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
MODEL_CLASS = AutoModelForCausalLM
if debug:
# random init
config = AutoConfig.from_pretrained(base_model)
config.num_hidden_layers = 1
model = MODEL_CLASS(config)
use_wandb = False
elif pretrain == '250m':
config = AutoConfig.from_pretrained('./configs/llama_250m.json')
model = MODEL_CLASS(config)
elif pretrain == '1b':
config = AutoConfig.from_pretrained('./configs/llama_1b.json')
model = MODEL_CLASS(config)
elif use_4bit:
from transformers import BitsAndBytesConfig
model = MODEL_CLASS.from_pretrained(
base_model,
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
),
torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
device_map=device_map,
)
else:
from transformers import BitsAndBytesConfig
torch_dtype = torch.bfloat16 if use_bf16 else torch.float16
model = MODEL_CLASS.from_pretrained(
base_model,
load_in_8bit=False if full_ft or (deepspeed and 'ds3' in deepspeed) or use_16bit else True, # if use zero3 not quantize
torch_dtype=torch_dtype,
device_map=device_map,
use_flash_attention_2=use_flash_atten,
)
if pretrain is not None:
print('saving init model')
if rank == 0:
model.save_pretrained(os.path.join(SAVE_PATH, "init-model"))
tokenizer = AutoTokenizer.from_pretrained(base_model)
if new_pad_token:
import deepspeed as dsp
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model,
model_max_length=512,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token_id = (
# NOTE: set this to eos token, set to unk(0) while make output nan
2 # unk. we want this to be different from the eos token
)
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=False)
num_added_tokens = tokenizer.add_special_tokens({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
})
assert num_added_tokens in [0, 1], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
embeddings = model.get_input_embeddings()
with dsp.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embeddings.weight.shape[0]:
model.resize_token_embeddings(len(tokenizer))
data_collator = transformers.DataCollatorForSeq2Seq(
tokenizer=tokenizer, model=model, padding="longest",
)
else:
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
data_collator = transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True,
)
if full_ft:
print('not use peft')
else:
if grad_checkpoint:
model.enable_input_require_grads()
if (deepspeed and 'ds3' not in deepspeed) and not use_16bit:
model = prepare_model_for_kbit_training(model)
# 'q_proj k_proj v_proj o_proj gate_proj down_proj up_proj'
if type(lora_target_modules) is str:
lora_target_modules = [lora_target_modules]
CONFIGCLASS = LoraConfig
if use_adalora:
from peft import AdaLoraConfig
CONFIGCLASS = AdaLoraConfig
kwargs = {}
if use_dora:
kwargs['use_dora'] = True
if use_mora:
kwargs['use_mora'] = True
kwargs['mora_type'] = mora_type
print('mora type', mora_type)
if train_embhead:
kwargs['modules_to_save'] = ['embed_tokens', 'lm_head', 'norm', 'input_layernorm', 'post_attention_layernorm' ]
config = CONFIGCLASS(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
**kwargs,
)
model = get_peft_model(model, config)
if use_4bit:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(torch.bfloat16 if use_bf16 else torch.float16)
#module = module.to(torch.float32)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
module = module.to(torch.bfloat16 if use_bf16 else torch.float16)
if not full_ft:
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if use_asymmetriclora:
from tqdm import tqdm
bar = tqdm(total=len([n for n, p in model.named_parameters() if 'lora_A' in n]))
asy_dict = {}
for name, param in model.named_parameters():
if 'lora_A' in name:
shape = param.shape
random_w = torch.rand(shape[1], max(shape[1], 4096)).cuda()
# slow here
U_rand, S_rand, V_rand = torch.linalg.svd(random_w)
print(name, shape, V_rand.std().item(), V_rand.mean().item())
param.data = V_rand[:, :shape[0]].T.contiguous()
asy_dict[name] = param.data.clone().cpu()
bar.update(1)
#param.requires_grad = False
#elif 'lora_B' in name:
#param.requires_grad = True
bar.close()
if max_samples > 0:
print(f'use max samples {max_samples}')
train_data = train_data.shuffle(seed=42)
train_data = train_data.select(range(max_samples))
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
warmup_ratio = 0
if 'meta-math' in data_path:
class A:
pass
data_args = A()
data_args.data_path = 'meta-math/MetaMathQA'
data_args.data_length = 1000000
from training_utils import make_supervised_data_module
lr_scheduler_type = 'cosine'
save_steps = 1000
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model,
model_max_length=512,
#padding_side="right",
padding_side="right",
use_fast=False,
)
#tokenizer.pad_token = "[PAD]"
#tokenizer.padding_side = "left"
tokenizer.pad_token_id = (
# NOTE: set this to eos token, set to unk(0) while make output nan
2 # unk. we want this to be different from the eos token
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
train_data = data_module['train_dataset']
data_collator = data_module['data_collator']
warmup_steps, warmup_ratio = 0, 0.03
else:
train_data = load_from_disk(data_path)
if 'open-instruct-tokenized' in data_path:
prev_len = len(train_data)
#train_data = train_data.filter(lambda x: max(x['input_ids']) < 32000,num_proc=48)
def remap(entry):
entry['input_ids'] = [x if x < 32000 else 0 for x in entry['input_ids']]
return entry
# this sample contain <pad> which is add new token in prev
print(f'filter out {prev_len - len(train_data)} samples')
if cutoff_len != 2048:
def cut_off(entry):
entry['input_ids'] = entry['input_ids'][:cutoff_len]
entry['attention_mask'] = entry['attention_mask'][:cutoff_len]
entry['labels'] = entry['labels'][:cutoff_len]
return entry
train_data = train_data.map(cut_off, num_proc=48)
train_data = train_data.filter(lambda example: (torch.LongTensor(example['labels']) != -100).any(), num_proc=48)
TRAINER_CLS = OurTrainer
trainer = TRAINER_CLS(
model=model,
train_dataset=train_data,
eval_dataset=None,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=warmup_steps,
warmup_ratio=warmup_ratio,
num_train_epochs=num_epochs,
max_steps=max_steps,
#max_steps=10000,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
fp16=False if use_bf16 else True,
bf16=use_bf16,
logging_steps=logging_steps,
optim="adamw_torch",
evaluation_strategy="no",
save_strategy="steps",
eval_steps=None,
save_steps=save_steps,
output_dir=output_dir,
save_total_limit=save_total_limit,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else "none",
run_name=wandb_run_name if use_wandb else None,
deepspeed=deepspeed,
seed=seed,
gradient_checkpointing=grad_checkpoint,
fsdp='full_shard auto_wrap' if full_ft and not deepspeed and pretrain is None else '',
fsdp_transformer_layer_cls_to_wrap='LlamaDecoderLayer' if full_ft and not deepspeed and pretrain is None else None,
),
data_collator=data_collator,
)
trainer.lora_plus_lambda = lora_plus_lambda
trainer.use_relora = use_relora
trainer.use_relora_step = use_relora_step
trainer.use_relora_reset_optimizer = not use_relora_not_reset_optimizer
trainer.relora_warmup_step = relora_warmup_step
trainer.is_pretrain = pretrain is not None
trainer.relora_scheduler = relora_scheduler
trainer.remora_types = remora_types
if not_shuffle_data:
trainer.shuffle_data = False
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if rank == 0:
model.save_pretrained(output_dir)
if __name__ == "__main__":
import fire
fire.Fire(train)