Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support llava1.5 lora finetuning. #1487

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,40 @@ python3 ../gaudi_spawn.py \
--lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

Here are single card training command examples for llava-hf/llava-1.5-7b-hf.

```
python3 run_image2text_lora_finetune.py \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--dataset_name nielsr/docvqa_1200_examples \
--bf16 True \
--output_dir ./model_lora_llava \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--weight_decay 0.01 \
--logging_steps 25 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--lr_scheduler_type "constant" \
--input_column_names 'image' 'query' \
--output_column_names 'answers' \
--remove_unused_columns False \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--lora_rank=8 \
--lora_alpha=8 \
--lora_dropout=0.1 \
--max_seq_length=512 \
--use_hpu_graphs_for_inference \
--low_cpu_mem_usage True
```

## Multi-HPU inference

### BF16 Inference with FusedSDPA on 8 HPUs
Expand Down
175 changes: 146 additions & 29 deletions examples/image-to-text/run_image2text_lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,58 @@ def __call__(self, examples):

return batch

class LLavaDataCollator:
def __init__(self, processor, max_seq_length):
self.processor = processor

num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * (
self.processor.image_processor.crop_size["width"] // self.processor.patch_size
) + 1
if self.processor.vision_feature_select_strategy == "default":
num_image_tokens -= 1

# text length + image length
self.max_seq_length = max_seq_length + num_image_tokens

def __call__(self, examples):
texts = []
images = []

keys = list(examples[0].keys())
if not all(key in ["image", "query", "answers"] for key in keys):
raise ValueError("Unsupported dataset format")
for example in examples:
image = example["image"]
question = example["query"]["en"]
answer = random.choice(example["answers"])
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append(image)

batch = self.processor(
images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length
)

labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch

def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length):

def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length, model_arc=""):
from tqdm import tqdm

answers_unique = []
Expand All @@ -307,7 +357,6 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
for i in tqdm(range(0, len(dataset), batch_size)):
examples = dataset[i : i + batch_size]
answers_unique.extend(examples["answers"])
images = [[im] for im in examples["image"]]
texts = []
for q in examples["query"]:
messages = [
Expand All @@ -322,14 +371,31 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
texts.append(text.strip())
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
)

if "Llava" in model_arc:
images = []
for im in examples["image"]:
images.append(im)

inputs = processor(
images,
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
padding_side="left",
)
else:
images = [[im] for im in examples["image"]]
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
)
inputs = {k: v.to("hpu") for k, v in inputs.items()}
generated_ids = model.generate(
**inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs
Expand All @@ -346,6 +412,22 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
return anls


def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -380,7 +462,7 @@ def main():
do_image_splitting=model_args.do_image_splitting,
padding_side="right",
)
setattr(processor.image_processor, "pad_to_longest_edge", True)

config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
Expand All @@ -395,7 +477,17 @@ def main():
else:
raise ValueError("Please provide value for model_name_or_path or config_name.")

# Load model
model_arc = ""
if config.architectures is not None:
model_arc = config.architectures[0]

if "Llava" in model_arc:
setattr(processor, "patch_size", config.vision_config.patch_size)
setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy)
else:
setattr(processor.image_processor, "pad_to_longest_edge", True)

# Load model
if model_args.model_name_or_path:
model_dtype = torch.bfloat16 if training_args.bf16 else None
model = AutoModelForVision2Seq.from_pretrained(
Expand All @@ -413,11 +505,16 @@ def main():
else:
raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")

if finetune_args.lora_target_modules is None:
target_modules = find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules

lora_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=finetune_args.lora_target_modules,
target_modules=target_modules,
init_lora_weights="gaussian",
)
model = get_peft_model(model, lora_config)
Expand Down Expand Up @@ -456,15 +553,19 @@ def main():
if col not in (data_args.input_column_names + data_args.output_column_names)
]
)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
if "Llava" in model_arc:
data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length)
else:
raise ValueError("Please provide value for image_token_id")
data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
else:
raise ValueError("Please provide value for image_token_id")

data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
Expand Down Expand Up @@ -509,14 +610,29 @@ def main():
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[text.strip()],
images=[image],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
)

if "Llava" in model_arc:
# don't expand image_token_id
setattr(processor, "patch_size", None)
setattr(processor, "vision_feature_select_strategy", None)
inputs = processor(
[image],
[text.strip()],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
padding_side="left",
)
else:
inputs = processor(
text=[text.strip()],
images=[image],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
)
inputs = {k: v.to("hpu") for k, v in inputs.items()}
generated_ids = model.generate(
**inputs,
Expand All @@ -543,6 +659,7 @@ def main():
use_lazy_mode=training_args.use_lazy_mode,
use_hpu_graphs=training_args.use_hpu_graphs_for_inference,
max_seq_length=data_args.max_seq_length,
model_arc=model_arc
)
eval_metrics = {"eval_accuracy": anls}
trainer.log_metrics("eval", eval_metrics)
Expand Down
Loading