-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
130 lines (105 loc) · 4.2 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import os
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers import AutoConfig, AutoTokenizer, AutoModel
from transformers import (
HfArgumentParser,
set_seed,
)
from transformers import TrainingArguments
from transformers.trainer import Trainer
from transformers import LlamaModel, LlamaConfig, LlamaTokenizer
from arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from data import TrainDatasetForEmbedding, EmbedCollator
from modeling import LlamaModelEmbedding
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
# Set seed
set_seed(training_args.seed)
num_labels = 1
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)
config = LlamaConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
)
config.temperature = training_args.temperature
config.normalized = model_args.normalized
config.sentence_pooling_method = model_args.sentence_pooling_method
logger.info('Config: %s', config)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"
model = LlamaModelEmbedding.from_pretrained(
model_args.model_name_or_path,
config = config,
use_flash_attention_2 = True
)
if training_args.fix_position_embedding:
for k, v in model.named_parameters():
if "position_embeddings" in k:
logging.info(f"Freeze the parameters for {k}")
v.requires_grad = False
train_dataset = TrainDatasetForEmbedding(args=data_args, tokenizer=tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=EmbedCollator(
tokenizer,
query_max_len=data_args.query_max_len,
passage_max_len=data_args.passage_max_len
),
tokenizer=tokenizer
)
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
try:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
except Exception as e:
e = str(e)
print(e)
if checkpoint and 'checkpoint' in e:
os.system(f'mv {checkpoint} {checkpoint}-temp')
if __name__ == "__main__":
main()