Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

实现断点恢复的逻辑和自动选择可用端口 #728

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,6 @@ pic2.py

# Pyre type checker
.pyre/
wandb/
*.txt
result/
2 changes: 1 addition & 1 deletion FlagEmbedding/baai_general_embedding/finetune/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding, PreTrainedTokenizer

from .arguments import DataArguments
from arguments import DataArguments


class TrainDatasetForEmbedding(Dataset):
Expand Down
2 changes: 2 additions & 0 deletions FlagEmbedding/baai_general_embedding/finetune/hn_mine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
corpus = []
queries = []
train_data = []
# input_file is jsonl, jsonl也是由 query,pos,neg三元组组成,并且 pos 和 neg 都全部放入 corpus 中, query放入 querys 中
for line in open(input_file):
line = json.loads(line.strip())
train_data.append(line)
Expand All @@ -67,6 +68,7 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
corpus.extend(line['neg'])
queries.append(line['query'])

# candidate pool和 corpus 库是二选一的
if candidate_pool is not None:
if not isinstance(candidate_pool, list):
candidate_pool = get_corpus(candidate_pool)
Expand Down
10 changes: 6 additions & 4 deletions FlagEmbedding/baai_general_embedding/finetune/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,21 @@ def __init__(self,
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)

def sentence_embedding(self, hidden_state, mask):
def sentence_embedding(self, output, mask):
if self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
s = torch.sum(output.last_hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
return output.last_hidden_state[:, 0]
elif self.sentence_pooling_method == 'cls_after_pooler':
return output.pooler_output

def encode(self, features):
if features is None:
return None
psg_out = self.model(**features, return_dict=True)
p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask'])
p_reps = self.sentence_embedding(psg_out, features['attention_mask'])
if self.normlized:
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
return p_reps.contiguous()
Expand Down
67 changes: 49 additions & 18 deletions FlagEmbedding/baai_general_embedding/finetune/run.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
import logging
import os
from pathlib import Path
os.environ["WANDB_DISABLED"]="true"

from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)

from .arguments import ModelArguments, DataArguments, \
from arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForEmbedding, EmbedCollator
from .modeling import BiEncoderModel
from .trainer import BiTrainer
from data import TrainDatasetForEmbedding, EmbedCollator
from modeling import BiEncoderModel
from trainer import BiTrainer
import sys
import transformers
sys.path.append("/opt/tiger/FlagEmbedding")
from utils import get_complete_last_checkpoint

logger = logging.getLogger(__name__)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict)

def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# check and load checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_complete_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.info(
f"Output directory ({training_args.output_dir}) already exists and is empty."
"Train from scratch"
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

# Setup logging
logging.basicConfig(
Expand Down Expand Up @@ -75,6 +92,14 @@ def main():
temperature=training_args.temperature,
use_inbatch_neg=training_args.use_inbatch_neg,
)

checkpoint = None
if training_args.resume_from_checkpoint is not None:
logger.info(f"train start from {training_args.resume_from_checkpoint}")
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
logger.info(f"train start from {last_checkpoint}")
checkpoint = last_checkpoint

if training_args.fix_position_embedding:
for k, v in model.named_parameters():
Expand All @@ -99,12 +124,18 @@ def main():
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

# Training
trainer.train()
trainer.save_model()
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
# try:
# trainer.train(resume_from_checkpoint=checkpoint)
# except:
# trainer.train()
# trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
# if trainer.is_world_process_zero():
# tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions FlagEmbedding/reranker/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ModelArguments:
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
model_type: str = field(default="CrossEncoder")


@dataclass
Expand Down
25 changes: 24 additions & 1 deletion FlagEmbedding/reranker/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding

from .arguments import DataArguments
from arguments import DataArguments


class TrainDatasetForCE(Dataset):
Expand Down Expand Up @@ -62,6 +62,29 @@ def __getitem__(self, item) -> List[BatchEncoding]:

return batch_data

class TrainDatasetForCL(TrainDatasetForCE):
def create_one_example(self, input):
item = self.tokenizer(
input,
truncation=True,
max_length=self.args.max_len,
padding=False,
)
return item

def __getitem__(self, item) -> List[BatchEncoding]:
query = self.dataset[item]['query']
pos = random.choice(self.dataset[item]['pos'])
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
batch_data = []
batch_data.append(self.create_one_example(query))
batch_data.append(self.create_one_example(pos))
for neg in negs: batch_data.append(self.create_one_example(neg))
return batch_data


@dataclass
Expand Down
125 changes: 125 additions & 0 deletions FlagEmbedding/reranker/embedding_proj_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import logging
import os
from pathlib import Path

from transformers import AutoConfig, AutoTokenizer, TrainingArguments
from transformers import (
HfArgumentParser,
set_seed,
)
from arguments import ModelArguments, DataArguments
from data import TrainDatasetForCE, GroupCollator
from modeling import CLProjEncoder
from trainer import CETrainer

logger = logging.getLogger(__name__)
from pprint import pprint as pp
import sys
sys.path.append("/opt/tiger/FlagEmbedding")
from FlagEmbedding.reranker.data import TrainDatasetForCL
from utils import get_complete_last_checkpoint
import transformers
import os
os.environ["WANDB_DISABLED"]="true"

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict)

def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments

# for args in (model_args, data_args, training_args): pp(args)

# check and load checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_complete_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.info(
f"Output directory ({training_args.output_dir}) already exists and is empty."
"Train from scratch"
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)

set_seed(training_args.seed)

num_labels = 1

tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)

config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
trust_remote_code=True
)
_model_class = CLProjEncoder

model = _model_class.from_pretrained(
model_args, data_args, training_args,
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
trust_remote_code=True
)

checkpoint = None
if training_args.resume_from_checkpoint is not None:
logger.info(f"train start from {training_args.resume_from_checkpoint}")
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
logger.info(f"train start from {last_checkpoint}")
checkpoint = last_checkpoint

train_dataset = TrainDatasetForCL(data_args, tokenizer=tokenizer)
_trainer_class = CETrainer

trainer = _trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=GroupCollator(tokenizer), #这里依旧是拍平
tokenizer=tokenizer
)
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)

if __name__ == "__main__":
main()
Loading