diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json new file mode 100755 index 00000000..12d01a02 --- /dev/null +++ b/models/llava/7b_config.json @@ -0,0 +1,54 @@ +{ + "embedding": ["vision_language"], + "vision_language_emb":{ + "vision_encoder":{ + "image_height": 336, + "image_width": 336, + "patch_size": 14, + "emb_size": 1024, + "feedforward_size": 4096, + "hidden_size": 1024, + "hidden_act": "gelu_quick", + "heads_num": 16, + "layers_num": 24, + "dropout": 0.0, + "max_seq_length": 577, + "embedding": ["patch", "pos"], + "patch_proj_bias": false, + "remove_embedding_layernorm": false, + "remove_transformer_bias": false, + "rotary_position_embedding": false, + "encoder": "transformer", + "feed_forward": "dense", + "mask": "fully_visible", + "layernorm_positioning": "pre", + "layernorm": "normal", + "has_cls": false + }, + "projection":{ + "mlp_hidden_size": 4096, + "num_mlp_layer": 2 + }, + "text":{ + "embedding": ["word"] + } + }, + "emb_size": 4096, + "feedforward_size": 11008, + "hidden_size": 4096, + "hidden_act": "silu", + "heads_num": 32, + "layers_num": 32, + "dropout": 0.0, + "data_processor": "llava", + "max_seq_length": 2048, + "remove_transformer_bias": true, + "remove_embedding_layernorm": true, + "rotary_position_embedding": true, + "encoder": "transformer", + "feed_forward": "gated", + "mask": "causal", + "layernorm_positioning": "pre", + "layernorm": "rms", + "target": ["lm"] + } diff --git a/preprocess.py b/preprocess.py index 7f61c54c..8c0f1a00 100644 --- a/preprocess.py +++ b/preprocess.py @@ -28,12 +28,14 @@ def main(): parser.add_argument("--data_processor", choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls", "prefixlm", "gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle", - "llm_pretrain", "llm_sft"], default="bert", + "llm_pretrain", "llm_sft", "llava"], default="bert", help="The data processor of the pretraining model.") parser.add_argument("--docs_buffer_size", type=int, default=100000, help="The buffer size of documents in memory, specific to targets that require negative sampling.") parser.add_argument("--seq_length", type=int, default=128, help="Sequence length of instances.") parser.add_argument("--tgt_seq_length", type=int, default=128, help="Target sequence length of instances.") + parser.add_argument("--vision_seq_length_in_VL", type=int, default=576, + help="Number of image patches in the vision language model(LLaVa).") parser.add_argument("--dup_factor", type=int, default=5, help="Duplicate instances multiple times.") parser.add_argument("--short_seq_prob", type=float, default=0.1, diff --git a/pretrain.py b/pretrain.py index 14c744c3..8a667160 100644 --- a/pretrain.py +++ b/pretrain.py @@ -13,6 +13,8 @@ def main(): help="Path of the preprocessed dataset.") parser.add_argument("--pretrained_model_path", type=str, default=None, help="Path of the pretrained model.") + parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, + help="Path of the vision pretrained model in the vision language embedding.") parser.add_argument("--output_model_path", type=str, required=True, help="Path of the output model.") parser.add_argument("--config_path", type=str, default="models/bert/base_config.json", diff --git a/scripts/convert_llava_from_huggingface_to_tencentpretrain.py b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py new file mode 100755 index 00000000..606a0f9c --- /dev/null +++ b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py @@ -0,0 +1,77 @@ +import argparse +import collections +import torch +import os +import json + + +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("--input_model_path", type=str, default="models/llava-v1.5-7b/", + help=".") +parser.add_argument("--output_model_path", type=str, default="models/llava-v1.5-7b.bin", + help=".") +parser.add_argument("--type", choices=["7B", "13B", "33B", "65B"], default="7B") + +args = parser.parse_args() + +model_config = {"7B" : [32, 4096, 32], + "13B": [40, 5120, 40], + "33B": [60, 6656, 52], + "65B": [80, 8192, 64] + } + +layers_num, dim, n_heads = model_config[args.type] + +files = os.listdir(args.input_model_path) +model_files = [f for f in files if f[-4:] == ".bin"] +input_models = {f: torch.load(os.path.join(args.input_model_path, f), map_location="cpu") for f in model_files} + +with open(os.path.join(args.input_model_path, "pytorch_model.bin.index.json")) as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + +output_model = collections.OrderedDict() + +def get_weight_from_name(layer_name): + return input_models[weight_map[layer_name]][layer_name] + +def unpermute(w): + return w.reshape(n_heads, 2, dim // n_heads // 2, dim).transpose(2, 1).reshape(dim, dim) + +output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight") + +for i in range(layers_num): + + output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.weight"] = \ + unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.q_proj.weight")) + output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.weight"] = \ + unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.k_proj.weight")) + + output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".self_attn.v_proj.weight") + output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".self_attn.o_proj.weight") + + output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".input_layernorm.weight") + + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".mlp.gate_proj.weight") + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".mlp.up_proj.weight") + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".mlp.down_proj.weight") + + output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".post_attention_layernorm.weight") + +output_model["encoder.layer_norm.weight"] = get_weight_from_name("model.norm.weight") +output_model["target.lm.output_layer.weight"] = get_weight_from_name("lm_head.weight") + +output_model["embedding.vision_language.projection.0.weight"] = get_weight_from_name("model.mm_projector.0.weight") +output_model["embedding.vision_language.projection.0.bias"] = get_weight_from_name("model.mm_projector.0.bias") +output_model["embedding.vision_language.projection.2.weight"] = get_weight_from_name("model.mm_projector.2.weight") +output_model["embedding.vision_language.projection.2.bias"] = get_weight_from_name("model.mm_projector.2.bias") + +torch.save(output_model, args.output_model_path) diff --git a/scripts/convert_llm_in_llava.py b/scripts/convert_llm_in_llava.py new file mode 100755 index 00000000..652bcb26 --- /dev/null +++ b/scripts/convert_llm_in_llava.py @@ -0,0 +1,30 @@ +import argparse +import collections +import torch + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--input_model_path", type=str, default="models/input_model.bin", + help=".") + parser.add_argument("--output_model_path", type=str, default="models/output_model.bin", + help=".") + + + args = parser.parse_args() + + input_model = torch.load(args.input_model_path, map_location="cpu") + + output_model = collections.OrderedDict() + + for k in input_model.keys(): + if k == "embedding.word.embedding.weight": + output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = input_model[k] + else: + output_model[k] = input_model[k] + + torch.save(output_model, args.output_model_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/convert_model_add_prefix.py b/scripts/convert_model_add_prefix.py new file mode 100755 index 00000000..f75ca66e --- /dev/null +++ b/scripts/convert_model_add_prefix.py @@ -0,0 +1,28 @@ +import argparse +import collections +import torch + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--input_model_path", type=str, default="models/input_model.bin", + help=".") + parser.add_argument("--output_model_path", type=str, default="models/output_model.bin", + help=".") + parser.add_argument("--prefix", type=str, default="", help="prefix to add") + + + args = parser.parse_args() + + input_model = torch.load(args.input_model_path, map_location="cpu") + + output_model = collections.OrderedDict() + + for k in input_model.keys(): + output_model[args.prefix + k] = input_model[k] + + torch.save(output_model, args.output_model_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py new file mode 100755 index 00000000..f4d960ed --- /dev/null +++ b/scripts/generate_lm_llava_deepspeed.py @@ -0,0 +1,333 @@ +""" + This script provides an exmaple to wrap TencentPretrain for generation. + Given the beginning of a text, language model generates the rest. +""" +import sys +import os +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import transforms +from torchvision.io import read_image +from torchvision.io.image import ImageReadMode +import imghdr +import deepspeed +import numpy as np +from PIL import Image +import torchvision.transforms.functional as transform + +tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(tencentpretrain_dir) + +from tencentpretrain.embeddings import * +from tencentpretrain.encoders import * +from tencentpretrain.targets import * +from tencentpretrain.utils.constants import * +from tencentpretrain.utils import * +from tencentpretrain.utils.config import load_hyperparam +from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts +from tencentpretrain.opts import deepspeed_opts +from tencentpretrain.utils.logging import init_logger +from tencentpretrain.model_loader import _load_state_dict_into_model, load_model +from tencentpretrain.utils.misc import pooling, ZeroOneNormalize, expand2square + + +class LLaVaGenerate(nn.Module): + def __init__(self, args): + super(LLaVaGenerate, self).__init__() + self.args = args + self.embedding = Embedding(args) + for embedding_name in args.embedding: + tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) + self.embedding.update(tmp_emb, embedding_name) + + self.encoder = str2encoder[args.encoder](args) + self.pooling_type = args.pooling + + self.target = Target() + self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm") + print("tokenizer vocab nums:", len(args.tokenizer.vocab)) + + def forward(self, src_text, seg_text, src_image, seg_image, image_pos): + """ + Args: + src: [batch_size x seq_length] + tgt: [batch_size] + seg: [batch_size x seq_length] + """ + # Embedding. + src = src_text, src_image, seg_text, seg_image, image_pos + emb = self.embedding(src, None) + seg = torch.cat((seg_image[:,1:], seg_text), 1) + # encoder + output = self.encoder(emb, seg) + # # Target. + output = self.target.lm.output_layer(output) + return output + + +def top_k_top_p_filtering(logits, top_k, top_p): + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float("Inf") + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = -float("Inf") + return logits + + +def load_or_initialize_parameters(args, model): + if args.pretrained_model_path is not None: + # Initialize with pretrained model. + args.logger.info("loading model from {0}".format(args.pretrained_model_path)) + keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False) + args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) + args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) + + if args.vision_model_in_VL_emb_path is not None: + args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path)) + model = load_model(model, args.vision_model_in_VL_emb_path) + else: + # Initialize with normal distribution. + for n, p in list(model.named_parameters()): + if "gamma" not in n and "beta" not in n: + p.data.normal_(0, 0.02) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + infer_opts(parser) + + parser.add_argument("--top_k", type=int, default=70) + parser.add_argument("--top_p", type=float, default=0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4", "sys5"], + help="The instruction type for training large language-vision model.", default="sys0") + parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, + help="Path of the vision pretrained model in the vision language embedding.") + tokenizer_opts(parser) + + deepspeed_opts(parser) + + log_opts(parser) + + mp_opts(parser) + + args = parser.parse_args() + + args.target = "lm" + args.batch_size = 1 + + args = load_hyperparam(args) + + args.tokenizer = str2tokenizer[args.tokenizer](args) + + args.logger = init_logger(args) + + args.pretrained_model_path = args.load_model_path + + # Load or initialize parameters. + if args.enable_zero3: + print("enable_zero3:", args.enable_zero3) + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model = LLaVaGenerate(args) + if args.pretrained_model_path: + model = _load_state_dict_into_model(model, args.pretrained_model_path) + if args.vision_model_in_VL_emb_path is not None: + model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path) + else: + model = LLaVaGenerate(args) + load_or_initialize_parameters(args, model) + + deepspeed.init_distributed() + model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] + + rank = dist.get_rank() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.eval() + + image_height = args.vision_language_emb["vision_encoder"]["image_height"] + image_width = args.vision_language_emb["vision_encoder"]["image_width"] + patch_size = args.vision_language_emb["vision_encoder"]["patch_size"] + + preprocess_pipeline = [] + if "corp" in args.image_preprocess: + preprocess_pipeline.append(transforms.RandomResizedCrop(max(image_height, image_width))) + elif "center_crop" in args.image_preprocess: + preprocess_pipeline.append(transforms.Resize(min(image_height, image_width))) + preprocess_pipeline.append(transforms.CenterCrop((image_height, image_width))) + if "horizontal_flip" in args.image_preprocess: + preprocess_pipeline.append(transforms.RandomHorizontalFlip()) + preprocess_pipeline.append(transforms.Resize((image_height, image_width))) + preprocess_pipeline.append(ZeroOneNormalize()) + if "normalize" in args.image_preprocess: + preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) + transform = transforms.Compose(preprocess_pipeline) + + prompt_template = { + "sys0": "", + "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n", + "sys3": " You are a helpful language and vision assistant. \n", + "sys4": "[INST]<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "sys5": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" + } + if args.instruction_template == "sys0": + role1, role2 = "### Instruction", "### Output" + else: + role1, role2 = "USER", "ASSISTANT" + + im_start, im_end = " ", "" + num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1 + seq_text = args.seq_length - num_image_tokens + input_f = open(args.test_path, mode="r", encoding="utf-8") + datas = json.load(input_f) + try: + prompt_overall = prompt_template[args.instruction_template] + except: + args.logger.info("unsupported prompt template!") + NotImplementedError + + with open(args.prediction_path, mode="w", encoding="utf-8") as outf: + for line_id, item in enumerate(datas): + try: + if "datasets" not in item["image"]: + image_path = "datasets/llava/" + item["image"] + else: + image_path = item["image"] + if not os.path.isfile(image_path): + continue + if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png': + continue + image = Image.open(image_path) + if "pad" in args.image_preprocess: + image = expand2square(image) + image = torch.from_numpy((np.array(image).transpose(2,0,1))) + image = image.to(device) + src_image = transform(image) + except: + print("sth wrong with item{}".format(item)) + continue + + prompt_before_image = prompt_overall + role1 + ": " + ground_truth = [] + prompt_answer_id = [] + if "conversations" in item: + conversations = item["conversations"] + for i, conv in enumerate(conversations): + # 1 round + if i > 1: + continue + if i == 0: + if isinstance(conv, str): + prompt = conv + else: + prompt = conv["value"] + if "" in prompt: + before_image, after_image = prompt.split("") + prompt_before_image = prompt_before_image + before_image + im_start + prompt_after_image = im_end + "\n" + after_image + " " + role2 + ":" + else: + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":" + + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( + [CLS_TOKEN] + args.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + args.logger.info("promt too long, jump for now") + break + prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] + prompt_answer_seg = [seg_before_image + seg_after_image] + elif i % 2 == 0: # human + prompt = conv["value"] + prompt_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(role1 + ":" + prompt + " " + role2 + ":") + ) + if prompt_answer_id: + prompt_answer_id.append(prompt_id) + prompt_answer_seg.append(prompt_answer_seg + [1] * len(prompt_id)) + else: + args.logger.info("no prompt, or prompt too long, jumping") + break + else: # gpt + if isinstance(conv, str): + answer = conv + else: + answer = conv["value"] + ground_truth.append(answer) + else: + prompt = item["instruction"] + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":" + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + args.logger.info("promt too long, jump for now") + break + prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] + prompt_answer_seg = [seg_before_image + seg_after_image] + + image_pos = len(prompt_before_image_id) + + # image_tensor = torch.unsqueeze(src_image, 0).half() + image_tensor = torch.unsqueeze(src_image, 0).bfloat16() + image_seg_tensor = torch.ones(1, num_image_tokens).to(device) + image_pos = torch.LongTensor([image_pos]).to(device) + SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) + text_tensor = None + for i, prompt in enumerate(prompt_answer_id): + if text_tensor is None: + text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([prompt_answer_seg[i]]).to(device) + else: + text_tensor = torch.cat([text_tensor, torch.LongTensor([prompt]).to(device)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([prompt_answer_seg[i]]).to(device)], dim=1) + + while text_tensor.shape[1] + num_image_tokens <= args.seq_length: + output = model(text_tensor, text_seg_tensor, image_tensor, image_seg_tensor, image_pos) + next_token_logits = output[0][-1] / args.temperature + filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + + text_tensor = torch.cat([text_tensor, next_token.view(1, 1)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.tensor([[1]]).to(device)], dim=1) + if next_token.cpu().tolist() == SEP_ID: + break + if rank == 0 and text_tensor is not None: + tokens = [token_id.item() for token_id in text_tensor[0]] + if args.tokenizer.sp_model is not None: + generated_sentence = args.tokenizer.sp_model.decode(tokens) + else: + generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) + print(item) + print(generated_sentence) + print(item, file=outf) + print("\n", file=outf) + print(generated_sentence + "\n\n", file=outf) + diff --git a/tencentpretrain/embeddings/__init__.py b/tencentpretrain/embeddings/__init__.py index 3eff2e32..a90ae889 100644 --- a/tencentpretrain/embeddings/__init__.py +++ b/tencentpretrain/embeddings/__init__.py @@ -8,13 +8,14 @@ from tencentpretrain.embeddings.word_patch_embedding import WordPatchEmbedding from tencentpretrain.embeddings.speech_embedding import SpeechEmbedding from tencentpretrain.embeddings.masked_patch_embedding import MaskedPatchEmbedding +from tencentpretrain.embeddings.vision_language_embedding import VisionLanguageEmbedding str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "seg": SegEmbedding, "sinusoidalpos": SinusoidalposEmbedding, "dual": DualEmbedding, "patch": PatchEmbedding, "word_patch": WordPatchEmbedding, "speech": SpeechEmbedding, - "masked_patch": MaskedPatchEmbedding} + "masked_patch": MaskedPatchEmbedding, "vision_language": VisionLanguageEmbedding} __all__ = ["Embedding", "WordEmbedding", "PosEmbedding", "SegEmbedding", "SinusoidalposEmbedding", "DualEmbedding", "PatchEmbedding", "WordPatchEmbedding", "SpeechEmbedding", - "MaskedPatchEmbedding", "str2embedding"] + "MaskedPatchEmbedding", "VisionLanguageEmbedding", "str2embedding"] diff --git a/tencentpretrain/embeddings/vision_language_embedding.py b/tencentpretrain/embeddings/vision_language_embedding.py new file mode 100755 index 00000000..d6d9117a --- /dev/null +++ b/tencentpretrain/embeddings/vision_language_embedding.py @@ -0,0 +1,68 @@ +from argparse import Namespace +import torch +import torch.nn as nn +import copy + +from tencentpretrain.embeddings.embedding import Embedding +from tencentpretrain.embeddings.word_embedding import WordEmbedding +from tencentpretrain.embeddings.pos_embedding import PosEmbedding +from tencentpretrain.embeddings.patch_embedding import PatchEmbedding +from tencentpretrain.encoders import str2encoder,TransformerEncoder + +str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "patch": PatchEmbedding} + + +class VisionLanguageEmbedding(nn.Module): + ''' + an combination of a vision encoder and a text embedding + ''' + def __init__(self, args, vocab_size): + super(VisionLanguageEmbedding, self).__init__() + # vision model for vision features + vision_encoder_args = copy.deepcopy(vars(args)) + vision_encoder_args.update(args.vision_language_emb["vision_encoder"]) + vision_encoder_args = Namespace(**vision_encoder_args) + self.vision_embedding = Embedding(vision_encoder_args) + for embedding_name in vision_encoder_args.embedding: + tmp_emb = str2embedding[embedding_name](vision_encoder_args, None) + self.vision_embedding.update(tmp_emb, embedding_name) + self.vision_encoder = str2encoder[vision_encoder_args.encoder](vision_encoder_args) + + # map the output of vision model into the same space as the text features + projection_args = copy.deepcopy(vars(args)) + projection_args.update(args.vision_language_emb["projection"]) + projection_args = Namespace(**projection_args) + projection_modules = [nn.Linear(vision_encoder_args.emb_size, projection_args.mlp_hidden_size)] + for _ in range(1, projection_args.num_mlp_layer): + projection_modules.append(nn.GELU()) + projection_modules.append(nn.Linear(projection_args.mlp_hidden_size, projection_args.mlp_hidden_size)) + self.projection = nn.Sequential(*projection_modules) + + # text embedding + text_args = copy.deepcopy(vars(args)) + text_args.update(args.vision_language_emb["text"]) + text_args = Namespace(**text_args) + self.text_embedding = Embedding(text_args) + for embedding_name in text_args.embedding: + tmp_emb = str2embedding[embedding_name](text_args, len(args.tokenizer.vocab)) + self.text_embedding.update(tmp_emb, embedding_name) + + def forward(self, src, seg=None): + src_text, src_image, seg_text, seg_image, image_pos = src + # image features + with torch.no_grad(): + image_emb = self.vision_embedding(src_image, seg_image) + image_emb = self.vision_encoder(image_emb, seg_image, output_layer=-2)[:,1:,:] + image_emb = self.projection(image_emb) + # text embedding + text_emb = self.text_embedding(src_text, seg_text) + # combine text and image + if text_emb.shape[0] == 1: + emb = torch.cat((text_emb[:,:image_pos[0],:], image_emb, text_emb[:,image_pos[0]:,:]), 1) + else: + emb = torch.cat((text_emb[0,:image_pos[0],:], image_emb[0], text_emb[0,image_pos[0]:,:]), 0).unsqueeze(0) + for i in range(1, text_emb.shape[0]): + tmp = torch.cat((text_emb[i,:image_pos[i],:], image_emb[i], text_emb[i,image_pos[i]:,:]), 0).unsqueeze(0) + emb = torch.cat((emb, tmp), 0) + + return emb diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py index f3dd6531..309bf5c2 100644 --- a/tencentpretrain/encoders/transformer_encoder.py +++ b/tencentpretrain/encoders/transformer_encoder.py @@ -63,7 +63,7 @@ def __init__(self, args): self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) - def forward(self, emb, seg): + def forward(self, emb, seg, output_layer=-1): """ Args: emb: [batch_size x seq_length x emb_size] @@ -129,12 +129,16 @@ def custom_forward(*inputs): while l < self.layers_num: inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs) l += self.deepspeed_checkpoint_layers_num + if output_layer != -1 and l == self.layers_num + output_layer: + return inputs[0] else: for i in range(self.layers_num): if self.parameter_sharing: inputs = self.transformer(inputs) else: inputs = self.transformer[i](inputs) + if output_layer != -1 and i == self.layers_num + output_layer: + return inputs[0] hidden = inputs[0] diff --git a/tencentpretrain/models/model.py b/tencentpretrain/models/model.py index f7b30d65..c46abe1b 100755 --- a/tencentpretrain/models/model.py +++ b/tencentpretrain/models/model.py @@ -29,6 +29,18 @@ def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target): if self.decoder is not None and args.share_embedding: self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight + if args.freeze_parameters: + name_mapping = { + "embedding": self.embedding, "encoder": self.encoder, "tgt_embedding": self.tgt_embedding, + "decoder": self.decoder, "target": self.target + } + for freeze_name in args.freeze_parameters: + if name_mapping[freeze_name] is None: + continue + for name, param in name_mapping[freeze_name].named_parameters(): + if args.freeze_exclude_by_name == "" or args.freeze_exclude_by_name not in name: + param.requires_grad = False + def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None): emb = self.embedding(src, seg) memory_bank = self.encoder(emb, seg) diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 57b03cde..88ab02da 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -52,7 +52,10 @@ def model_opts(parser): help="whether use alibi position embedding.") parser.add_argument("--layer_number_scale", action="store_true", help="whether use layer number scaling.") - + parser.add_argument("--freeze_parameters", choices=["embedding", "encoder", "tgt_embedding", "decoder", "target"], + default="", nargs='+', help="Which module to be frozen during training.") + parser.add_argument("--freeze_exclude_by_name", type=str, default="", + help="Exclude some modules with the specific string in the name when freezing parameters.") vision_opts(parser) audio_opts(parser) @@ -67,7 +70,7 @@ def vision_opts(parser): parser.add_argument("--channels_num", type=int, default=3, help="Channels num.") parser.add_argument("--image_preprocess", type=str, default=["crop", "normalize"], nargs='+', - help="Preprocess and data augmentation methods. Choices: [\"crop\", \"horizontal_flip\", \"normalize\"]. ") + help="Preprocess and data augmentation methods. Choices: [\"crop\" or \"center_crop\" or \"pad\", \"horizontal_flip\", \"normalize\"]. ") def audio_opts(parser): diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 5ee00ac3..105ce750 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -1,6 +1,7 @@ import os import json import time +import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel @@ -93,6 +94,9 @@ def init_model(args): else: model_for_dataloader = None + if args.vision_model_in_VL_emb_path is not None: + args.logger.info("loading: {}".format(args.vision_model_in_VL_emb_path)) + model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path) return model_for_training, model_for_dataloader @@ -649,12 +653,23 @@ class LlmSftTrainer(LmTrainer): pass +class LlavaTrainer(LmTrainer): + def forward_propagation(self, batch, model): + src_text, src_img, tgt, seg_text, seg_img, seg_tgt, image_pos = batch + seg = torch.cat((seg_img[:,1:], seg_text), 1) + loss = model((src_text, src_img, seg_text, seg_img, image_pos), tgt, seg, tgt_seg=seg_tgt) + + self.total_loss += loss.item() + loss = loss / self.accumulation_steps + return loss + + str2trainer = {"bert": BertTrainer, "mlm": MlmTrainer, "lm": LmTrainer, "albert": AlbertTrainer, "bilm": BilmTrainer, "cls": ClsTrainer, "mt": MtTrainer, "t5": T5Trainer, "gsg": GsgTrainer, "bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer, "vit": VitTrainer, "vilt": ViltTrainer, "clip": ClipTrainer, "s2t": S2tTrainer, - "beit": BeitTrainer, "dalle": DalleTrainer, "llm_sft": LlmSftTrainer} + "beit": BeitTrainer, "dalle": DalleTrainer, "llm_sft": LlmSftTrainer, "llava": LlavaTrainer} def worker(local_rank, gpu_ranks, args): @@ -676,6 +691,9 @@ def worker(local_rank, gpu_ranks, args): # Build model. model_for_training, model_for_dataloader = init_model(args) + if global_rank == 0: + args.logger.info("model: {}".format(model_for_training)) + # Build optimizer. custom_optimizer, custom_scheduler, optimizer_grouped_parameters = init_optimizer(args, model_for_training) diff --git a/tencentpretrain/utils/__init__.py b/tencentpretrain/utils/__init__.py index efa947e4..385ddba1 100644 --- a/tencentpretrain/utils/__init__.py +++ b/tencentpretrain/utils/__init__.py @@ -13,15 +13,17 @@ "t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset, "cls": ClsDataset, "prefixlm": PrefixlmDataset, "cls_mlm": ClsMlmDataset, "vit": VitDataset, "vilt": ViltDataset, "clip": ClipDataset, "s2t": S2tDataset, - "beit":BeitDataset, "dalle": DalleDataset, "llm_sft": LlmSftDataset, "llm_pretrain": LlmPretrainDataset} + "beit":BeitDataset, "dalle": DalleDataset, "llm_sft": LlmSftDataset, + "llm_pretrain": LlmPretrainDataset, "llava": LlavaDataset} str2dataloader = {"bert": BertDataloader, "lm": LmDataloader, "mlm": MlmDataloader, "bilm": BilmDataloader, "albert": AlbertDataloader, "mt": MtDataloader, "t5": T5Dataloader, "gsg": GsgDataloader, "bart": BartDataloader, "cls": ClsDataloader, "prefixlm": PrefixlmDataloader, "cls_mlm": ClsMlmDataloader, - "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, "s2t": S2tDataloader, - "beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader} + "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, + "s2t": S2tDataloader, "beit":BeitDataloader, "dalle": DalleDataloader, + "llm_sft": LlmSftDataloader, "llava":LlavaDataloader} -str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear} +str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear, "gelu_quick": gelu_quick} str2optimizer = {"adamw": AdamW, "adafactor": Adafactor} @@ -38,11 +40,12 @@ "BertDataset", "LmDataset", "MlmDataset", "BilmDataset", "AlbertDataset", "MtDataset", "T5Dataset", "GsgDataset", "BartDataset", "ClsDataset", "PrefixlmDataset", "ClsMlmDataset", - "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "DalleDataset", "LlmSftDataset", "str2dataset", - "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader", + "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "DalleDataset", "LlmSftDataset", "LlavaDataset", + "str2dataset", "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader", "AlbertDataloader", "MtDataloader", "T5Dataloader", "GsgDataloader", "BartDataloader", "ClsDataloader", "PrefixlmDataloader", "ClsMlmDataloader", - "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "DalleDataloader", "LlmSftDataloader", "str2dataloader", + "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "DalleDataloader", + "LlmSftDataloader", "LlavaDataloader", "str2dataloader", "gelu", "gelu_fast", "relu", "silu", "linear", "str2act", "AdamW", "Adafactor", "str2optimizer", "get_linear_schedule_with_warmup", "get_cosine_schedule_with_warmup", diff --git a/tencentpretrain/utils/act_fun.py b/tencentpretrain/utils/act_fun.py index be306a6b..125607c2 100644 --- a/tencentpretrain/utils/act_fun.py +++ b/tencentpretrain/utils/act_fun.py @@ -10,6 +10,12 @@ def gelu(x): def gelu_fast(x): return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) +def gelu_quick(x): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + return x * torch.sigmoid(1.702 * x) + def relu(x): return F.relu(x) diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index b7269518..d1fd48ef 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -546,6 +546,7 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca self.patch_size = args.patch_size self.image_height = args.image_height self.image_width = args.image_width + self.args = args from torchvision import transforms from tencentpretrain.utils.misc import ZeroOneNormalize @@ -553,6 +554,9 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca preprocess_pipeline = [] if "corp" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomResizedCrop(max(self.image_height, self.image_width))) + elif "center_crop" in args.image_preprocess: + preprocess_pipeline.append(transforms.Resize(min(self.image_height, self.image_width))) + preprocess_pipeline.append(transforms.CenterCrop((self.image_height, self.image_width))) if "horizontal_flip" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomHorizontalFlip()) preprocess_pipeline.append(transforms.Resize((self.image_height, self.image_width))) @@ -963,3 +967,93 @@ def __iter__(self): yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) + + +class LlavaDataloader(VisionDataloader): + + def __iter__(self): + """ + instances: ((src, tgt), (seg_src, seg_tgt), (src_image, image_pos)) + src, tgt: Tokens of the text sample + seg_src_nums, seg_tgt_nums: Number of the segment information of text sample + src_image: Path of the image sample + image_pos: Position of the image in the text sample + + Returns: + src_text: [batch_size x seq_length] + src_image: [batch_size x channel_size x width x hight] + tgt: [batch_size x seq_length] + seg_text: [batch_size x seq_length] + seg_image: [batch_size x (patch_num + 1)] + seg_tgt: [batch_size x seq_length] + image_pos: [batch_size] + + """ + from torchvision.io import read_image + from torchvision.io.image import ImageReadMode + from tencentpretrain.utils.misc import expand2square + + seg_image_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + while True: + while self._empty(): + self._fill_buf() + if self.start + self.batch_size >= self.end: + instances = self.buffer[self.start:] + else: + instances = self.buffer[self.start: self.start + self.batch_size] + + self.start += self.batch_size + + src_text = [] + src_image = [] + tgt = [] + seg_text = [] + seg_image = [] + seg_tgt = [] + image_pos = [] + + for ins in instances: + ins_src, ins_tgt = ins[0] + ins_seg_nums_src, ins_seg_nums_tgt = ins[1] + ins_src_image, ins_image_pos = ins[2] + seq_length = len(ins_src) + text_seq_length = seq_length - seg_image_num + + try: + if "pad" in self.args.image_preprocess: + from PIL import Image + import numpy as np + import torchvision.transforms.functional as transform + image = Image.open(ins_src_image) + image = expand2square(image) + image = torch.from_numpy((np.array(image).transpose(2,0,1))) + else: + image = read_image(ins_src_image, ImageReadMode.RGB) + except: + print("Something is wrong when reading {}, just skipped!".format(ins_src_image)) + continue + image = image.cuda(self.local_rank) + src_image.append(self.transform(image)) + seg_image.append([1] * (seg_image_num + 1)) + image_pos.append(ins_image_pos) + + src_text.append(ins_src[:text_seq_length]) + ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1] + seg_text.append(ins_seg_src[:text_seq_length]) + + ins_tgt_new = [self.vocab.get(PAD_TOKEN)] * seg_image_num + ins_tgt + tgt.append(ins_tgt_new[:seq_length]) + ins_seg_tgt = [0] * seg_image_num + for i, num in enumerate(ins_seg_nums_tgt): + ins_seg_tgt = ins_seg_tgt + [i % 2] * num + seg_tgt.append(ins_seg_tgt[:seq_length]) + + if len(src_image) == 0: + continue + yield torch.LongTensor(src_text), \ + torch.stack(src_image, 0).half(), \ + torch.LongTensor(tgt), \ + torch.LongTensor(seg_text), \ + torch.LongTensor(seg_image), \ + torch.LongTensor(seg_tgt), \ + image_pos diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 9c7a682b..ea927b2e 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -57,6 +57,7 @@ def __init__(self, args, vocab, tokenizer): self.span_max_length = args.span_max_length self.docs_buffer_size = args.docs_buffer_size self.dup_factor = args.dup_factor + self.args = args def build_and_save(self, workers_num): """ @@ -1083,3 +1084,109 @@ def worker(self, proc_id, start, end): break dataset_writer.close() + + +class LlavaDataset(Dataset): + def worker(self, proc_id, start, end): + import json + PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] + role1, role2 = "USER", "ASSISTANT" + im_start, im_end = " ", "" + print("Worker %d is building dataset ... " % proc_id) + set_seed(self.seed) + dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") + pos = start + skip_item = 0 + with open(self.corpus_path, mode="r", encoding="utf-8") as f: + data = json.load(f) + while True: + item = data[pos] + pos += 1 + try: + path = item["image"] + if not os.path.isfile(path): + continue + except: + skip_item += 1 + continue + conversations = item["conversations"] + if "instruction" in item.keys(): + inst = item["instruction"] + else: + inst = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + if inst: + prompt_before_image = inst + " " + role1 + ":" + else: + prompt_before_image = role1 + ":" + prompt_answer_seg_nums, tgt_seg_nums = [], [] + for i, conv in enumerate(conversations): + if i == 0: + if isinstance(conv, str): + prompt = conv + else: + prompt = conv["value"] + if "" in prompt: + before_image, after_image = prompt.split("") + prompt_before_image = prompt_before_image + before_image + im_start + prompt_after_image = im_end + "\n" + after_image + " " + role2 + ":" + else: + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":" + prompt_before_image_id = self.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + self.tokenizer.tokenize(prompt_before_image)) + prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > self.seq_length: + print("promt too long, jumped") + continue + prompt_answer_id = prompt_before_image_id + prompt_after_image_id + tgt_id = [PAD_ID] * (len(prompt_answer_id) - 1) + tgt_seg_nums = [len(tgt_id)] + elif i % 2 == 0: # human + if isinstance(conv, str): + prompt = conv + else: + prompt = conv["value"] + prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(role1 + ":" + prompt + " " + role2 + ":")) + prompt_answer_id = prompt_answer_id + prompt_id + tgt_id = tgt_id + [PAD_ID] * len(prompt_id) + if len(tgt_seg_nums) == 1: + tgt_seg_nums[0] = tgt_seg_nums[0] + len(prompt_id) + else: + tgt_seg_nums = tgt_seg_nums + [len(prompt_id)] + else: # gpt + if isinstance(conv, str): + answer = conv + else: + answer = conv["value"] + answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer) + [SEP_TOKEN]) + prompt_answer_id = prompt_answer_id + answer_id + tgt_id = tgt_id + answer_id + tgt_seg_nums = tgt_seg_nums + [len(answer_id)] + + if len(tgt_id) > self.seq_length: + tgt_id = tgt_id[:self.seq_length] + pad_num = self.seq_length - len(tgt_id) + tgt_id = tgt_id + [PAD_ID] * pad_num + while sum(tgt_seg_nums) > self.seq_length: + tgt_seg_nums = tgt_seg_nums[:-1] + pad_num = self.seq_length - sum(tgt_seg_nums) + tgt_seg_nums = tgt_seg_nums + [pad_num] + + if len(prompt_answer_id) > self.seq_length : + prompt_answer_id = prompt_answer_id[:self.seq_length] + + pad_num = self.seq_length - len(prompt_answer_id) + prompt_answer_seg_nums = [len(prompt_answer_id), pad_num] + prompt_answer_id = prompt_answer_id + [PAD_ID] * pad_num + + image_pos = len(prompt_before_image_id) + src = (prompt_answer_id, tgt_id) + seg_nums = (prompt_answer_seg_nums, tgt_seg_nums) + image = (path, image_pos) + pickle.dump((src, seg_nums, image), dataset_writer) + + if pos >= end: + break + + dataset_writer.close() diff --git a/tencentpretrain/utils/misc.py b/tencentpretrain/utils/misc.py index 01545650..eea8541f 100644 --- a/tencentpretrain/utils/misc.py +++ b/tencentpretrain/utils/misc.py @@ -4,6 +4,11 @@ def count_lines(file_path): lines_num = 0 + if file_path.endswith(".json"): + import json + with open(file_path, 'rb') as f: + data = json.load(f) + return len(data) with open(file_path, 'rb') as f: while True: data = f.read(2 ** 20) @@ -34,6 +39,24 @@ def pooling(memory_bank, seg, pooling_type): features = memory_bank[:, 0, :] return features + class ZeroOneNormalize(object): def __call__(self, img): return img.float().div(255) + + +def expand2square(img, background_color=(122, 116, 104)): + from PIL import Image + width, height = img.size + if img.mode != "RGB": + img = img.convert("RGB") + if width == height: + return img + elif width > height: + result = Image.new(img.mode, (width, width), background_color) + result.paste(img, (0, (width - height) // 2)) + return result + else: + result = Image.new(img.mode, (height, height), background_color) + result.paste(img, ((height - width) // 2, 0)) + return result