Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support LLaVa #119

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
54 changes: 54 additions & 0 deletions models/llava/7b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"embedding": ["vision_language"],
"vision_language_emb":{
"vision_encoder":{
"image_height": 336,
"image_width": 336,
"patch_size": 14,
"emb_size": 1024,
"feedforward_size": 4096,
"hidden_size": 1024,
"hidden_act": "gelu_quick",
"heads_num": 16,
"layers_num": 24,
"dropout": 0.0,
"max_seq_length": 577,
"embedding": ["patch", "pos"],
"patch_proj_bias": false,
"remove_embedding_layernorm": false,
"remove_transformer_bias": false,
"rotary_position_embedding": false,
"encoder": "transformer",
"feed_forward": "dense",
"mask": "fully_visible",
"layernorm_positioning": "pre",
"layernorm": "normal",
"has_cls": false
},
"projection":{
"mlp_hidden_size": 4096,
"num_mlp_layer": 2
},
"text":{
"embedding": ["word"]
}
},
"emb_size": 4096,
"feedforward_size": 11008,
"hidden_size": 4096,
"hidden_act": "silu",
"heads_num": 32,
"layers_num": 32,
"dropout": 0.0,
"data_processor": "llava",
"max_seq_length": 2048,
"remove_transformer_bias": true,
"remove_embedding_layernorm": true,
"rotary_position_embedding": true,
"encoder": "transformer",
"feed_forward": "gated",
"mask": "causal",
"layernorm_positioning": "pre",
"layernorm": "rms",
"target": ["lm"]
}
4 changes: 3 additions & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def main():
parser.add_argument("--data_processor",
choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls", "prefixlm",
"gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle",
"llm_pretrain", "llm_sft"], default="bert",
"llm_pretrain", "llm_sft", "llava"], default="bert",
help="The data processor of the pretraining model.")
parser.add_argument("--docs_buffer_size", type=int, default=100000,
help="The buffer size of documents in memory, specific to targets that require negative sampling.")
parser.add_argument("--seq_length", type=int, default=128, help="Sequence length of instances.")
parser.add_argument("--tgt_seq_length", type=int, default=128, help="Target sequence length of instances.")
parser.add_argument("--vision_seq_length_in_VL", type=int, default=576,
help="Number of image patches in the vision language model(LLaVa).")
parser.add_argument("--dup_factor", type=int, default=5,
help="Duplicate instances multiple times.")
parser.add_argument("--short_seq_prob", type=float, default=0.1,
Expand Down
2 changes: 2 additions & 0 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def main():
help="Path of the preprocessed dataset.")
parser.add_argument("--pretrained_model_path", type=str, default=None,
help="Path of the pretrained model.")
parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None,
help="Path of the vision pretrained model in the vision language embedding.")
parser.add_argument("--output_model_path", type=str, required=True,
help="Path of the output model.")
parser.add_argument("--config_path", type=str, default="models/bert/base_config.json",
Expand Down
77 changes: 77 additions & 0 deletions scripts/convert_llava_from_huggingface_to_tencentpretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse
import collections
import torch
import os
import json


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/llava-v1.5-7b/",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/llava-v1.5-7b.bin",
help=".")
parser.add_argument("--type", choices=["7B", "13B", "33B", "65B"], default="7B")

args = parser.parse_args()

model_config = {"7B" : [32, 4096, 32],
"13B": [40, 5120, 40],
"33B": [60, 6656, 52],
"65B": [80, 8192, 64]
}

layers_num, dim, n_heads = model_config[args.type]

files = os.listdir(args.input_model_path)
model_files = [f for f in files if f[-4:] == ".bin"]
input_models = {f: torch.load(os.path.join(args.input_model_path, f), map_location="cpu") for f in model_files}

with open(os.path.join(args.input_model_path, "pytorch_model.bin.index.json")) as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]


output_model = collections.OrderedDict()

def get_weight_from_name(layer_name):
return input_models[weight_map[layer_name]][layer_name]

def unpermute(w):
return w.reshape(n_heads, 2, dim // n_heads // 2, dim).transpose(2, 1).reshape(dim, dim)

output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight")

for i in range(layers_num):

output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.weight"] = \
unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.q_proj.weight"))
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.weight"] = \
unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.k_proj.weight"))

output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".self_attn.v_proj.weight")
output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".self_attn.o_proj.weight")

output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".input_layernorm.weight")

output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".mlp.gate_proj.weight")
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".mlp.up_proj.weight")
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".mlp.down_proj.weight")

output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".post_attention_layernorm.weight")

output_model["encoder.layer_norm.weight"] = get_weight_from_name("model.norm.weight")
output_model["target.lm.output_layer.weight"] = get_weight_from_name("lm_head.weight")

output_model["embedding.vision_language.projection.0.weight"] = get_weight_from_name("model.mm_projector.0.weight")
output_model["embedding.vision_language.projection.0.bias"] = get_weight_from_name("model.mm_projector.0.bias")
output_model["embedding.vision_language.projection.2.weight"] = get_weight_from_name("model.mm_projector.2.weight")
output_model["embedding.vision_language.projection.2.bias"] = get_weight_from_name("model.mm_projector.2.bias")

torch.save(output_model, args.output_model_path)
30 changes: 30 additions & 0 deletions scripts/convert_llm_in_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse
import collections
import torch


def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
help=".")


args = parser.parse_args()

input_model = torch.load(args.input_model_path, map_location="cpu")

output_model = collections.OrderedDict()

for k in input_model.keys():
if k == "embedding.word.embedding.weight":
output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = input_model[k]
else:
output_model[k] = input_model[k]

torch.save(output_model, args.output_model_path)


if __name__ == "__main__":
main()
28 changes: 28 additions & 0 deletions scripts/convert_model_add_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
import collections
import torch


def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
help=".")
parser.add_argument("--prefix", type=str, default="", help="prefix to add")


args = parser.parse_args()

input_model = torch.load(args.input_model_path, map_location="cpu")

output_model = collections.OrderedDict()

for k in input_model.keys():
output_model[args.prefix + k] = input_model[k]

torch.save(output_model, args.output_model_path)


if __name__ == "__main__":
main()
Loading