diff --git a/LM_Cocktail/LM_Cocktail/cocktail.py b/LM_Cocktail/LM_Cocktail/cocktail.py index 3a3ed00e..8da72438 100644 --- a/LM_Cocktail/LM_Cocktail/cocktail.py +++ b/LM_Cocktail/LM_Cocktail/cocktail.py @@ -81,7 +81,7 @@ def mix_models_with_data(model_names_or_paths: List[str], mix model based on given a few examples Args: model_names_or_paths (List[str]): a list of names or paths to models - model_type (str): type of model to mix, should be in ["decoder", "encoder"] + model_type (str): type of model to mix, should be in ["decoder", "encoder", "encoder-decoder", "reranker"] example_data (List[Any]): a list of examples temperature (float, optional): temperature can impact the distribution of weights . Defaults to 3.0. batch_size (int, optional): batch size to compute loss. Defaults to 2. @@ -93,7 +93,7 @@ def mix_models_with_data(model_names_or_paths: List[str], new model """ - assert model_type in ['decoder', 'encoder', 'encoder-decoder'] + assert model_type in ['decoder', 'encoder', 'encoder-decoder', 'reranker'] model = load_model(model_names_or_paths[0], model_type=model_type) tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0], trust_remote_code=True) diff --git a/LM_Cocktail/LM_Cocktail/utils.py b/LM_Cocktail/LM_Cocktail/utils.py index 4091a461..c6b74b0a 100644 --- a/LM_Cocktail/LM_Cocktail/utils.py +++ b/LM_Cocktail/LM_Cocktail/utils.py @@ -6,9 +6,10 @@ import random import numpy as np from tqdm import tqdm -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple -from transformers import AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, is_torch_npu_available +from transformers import AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, is_torch_npu_available, BatchEncoding +from transformers.modeling_outputs import SequenceClassifierOutput def load_llm(model_name:str, trust_remote_code:bool): @@ -142,6 +143,9 @@ def compute_weights(base_model, tokenizer, param_list: List[Dict], model_type: s elif model_type == 'encoder-decoder': input_data = preprocess_data_for_seq2seq(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length) loss_func = seq2seq_loss + elif model_type == "reranker": + input_data = preprocess_data_for_reranker(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length, neg_number=neg_number) + loss_func = reranker_loss example_loss = [] with torch.no_grad(): @@ -235,6 +239,44 @@ def generate_embeddings(model, inputs): loss = float(loss / len(input_data)) return float(loss) +def preprocess_data_for_reranker(example_data, tokenizer, device, batch_size:int=64, max_input_length:int=512, neg_number:int=7): + input_data = [] + pending_encoding = [] + for e in tqdm(example_data, desc="Tokenizing"): + for pos in e['pos']: + pending_encoding.append((e['query'], pos)) + + if len(e['neg']) < neg_number: + num = math.ceil(neg_number / len(e['neg'])) + negs = random.sample(e['neg'] * num, neg_number) + else: + negs = random.sample(e['neg'], neg_number) + + for neg in negs: + pending_encoding.append((e['query'], neg)) + + input_data.append(tokenizer(pending_encoding, padding=True, truncation=True, max_length = max_input_length, return_tensors='pt').to(device)) + pending_encoding.clear() + if len(pending_encoding) >0: + input_data.append(tokenizer(pending_encoding, padding=True, truncation=True, max_length = max_input_length, return_tensors='pt').to(device)) + pending_encoding.clear() + return input_data + +def reranker_loss(base_model: AutoModelForSequenceClassification, input_data: List[BatchEncoding]): + with torch.no_grad(): + loss = 0 + p_bar = tqdm(enumerate(input_data), total=len(input_data)) + for idx, batch_encoding in p_bar: + ranker_out: SequenceClassifierOutput = base_model(**batch_encoding, return_dict=True) + logits = ranker_out.logits.view(-1) + cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean') + target_label = torch.eye(1, logits.shape[0], dtype=torch.float, device=logits.device).view(-1) + batch_loss = cross_entropy(logits, target_label) + loss += batch_loss.cpu() + p_bar.set_description(f"Calculating loss {loss/(idx+1)}") + loss = float(loss/ len(input_data)) + return float(loss) + def preprocess_data_for_llm(example_data, tokenizer, device, batch_size:int=2, max_input_length:int=2048): batch_input_ids = []