diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json
new file mode 100755
index 00000000..12d01a02
--- /dev/null
+++ b/models/llava/7b_config.json
@@ -0,0 +1,54 @@
+{
+ "embedding": ["vision_language"],
+ "vision_language_emb":{
+ "vision_encoder":{
+ "image_height": 336,
+ "image_width": 336,
+ "patch_size": 14,
+ "emb_size": 1024,
+ "feedforward_size": 4096,
+ "hidden_size": 1024,
+ "hidden_act": "gelu_quick",
+ "heads_num": 16,
+ "layers_num": 24,
+ "dropout": 0.0,
+ "max_seq_length": 577,
+ "embedding": ["patch", "pos"],
+ "patch_proj_bias": false,
+ "remove_embedding_layernorm": false,
+ "remove_transformer_bias": false,
+ "rotary_position_embedding": false,
+ "encoder": "transformer",
+ "feed_forward": "dense",
+ "mask": "fully_visible",
+ "layernorm_positioning": "pre",
+ "layernorm": "normal",
+ "has_cls": false
+ },
+ "projection":{
+ "mlp_hidden_size": 4096,
+ "num_mlp_layer": 2
+ },
+ "text":{
+ "embedding": ["word"]
+ }
+ },
+ "emb_size": 4096,
+ "feedforward_size": 11008,
+ "hidden_size": 4096,
+ "hidden_act": "silu",
+ "heads_num": 32,
+ "layers_num": 32,
+ "dropout": 0.0,
+ "data_processor": "llava",
+ "max_seq_length": 2048,
+ "remove_transformer_bias": true,
+ "remove_embedding_layernorm": true,
+ "rotary_position_embedding": true,
+ "encoder": "transformer",
+ "feed_forward": "gated",
+ "mask": "causal",
+ "layernorm_positioning": "pre",
+ "layernorm": "rms",
+ "target": ["lm"]
+ }
diff --git a/preprocess.py b/preprocess.py
index 7f61c54c..8c0f1a00 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -28,12 +28,14 @@ def main():
parser.add_argument("--data_processor",
choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls", "prefixlm",
"gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle",
- "llm_pretrain", "llm_sft"], default="bert",
+ "llm_pretrain", "llm_sft", "llava"], default="bert",
help="The data processor of the pretraining model.")
parser.add_argument("--docs_buffer_size", type=int, default=100000,
help="The buffer size of documents in memory, specific to targets that require negative sampling.")
parser.add_argument("--seq_length", type=int, default=128, help="Sequence length of instances.")
parser.add_argument("--tgt_seq_length", type=int, default=128, help="Target sequence length of instances.")
+ parser.add_argument("--vision_seq_length_in_VL", type=int, default=576,
+ help="Number of image patches in the vision language model(LLaVa).")
parser.add_argument("--dup_factor", type=int, default=5,
help="Duplicate instances multiple times.")
parser.add_argument("--short_seq_prob", type=float, default=0.1,
diff --git a/pretrain.py b/pretrain.py
index 14c744c3..8a667160 100644
--- a/pretrain.py
+++ b/pretrain.py
@@ -13,6 +13,8 @@ def main():
help="Path of the preprocessed dataset.")
parser.add_argument("--pretrained_model_path", type=str, default=None,
help="Path of the pretrained model.")
+ parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None,
+ help="Path of the vision pretrained model in the vision language embedding.")
parser.add_argument("--output_model_path", type=str, required=True,
help="Path of the output model.")
parser.add_argument("--config_path", type=str, default="models/bert/base_config.json",
diff --git a/scripts/convert_llava_from_huggingface_to_tencentpretrain.py b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py
new file mode 100755
index 00000000..606a0f9c
--- /dev/null
+++ b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py
@@ -0,0 +1,77 @@
+import argparse
+import collections
+import torch
+import os
+import json
+
+
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument("--input_model_path", type=str, default="models/llava-v1.5-7b/",
+ help=".")
+parser.add_argument("--output_model_path", type=str, default="models/llava-v1.5-7b.bin",
+ help=".")
+parser.add_argument("--type", choices=["7B", "13B", "33B", "65B"], default="7B")
+
+args = parser.parse_args()
+
+model_config = {"7B" : [32, 4096, 32],
+ "13B": [40, 5120, 40],
+ "33B": [60, 6656, 52],
+ "65B": [80, 8192, 64]
+ }
+
+layers_num, dim, n_heads = model_config[args.type]
+
+files = os.listdir(args.input_model_path)
+model_files = [f for f in files if f[-4:] == ".bin"]
+input_models = {f: torch.load(os.path.join(args.input_model_path, f), map_location="cpu") for f in model_files}
+
+with open(os.path.join(args.input_model_path, "pytorch_model.bin.index.json")) as f:
+ model_index = json.load(f)
+ weight_map = model_index["weight_map"]
+
+
+output_model = collections.OrderedDict()
+
+def get_weight_from_name(layer_name):
+ return input_models[weight_map[layer_name]][layer_name]
+
+def unpermute(w):
+ return w.reshape(n_heads, 2, dim // n_heads // 2, dim).transpose(2, 1).reshape(dim, dim)
+
+output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight")
+
+for i in range(layers_num):
+
+ output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.weight"] = \
+ unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.q_proj.weight"))
+ output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.weight"] = \
+ unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.k_proj.weight"))
+
+ output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".self_attn.v_proj.weight")
+ output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".self_attn.o_proj.weight")
+
+ output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".input_layernorm.weight")
+
+ output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".mlp.gate_proj.weight")
+ output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".mlp.up_proj.weight")
+ output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".mlp.down_proj.weight")
+
+ output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \
+ get_weight_from_name("model.layers." + str(i) + ".post_attention_layernorm.weight")
+
+output_model["encoder.layer_norm.weight"] = get_weight_from_name("model.norm.weight")
+output_model["target.lm.output_layer.weight"] = get_weight_from_name("lm_head.weight")
+
+output_model["embedding.vision_language.projection.0.weight"] = get_weight_from_name("model.mm_projector.0.weight")
+output_model["embedding.vision_language.projection.0.bias"] = get_weight_from_name("model.mm_projector.0.bias")
+output_model["embedding.vision_language.projection.2.weight"] = get_weight_from_name("model.mm_projector.2.weight")
+output_model["embedding.vision_language.projection.2.bias"] = get_weight_from_name("model.mm_projector.2.bias")
+
+torch.save(output_model, args.output_model_path)
diff --git a/scripts/convert_llm_in_llava.py b/scripts/convert_llm_in_llava.py
new file mode 100755
index 00000000..652bcb26
--- /dev/null
+++ b/scripts/convert_llm_in_llava.py
@@ -0,0 +1,30 @@
+import argparse
+import collections
+import torch
+
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
+ help=".")
+ parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
+ help=".")
+
+
+ args = parser.parse_args()
+
+ input_model = torch.load(args.input_model_path, map_location="cpu")
+
+ output_model = collections.OrderedDict()
+
+ for k in input_model.keys():
+ if k == "embedding.word.embedding.weight":
+ output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = input_model[k]
+ else:
+ output_model[k] = input_model[k]
+
+ torch.save(output_model, args.output_model_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/convert_model_add_prefix.py b/scripts/convert_model_add_prefix.py
new file mode 100755
index 00000000..f75ca66e
--- /dev/null
+++ b/scripts/convert_model_add_prefix.py
@@ -0,0 +1,28 @@
+import argparse
+import collections
+import torch
+
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
+ help=".")
+ parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
+ help=".")
+ parser.add_argument("--prefix", type=str, default="", help="prefix to add")
+
+
+ args = parser.parse_args()
+
+ input_model = torch.load(args.input_model_path, map_location="cpu")
+
+ output_model = collections.OrderedDict()
+
+ for k in input_model.keys():
+ output_model[args.prefix + k] = input_model[k]
+
+ torch.save(output_model, args.output_model_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py
new file mode 100755
index 00000000..f4d960ed
--- /dev/null
+++ b/scripts/generate_lm_llava_deepspeed.py
@@ -0,0 +1,333 @@
+"""
+ This script provides an exmaple to wrap TencentPretrain for generation.
+ Given the beginning of a text, language model generates the rest.
+"""
+import sys
+import os
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+from torchvision import transforms
+from torchvision.io import read_image
+from torchvision.io.image import ImageReadMode
+import imghdr
+import deepspeed
+import numpy as np
+from PIL import Image
+import torchvision.transforms.functional as transform
+
+tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.append(tencentpretrain_dir)
+
+from tencentpretrain.embeddings import *
+from tencentpretrain.encoders import *
+from tencentpretrain.targets import *
+from tencentpretrain.utils.constants import *
+from tencentpretrain.utils import *
+from tencentpretrain.utils.config import load_hyperparam
+from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts
+from tencentpretrain.opts import deepspeed_opts
+from tencentpretrain.utils.logging import init_logger
+from tencentpretrain.model_loader import _load_state_dict_into_model, load_model
+from tencentpretrain.utils.misc import pooling, ZeroOneNormalize, expand2square
+
+
+class LLaVaGenerate(nn.Module):
+ def __init__(self, args):
+ super(LLaVaGenerate, self).__init__()
+ self.args = args
+ self.embedding = Embedding(args)
+ for embedding_name in args.embedding:
+ tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
+ self.embedding.update(tmp_emb, embedding_name)
+
+ self.encoder = str2encoder[args.encoder](args)
+ self.pooling_type = args.pooling
+
+ self.target = Target()
+ self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm")
+ print("tokenizer vocab nums:", len(args.tokenizer.vocab))
+
+ def forward(self, src_text, seg_text, src_image, seg_image, image_pos):
+ """
+ Args:
+ src: [batch_size x seq_length]
+ tgt: [batch_size]
+ seg: [batch_size x seq_length]
+ """
+ # Embedding.
+ src = src_text, src_image, seg_text, seg_image, image_pos
+ emb = self.embedding(src, None)
+ seg = torch.cat((seg_image[:,1:], seg_text), 1)
+ # encoder
+ output = self.encoder(emb, seg)
+ # # Target.
+ output = self.target.lm.output_layer(output)
+ return output
+
+
+def top_k_top_p_filtering(logits, top_k, top_p):
+ top_k = min(top_k, logits.size(-1)) # Safety check
+ if top_k > 0:
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = -float("Inf")
+
+ if top_p > 0.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > top_p
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
+ logits[indices_to_remove] = -float("Inf")
+ return logits
+
+
+def load_or_initialize_parameters(args, model):
+ if args.pretrained_model_path is not None:
+ # Initialize with pretrained model.
+ args.logger.info("loading model from {0}".format(args.pretrained_model_path))
+ keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False)
+ args.logger.info("missing_keys: {0}".format(keys_info.missing_keys))
+ args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys))
+
+ if args.vision_model_in_VL_emb_path is not None:
+ args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path))
+ model = load_model(model, args.vision_model_in_VL_emb_path)
+ else:
+ # Initialize with normal distribution.
+ for n, p in list(model.named_parameters()):
+ if "gamma" not in n and "beta" not in n:
+ p.data.normal_(0, 0.02)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ infer_opts(parser)
+
+ parser.add_argument("--top_k", type=int, default=70)
+ parser.add_argument("--top_p", type=float, default=0)
+ parser.add_argument("--temperature", type=float, default=1.0)
+ parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4", "sys5"],
+ help="The instruction type for training large language-vision model.", default="sys0")
+ parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None,
+ help="Path of the vision pretrained model in the vision language embedding.")
+ tokenizer_opts(parser)
+
+ deepspeed_opts(parser)
+
+ log_opts(parser)
+
+ mp_opts(parser)
+
+ args = parser.parse_args()
+
+ args.target = "lm"
+ args.batch_size = 1
+
+ args = load_hyperparam(args)
+
+ args.tokenizer = str2tokenizer[args.tokenizer](args)
+
+ args.logger = init_logger(args)
+
+ args.pretrained_model_path = args.load_model_path
+
+ # Load or initialize parameters.
+ if args.enable_zero3:
+ print("enable_zero3:", args.enable_zero3)
+ with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config):
+ model = LLaVaGenerate(args)
+ if args.pretrained_model_path:
+ model = _load_state_dict_into_model(model, args.pretrained_model_path)
+ if args.vision_model_in_VL_emb_path is not None:
+ model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path)
+ else:
+ model = LLaVaGenerate(args)
+ load_or_initialize_parameters(args, model)
+
+ deepspeed.init_distributed()
+ model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0]
+
+ rank = dist.get_rank()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.eval()
+
+ image_height = args.vision_language_emb["vision_encoder"]["image_height"]
+ image_width = args.vision_language_emb["vision_encoder"]["image_width"]
+ patch_size = args.vision_language_emb["vision_encoder"]["patch_size"]
+
+ preprocess_pipeline = []
+ if "corp" in args.image_preprocess:
+ preprocess_pipeline.append(transforms.RandomResizedCrop(max(image_height, image_width)))
+ elif "center_crop" in args.image_preprocess:
+ preprocess_pipeline.append(transforms.Resize(min(image_height, image_width)))
+ preprocess_pipeline.append(transforms.CenterCrop((image_height, image_width)))
+ if "horizontal_flip" in args.image_preprocess:
+ preprocess_pipeline.append(transforms.RandomHorizontalFlip())
+ preprocess_pipeline.append(transforms.Resize((image_height, image_width)))
+ preprocess_pipeline.append(ZeroOneNormalize())
+ if "normalize" in args.image_preprocess:
+ preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
+ transform = transforms.Compose(preprocess_pipeline)
+
+ prompt_template = {
+ "sys0": "",
+ "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n",
+ "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n",
+ "sys3": " You are a helpful language and vision assistant. \n",
+ "sys4": "[INST]<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n",
+ "sys5": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
+ }
+ if args.instruction_template == "sys0":
+ role1, role2 = "### Instruction", "### Output"
+ else:
+ role1, role2 = "USER", "ASSISTANT"
+
+ im_start, im_end = " ", ""
+ num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1
+ seq_text = args.seq_length - num_image_tokens
+ input_f = open(args.test_path, mode="r", encoding="utf-8")
+ datas = json.load(input_f)
+ try:
+ prompt_overall = prompt_template[args.instruction_template]
+ except:
+ args.logger.info("unsupported prompt template!")
+ NotImplementedError
+
+ with open(args.prediction_path, mode="w", encoding="utf-8") as outf:
+ for line_id, item in enumerate(datas):
+ try:
+ if "datasets" not in item["image"]:
+ image_path = "datasets/llava/" + item["image"]
+ else:
+ image_path = item["image"]
+ if not os.path.isfile(image_path):
+ continue
+ if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png':
+ continue
+ image = Image.open(image_path)
+ if "pad" in args.image_preprocess:
+ image = expand2square(image)
+ image = torch.from_numpy((np.array(image).transpose(2,0,1)))
+ image = image.to(device)
+ src_image = transform(image)
+ except:
+ print("sth wrong with item{}".format(item))
+ continue
+
+ prompt_before_image = prompt_overall + role1 + ": "
+ ground_truth = []
+ prompt_answer_id = []
+ if "conversations" in item:
+ conversations = item["conversations"]
+ for i, conv in enumerate(conversations):
+ # 1 round
+ if i > 1:
+ continue
+ if i == 0:
+ if isinstance(conv, str):
+ prompt = conv
+ else:
+ prompt = conv["value"]
+ if "" in prompt:
+ before_image, after_image = prompt.split("")
+ prompt_before_image = prompt_before_image + before_image + im_start
+ prompt_after_image = im_end + "\n" + after_image + " " + role2 + ":"
+ else:
+ prompt_before_image = prompt_before_image + im_start
+ prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":"
+
+ prompt_before_image_id = args.tokenizer.convert_tokens_to_ids(
+ [CLS_TOKEN] + args.tokenizer.tokenize(prompt_before_image)
+ )
+ prompt_after_image_id = args.tokenizer.convert_tokens_to_ids(
+ args.tokenizer.tokenize(prompt_after_image)
+ )
+ seg_before_image = [1] * len(prompt_before_image_id)
+ seg_after_image = [1] * len(prompt_after_image_id)
+ if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text:
+ args.logger.info("promt too long, jump for now")
+ break
+ prompt_answer_id = [prompt_before_image_id + prompt_after_image_id]
+ prompt_answer_seg = [seg_before_image + seg_after_image]
+ elif i % 2 == 0: # human
+ prompt = conv["value"]
+ prompt_id = args.tokenizer.convert_tokens_to_ids(
+ args.tokenizer.tokenize(role1 + ":" + prompt + " " + role2 + ":")
+ )
+ if prompt_answer_id:
+ prompt_answer_id.append(prompt_id)
+ prompt_answer_seg.append(prompt_answer_seg + [1] * len(prompt_id))
+ else:
+ args.logger.info("no prompt, or prompt too long, jumping")
+ break
+ else: # gpt
+ if isinstance(conv, str):
+ answer = conv
+ else:
+ answer = conv["value"]
+ ground_truth.append(answer)
+ else:
+ prompt = item["instruction"]
+ prompt_before_image = prompt_before_image + im_start
+ prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":"
+ prompt_before_image_id = args.tokenizer.convert_tokens_to_ids(
+ args.tokenizer.tokenize(prompt_before_image)
+ )
+ prompt_after_image_id = args.tokenizer.convert_tokens_to_ids(
+ args.tokenizer.tokenize(prompt_after_image)
+ )
+ seg_before_image = [1] * len(prompt_before_image_id)
+ seg_after_image = [1] * len(prompt_after_image_id)
+ if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text:
+ args.logger.info("promt too long, jump for now")
+ break
+ prompt_answer_id = [prompt_before_image_id + prompt_after_image_id]
+ prompt_answer_seg = [seg_before_image + seg_after_image]
+
+ image_pos = len(prompt_before_image_id)
+
+ # image_tensor = torch.unsqueeze(src_image, 0).half()
+ image_tensor = torch.unsqueeze(src_image, 0).bfloat16()
+ image_seg_tensor = torch.ones(1, num_image_tokens).to(device)
+ image_pos = torch.LongTensor([image_pos]).to(device)
+ SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN])
+ text_tensor = None
+ for i, prompt in enumerate(prompt_answer_id):
+ if text_tensor is None:
+ text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([prompt_answer_seg[i]]).to(device)
+ else:
+ text_tensor = torch.cat([text_tensor, torch.LongTensor([prompt]).to(device)], dim=1)
+ text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([prompt_answer_seg[i]]).to(device)], dim=1)
+
+ while text_tensor.shape[1] + num_image_tokens <= args.seq_length:
+ output = model(text_tensor, text_seg_tensor, image_tensor, image_seg_tensor, image_pos)
+ next_token_logits = output[0][-1] / args.temperature
+ filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
+
+ text_tensor = torch.cat([text_tensor, next_token.view(1, 1)], dim=1)
+ text_seg_tensor = torch.cat([text_seg_tensor, torch.tensor([[1]]).to(device)], dim=1)
+ if next_token.cpu().tolist() == SEP_ID:
+ break
+ if rank == 0 and text_tensor is not None:
+ tokens = [token_id.item() for token_id in text_tensor[0]]
+ if args.tokenizer.sp_model is not None:
+ generated_sentence = args.tokenizer.sp_model.decode(tokens)
+ else:
+ generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens))
+ print(item)
+ print(generated_sentence)
+ print(item, file=outf)
+ print("\n", file=outf)
+ print(generated_sentence + "\n\n", file=outf)
+
diff --git a/tencentpretrain/embeddings/__init__.py b/tencentpretrain/embeddings/__init__.py
index 3eff2e32..a90ae889 100644
--- a/tencentpretrain/embeddings/__init__.py
+++ b/tencentpretrain/embeddings/__init__.py
@@ -8,13 +8,14 @@
from tencentpretrain.embeddings.word_patch_embedding import WordPatchEmbedding
from tencentpretrain.embeddings.speech_embedding import SpeechEmbedding
from tencentpretrain.embeddings.masked_patch_embedding import MaskedPatchEmbedding
+from tencentpretrain.embeddings.vision_language_embedding import VisionLanguageEmbedding
str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "seg": SegEmbedding,
"sinusoidalpos": SinusoidalposEmbedding, "dual": DualEmbedding,
"patch": PatchEmbedding, "word_patch": WordPatchEmbedding, "speech": SpeechEmbedding,
- "masked_patch": MaskedPatchEmbedding}
+ "masked_patch": MaskedPatchEmbedding, "vision_language": VisionLanguageEmbedding}
__all__ = ["Embedding", "WordEmbedding", "PosEmbedding", "SegEmbedding", "SinusoidalposEmbedding",
"DualEmbedding", "PatchEmbedding", "WordPatchEmbedding", "SpeechEmbedding",
- "MaskedPatchEmbedding", "str2embedding"]
+ "MaskedPatchEmbedding", "VisionLanguageEmbedding", "str2embedding"]
diff --git a/tencentpretrain/embeddings/vision_language_embedding.py b/tencentpretrain/embeddings/vision_language_embedding.py
new file mode 100755
index 00000000..d6d9117a
--- /dev/null
+++ b/tencentpretrain/embeddings/vision_language_embedding.py
@@ -0,0 +1,68 @@
+from argparse import Namespace
+import torch
+import torch.nn as nn
+import copy
+
+from tencentpretrain.embeddings.embedding import Embedding
+from tencentpretrain.embeddings.word_embedding import WordEmbedding
+from tencentpretrain.embeddings.pos_embedding import PosEmbedding
+from tencentpretrain.embeddings.patch_embedding import PatchEmbedding
+from tencentpretrain.encoders import str2encoder,TransformerEncoder
+
+str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "patch": PatchEmbedding}
+
+
+class VisionLanguageEmbedding(nn.Module):
+ '''
+ an combination of a vision encoder and a text embedding
+ '''
+ def __init__(self, args, vocab_size):
+ super(VisionLanguageEmbedding, self).__init__()
+ # vision model for vision features
+ vision_encoder_args = copy.deepcopy(vars(args))
+ vision_encoder_args.update(args.vision_language_emb["vision_encoder"])
+ vision_encoder_args = Namespace(**vision_encoder_args)
+ self.vision_embedding = Embedding(vision_encoder_args)
+ for embedding_name in vision_encoder_args.embedding:
+ tmp_emb = str2embedding[embedding_name](vision_encoder_args, None)
+ self.vision_embedding.update(tmp_emb, embedding_name)
+ self.vision_encoder = str2encoder[vision_encoder_args.encoder](vision_encoder_args)
+
+ # map the output of vision model into the same space as the text features
+ projection_args = copy.deepcopy(vars(args))
+ projection_args.update(args.vision_language_emb["projection"])
+ projection_args = Namespace(**projection_args)
+ projection_modules = [nn.Linear(vision_encoder_args.emb_size, projection_args.mlp_hidden_size)]
+ for _ in range(1, projection_args.num_mlp_layer):
+ projection_modules.append(nn.GELU())
+ projection_modules.append(nn.Linear(projection_args.mlp_hidden_size, projection_args.mlp_hidden_size))
+ self.projection = nn.Sequential(*projection_modules)
+
+ # text embedding
+ text_args = copy.deepcopy(vars(args))
+ text_args.update(args.vision_language_emb["text"])
+ text_args = Namespace(**text_args)
+ self.text_embedding = Embedding(text_args)
+ for embedding_name in text_args.embedding:
+ tmp_emb = str2embedding[embedding_name](text_args, len(args.tokenizer.vocab))
+ self.text_embedding.update(tmp_emb, embedding_name)
+
+ def forward(self, src, seg=None):
+ src_text, src_image, seg_text, seg_image, image_pos = src
+ # image features
+ with torch.no_grad():
+ image_emb = self.vision_embedding(src_image, seg_image)
+ image_emb = self.vision_encoder(image_emb, seg_image, output_layer=-2)[:,1:,:]
+ image_emb = self.projection(image_emb)
+ # text embedding
+ text_emb = self.text_embedding(src_text, seg_text)
+ # combine text and image
+ if text_emb.shape[0] == 1:
+ emb = torch.cat((text_emb[:,:image_pos[0],:], image_emb, text_emb[:,image_pos[0]:,:]), 1)
+ else:
+ emb = torch.cat((text_emb[0,:image_pos[0],:], image_emb[0], text_emb[0,image_pos[0]:,:]), 0).unsqueeze(0)
+ for i in range(1, text_emb.shape[0]):
+ tmp = torch.cat((text_emb[i,:image_pos[i],:], image_emb[i], text_emb[i,image_pos[i]:,:]), 0).unsqueeze(0)
+ emb = torch.cat((emb, tmp), 0)
+
+ return emb
diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py
index f3dd6531..309bf5c2 100644
--- a/tencentpretrain/encoders/transformer_encoder.py
+++ b/tencentpretrain/encoders/transformer_encoder.py
@@ -63,7 +63,7 @@ def __init__(self, args):
self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
- def forward(self, emb, seg):
+ def forward(self, emb, seg, output_layer=-1):
"""
Args:
emb: [batch_size x seq_length x emb_size]
@@ -129,12 +129,16 @@ def custom_forward(*inputs):
while l < self.layers_num:
inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs)
l += self.deepspeed_checkpoint_layers_num
+ if output_layer != -1 and l == self.layers_num + output_layer:
+ return inputs[0]
else:
for i in range(self.layers_num):
if self.parameter_sharing:
inputs = self.transformer(inputs)
else:
inputs = self.transformer[i](inputs)
+ if output_layer != -1 and i == self.layers_num + output_layer:
+ return inputs[0]
hidden = inputs[0]
diff --git a/tencentpretrain/models/model.py b/tencentpretrain/models/model.py
index f7b30d65..c46abe1b 100755
--- a/tencentpretrain/models/model.py
+++ b/tencentpretrain/models/model.py
@@ -29,6 +29,18 @@ def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target):
if self.decoder is not None and args.share_embedding:
self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight
+ if args.freeze_parameters:
+ name_mapping = {
+ "embedding": self.embedding, "encoder": self.encoder, "tgt_embedding": self.tgt_embedding,
+ "decoder": self.decoder, "target": self.target
+ }
+ for freeze_name in args.freeze_parameters:
+ if name_mapping[freeze_name] is None:
+ continue
+ for name, param in name_mapping[freeze_name].named_parameters():
+ if args.freeze_exclude_by_name == "" or args.freeze_exclude_by_name not in name:
+ param.requires_grad = False
+
def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None):
emb = self.embedding(src, seg)
memory_bank = self.encoder(emb, seg)
diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py
index 57b03cde..88ab02da 100755
--- a/tencentpretrain/opts.py
+++ b/tencentpretrain/opts.py
@@ -52,7 +52,10 @@ def model_opts(parser):
help="whether use alibi position embedding.")
parser.add_argument("--layer_number_scale", action="store_true",
help="whether use layer number scaling.")
-
+ parser.add_argument("--freeze_parameters", choices=["embedding", "encoder", "tgt_embedding", "decoder", "target"],
+ default="", nargs='+', help="Which module to be frozen during training.")
+ parser.add_argument("--freeze_exclude_by_name", type=str, default="",
+ help="Exclude some modules with the specific string in the name when freezing parameters.")
vision_opts(parser)
audio_opts(parser)
@@ -67,7 +70,7 @@ def vision_opts(parser):
parser.add_argument("--channels_num", type=int, default=3,
help="Channels num.")
parser.add_argument("--image_preprocess", type=str, default=["crop", "normalize"], nargs='+',
- help="Preprocess and data augmentation methods. Choices: [\"crop\", \"horizontal_flip\", \"normalize\"]. ")
+ help="Preprocess and data augmentation methods. Choices: [\"crop\" or \"center_crop\" or \"pad\", \"horizontal_flip\", \"normalize\"]. ")
def audio_opts(parser):
diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py
index 5ee00ac3..105ce750 100755
--- a/tencentpretrain/trainer.py
+++ b/tencentpretrain/trainer.py
@@ -1,6 +1,7 @@
import os
import json
import time
+import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
@@ -93,6 +94,9 @@ def init_model(args):
else:
model_for_dataloader = None
+ if args.vision_model_in_VL_emb_path is not None:
+ args.logger.info("loading: {}".format(args.vision_model_in_VL_emb_path))
+ model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path)
return model_for_training, model_for_dataloader
@@ -649,12 +653,23 @@ class LlmSftTrainer(LmTrainer):
pass
+class LlavaTrainer(LmTrainer):
+ def forward_propagation(self, batch, model):
+ src_text, src_img, tgt, seg_text, seg_img, seg_tgt, image_pos = batch
+ seg = torch.cat((seg_img[:,1:], seg_text), 1)
+ loss = model((src_text, src_img, seg_text, seg_img, image_pos), tgt, seg, tgt_seg=seg_tgt)
+
+ self.total_loss += loss.item()
+ loss = loss / self.accumulation_steps
+ return loss
+
+
str2trainer = {"bert": BertTrainer, "mlm": MlmTrainer, "lm": LmTrainer,
"albert": AlbertTrainer, "bilm": BilmTrainer, "cls": ClsTrainer,
"mt": MtTrainer, "t5": T5Trainer, "gsg": GsgTrainer,
"bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer,
"vit": VitTrainer, "vilt": ViltTrainer, "clip": ClipTrainer, "s2t": S2tTrainer,
- "beit": BeitTrainer, "dalle": DalleTrainer, "llm_sft": LlmSftTrainer}
+ "beit": BeitTrainer, "dalle": DalleTrainer, "llm_sft": LlmSftTrainer, "llava": LlavaTrainer}
def worker(local_rank, gpu_ranks, args):
@@ -676,6 +691,9 @@ def worker(local_rank, gpu_ranks, args):
# Build model.
model_for_training, model_for_dataloader = init_model(args)
+ if global_rank == 0:
+ args.logger.info("model: {}".format(model_for_training))
+
# Build optimizer.
custom_optimizer, custom_scheduler, optimizer_grouped_parameters = init_optimizer(args, model_for_training)
diff --git a/tencentpretrain/utils/__init__.py b/tencentpretrain/utils/__init__.py
index efa947e4..385ddba1 100644
--- a/tencentpretrain/utils/__init__.py
+++ b/tencentpretrain/utils/__init__.py
@@ -13,15 +13,17 @@
"t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset,
"cls": ClsDataset, "prefixlm": PrefixlmDataset, "cls_mlm": ClsMlmDataset,
"vit": VitDataset, "vilt": ViltDataset, "clip": ClipDataset, "s2t": S2tDataset,
- "beit":BeitDataset, "dalle": DalleDataset, "llm_sft": LlmSftDataset, "llm_pretrain": LlmPretrainDataset}
+ "beit":BeitDataset, "dalle": DalleDataset, "llm_sft": LlmSftDataset,
+ "llm_pretrain": LlmPretrainDataset, "llava": LlavaDataset}
str2dataloader = {"bert": BertDataloader, "lm": LmDataloader, "mlm": MlmDataloader,
"bilm": BilmDataloader, "albert": AlbertDataloader, "mt": MtDataloader,
"t5": T5Dataloader, "gsg": GsgDataloader, "bart": BartDataloader,
"cls": ClsDataloader, "prefixlm": PrefixlmDataloader, "cls_mlm": ClsMlmDataloader,
- "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, "s2t": S2tDataloader,
- "beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader}
+ "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader,
+ "s2t": S2tDataloader, "beit":BeitDataloader, "dalle": DalleDataloader,
+ "llm_sft": LlmSftDataloader, "llava":LlavaDataloader}
-str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear}
+str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear, "gelu_quick": gelu_quick}
str2optimizer = {"adamw": AdamW, "adafactor": Adafactor}
@@ -38,11 +40,12 @@
"BertDataset", "LmDataset", "MlmDataset", "BilmDataset",
"AlbertDataset", "MtDataset", "T5Dataset", "GsgDataset",
"BartDataset", "ClsDataset", "PrefixlmDataset", "ClsMlmDataset",
- "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "DalleDataset", "LlmSftDataset", "str2dataset",
- "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader",
+ "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "DalleDataset", "LlmSftDataset", "LlavaDataset",
+ "str2dataset", "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader",
"AlbertDataloader", "MtDataloader", "T5Dataloader", "GsgDataloader",
"BartDataloader", "ClsDataloader", "PrefixlmDataloader", "ClsMlmDataloader",
- "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "DalleDataloader", "LlmSftDataloader", "str2dataloader",
+ "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "DalleDataloader",
+ "LlmSftDataloader", "LlavaDataloader", "str2dataloader",
"gelu", "gelu_fast", "relu", "silu", "linear", "str2act",
"AdamW", "Adafactor", "str2optimizer",
"get_linear_schedule_with_warmup", "get_cosine_schedule_with_warmup",
diff --git a/tencentpretrain/utils/act_fun.py b/tencentpretrain/utils/act_fun.py
index be306a6b..125607c2 100644
--- a/tencentpretrain/utils/act_fun.py
+++ b/tencentpretrain/utils/act_fun.py
@@ -10,6 +10,12 @@ def gelu(x):
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+def gelu_quick(x):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+ return x * torch.sigmoid(1.702 * x)
+
def relu(x):
return F.relu(x)
diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py
index b7269518..d1fd48ef 100755
--- a/tencentpretrain/utils/dataloader.py
+++ b/tencentpretrain/utils/dataloader.py
@@ -546,6 +546,7 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca
self.patch_size = args.patch_size
self.image_height = args.image_height
self.image_width = args.image_width
+ self.args = args
from torchvision import transforms
from tencentpretrain.utils.misc import ZeroOneNormalize
@@ -553,6 +554,9 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca
preprocess_pipeline = []
if "corp" in args.image_preprocess:
preprocess_pipeline.append(transforms.RandomResizedCrop(max(self.image_height, self.image_width)))
+ elif "center_crop" in args.image_preprocess:
+ preprocess_pipeline.append(transforms.Resize(min(self.image_height, self.image_width)))
+ preprocess_pipeline.append(transforms.CenterCrop((self.image_height, self.image_width)))
if "horizontal_flip" in args.image_preprocess:
preprocess_pipeline.append(transforms.RandomHorizontalFlip())
preprocess_pipeline.append(transforms.Resize((self.image_height, self.image_width)))
@@ -963,3 +967,93 @@ def __iter__(self):
yield torch.LongTensor(src), \
torch.LongTensor(tgt), \
torch.LongTensor(seg)
+
+
+class LlavaDataloader(VisionDataloader):
+
+ def __iter__(self):
+ """
+ instances: ((src, tgt), (seg_src, seg_tgt), (src_image, image_pos))
+ src, tgt: Tokens of the text sample
+ seg_src_nums, seg_tgt_nums: Number of the segment information of text sample
+ src_image: Path of the image sample
+ image_pos: Position of the image in the text sample
+
+ Returns:
+ src_text: [batch_size x seq_length]
+ src_image: [batch_size x channel_size x width x hight]
+ tgt: [batch_size x seq_length]
+ seg_text: [batch_size x seq_length]
+ seg_image: [batch_size x (patch_num + 1)]
+ seg_tgt: [batch_size x seq_length]
+ image_pos: [batch_size]
+
+ """
+ from torchvision.io import read_image
+ from torchvision.io.image import ImageReadMode
+ from tencentpretrain.utils.misc import expand2square
+
+ seg_image_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size)
+ while True:
+ while self._empty():
+ self._fill_buf()
+ if self.start + self.batch_size >= self.end:
+ instances = self.buffer[self.start:]
+ else:
+ instances = self.buffer[self.start: self.start + self.batch_size]
+
+ self.start += self.batch_size
+
+ src_text = []
+ src_image = []
+ tgt = []
+ seg_text = []
+ seg_image = []
+ seg_tgt = []
+ image_pos = []
+
+ for ins in instances:
+ ins_src, ins_tgt = ins[0]
+ ins_seg_nums_src, ins_seg_nums_tgt = ins[1]
+ ins_src_image, ins_image_pos = ins[2]
+ seq_length = len(ins_src)
+ text_seq_length = seq_length - seg_image_num
+
+ try:
+ if "pad" in self.args.image_preprocess:
+ from PIL import Image
+ import numpy as np
+ import torchvision.transforms.functional as transform
+ image = Image.open(ins_src_image)
+ image = expand2square(image)
+ image = torch.from_numpy((np.array(image).transpose(2,0,1)))
+ else:
+ image = read_image(ins_src_image, ImageReadMode.RGB)
+ except:
+ print("Something is wrong when reading {}, just skipped!".format(ins_src_image))
+ continue
+ image = image.cuda(self.local_rank)
+ src_image.append(self.transform(image))
+ seg_image.append([1] * (seg_image_num + 1))
+ image_pos.append(ins_image_pos)
+
+ src_text.append(ins_src[:text_seq_length])
+ ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1]
+ seg_text.append(ins_seg_src[:text_seq_length])
+
+ ins_tgt_new = [self.vocab.get(PAD_TOKEN)] * seg_image_num + ins_tgt
+ tgt.append(ins_tgt_new[:seq_length])
+ ins_seg_tgt = [0] * seg_image_num
+ for i, num in enumerate(ins_seg_nums_tgt):
+ ins_seg_tgt = ins_seg_tgt + [i % 2] * num
+ seg_tgt.append(ins_seg_tgt[:seq_length])
+
+ if len(src_image) == 0:
+ continue
+ yield torch.LongTensor(src_text), \
+ torch.stack(src_image, 0).half(), \
+ torch.LongTensor(tgt), \
+ torch.LongTensor(seg_text), \
+ torch.LongTensor(seg_image), \
+ torch.LongTensor(seg_tgt), \
+ image_pos
diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py
index 9c7a682b..ea927b2e 100755
--- a/tencentpretrain/utils/dataset.py
+++ b/tencentpretrain/utils/dataset.py
@@ -57,6 +57,7 @@ def __init__(self, args, vocab, tokenizer):
self.span_max_length = args.span_max_length
self.docs_buffer_size = args.docs_buffer_size
self.dup_factor = args.dup_factor
+ self.args = args
def build_and_save(self, workers_num):
"""
@@ -1083,3 +1084,109 @@ def worker(self, proc_id, start, end):
break
dataset_writer.close()
+
+
+class LlavaDataset(Dataset):
+ def worker(self, proc_id, start, end):
+ import json
+ PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
+ role1, role2 = "USER", "ASSISTANT"
+ im_start, im_end = " ", ""
+ print("Worker %d is building dataset ... " % proc_id)
+ set_seed(self.seed)
+ dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb")
+ pos = start
+ skip_item = 0
+ with open(self.corpus_path, mode="r", encoding="utf-8") as f:
+ data = json.load(f)
+ while True:
+ item = data[pos]
+ pos += 1
+ try:
+ path = item["image"]
+ if not os.path.isfile(path):
+ continue
+ except:
+ skip_item += 1
+ continue
+ conversations = item["conversations"]
+ if "instruction" in item.keys():
+ inst = item["instruction"]
+ else:
+ inst = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
+ if inst:
+ prompt_before_image = inst + " " + role1 + ":"
+ else:
+ prompt_before_image = role1 + ":"
+ prompt_answer_seg_nums, tgt_seg_nums = [], []
+ for i, conv in enumerate(conversations):
+ if i == 0:
+ if isinstance(conv, str):
+ prompt = conv
+ else:
+ prompt = conv["value"]
+ if "" in prompt:
+ before_image, after_image = prompt.split("")
+ prompt_before_image = prompt_before_image + before_image + im_start
+ prompt_after_image = im_end + "\n" + after_image + " " + role2 + ":"
+ else:
+ prompt_before_image = prompt_before_image + im_start
+ prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":"
+ prompt_before_image_id = self.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + self.tokenizer.tokenize(prompt_before_image))
+ prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image))
+ seg_before_image = [1] * len(prompt_before_image_id)
+ seg_after_image = [1] * len(prompt_after_image_id)
+ if len(prompt_before_image_id) + len(prompt_after_image_id) > self.seq_length:
+ print("promt too long, jumped")
+ continue
+ prompt_answer_id = prompt_before_image_id + prompt_after_image_id
+ tgt_id = [PAD_ID] * (len(prompt_answer_id) - 1)
+ tgt_seg_nums = [len(tgt_id)]
+ elif i % 2 == 0: # human
+ if isinstance(conv, str):
+ prompt = conv
+ else:
+ prompt = conv["value"]
+ prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(role1 + ":" + prompt + " " + role2 + ":"))
+ prompt_answer_id = prompt_answer_id + prompt_id
+ tgt_id = tgt_id + [PAD_ID] * len(prompt_id)
+ if len(tgt_seg_nums) == 1:
+ tgt_seg_nums[0] = tgt_seg_nums[0] + len(prompt_id)
+ else:
+ tgt_seg_nums = tgt_seg_nums + [len(prompt_id)]
+ else: # gpt
+ if isinstance(conv, str):
+ answer = conv
+ else:
+ answer = conv["value"]
+ answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer) + [SEP_TOKEN])
+ prompt_answer_id = prompt_answer_id + answer_id
+ tgt_id = tgt_id + answer_id
+ tgt_seg_nums = tgt_seg_nums + [len(answer_id)]
+
+ if len(tgt_id) > self.seq_length:
+ tgt_id = tgt_id[:self.seq_length]
+ pad_num = self.seq_length - len(tgt_id)
+ tgt_id = tgt_id + [PAD_ID] * pad_num
+ while sum(tgt_seg_nums) > self.seq_length:
+ tgt_seg_nums = tgt_seg_nums[:-1]
+ pad_num = self.seq_length - sum(tgt_seg_nums)
+ tgt_seg_nums = tgt_seg_nums + [pad_num]
+
+ if len(prompt_answer_id) > self.seq_length :
+ prompt_answer_id = prompt_answer_id[:self.seq_length]
+
+ pad_num = self.seq_length - len(prompt_answer_id)
+ prompt_answer_seg_nums = [len(prompt_answer_id), pad_num]
+ prompt_answer_id = prompt_answer_id + [PAD_ID] * pad_num
+
+ image_pos = len(prompt_before_image_id)
+ src = (prompt_answer_id, tgt_id)
+ seg_nums = (prompt_answer_seg_nums, tgt_seg_nums)
+ image = (path, image_pos)
+ pickle.dump((src, seg_nums, image), dataset_writer)
+
+ if pos >= end:
+ break
+
+ dataset_writer.close()
diff --git a/tencentpretrain/utils/misc.py b/tencentpretrain/utils/misc.py
index 01545650..eea8541f 100644
--- a/tencentpretrain/utils/misc.py
+++ b/tencentpretrain/utils/misc.py
@@ -4,6 +4,11 @@
def count_lines(file_path):
lines_num = 0
+ if file_path.endswith(".json"):
+ import json
+ with open(file_path, 'rb') as f:
+ data = json.load(f)
+ return len(data)
with open(file_path, 'rb') as f:
while True:
data = f.read(2 ** 20)
@@ -34,6 +39,24 @@ def pooling(memory_bank, seg, pooling_type):
features = memory_bank[:, 0, :]
return features
+
class ZeroOneNormalize(object):
def __call__(self, img):
return img.float().div(255)
+
+
+def expand2square(img, background_color=(122, 116, 104)):
+ from PIL import Image
+ width, height = img.size
+ if img.mode != "RGB":
+ img = img.convert("RGB")
+ if width == height:
+ return img
+ elif width > height:
+ result = Image.new(img.mode, (width, width), background_color)
+ result.paste(img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(img.mode, (height, height), background_color)
+ result.paste(img, ((height - width) // 2, 0))
+ return result