diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 51f4a5dda2..4f578b6282 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -375,6 +375,40 @@ python3 ../gaudi_spawn.py \ --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' ``` +Here are single card training command examples for llava-hf/llava-1.5-7b-hf. + +``` +python3 run_image2text_lora_finetune.py \ + --model_name_or_path llava-hf/llava-1.5-7b-hf \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llava \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference \ + --low_cpu_mem_usage True +``` + ## Multi-HPU inference ### BF16 Inference with FusedSDPA on 8 HPUs diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index ded60e6d52..b86247376f 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -297,8 +297,58 @@ def __call__(self, examples): return batch +class LLavaDataCollator: + def __init__(self, processor, max_seq_length): + self.processor = processor + + num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * ( + self.processor.image_processor.crop_size["width"] // self.processor.patch_size + ) + 1 + if self.processor.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + # text length + image length + self.max_seq_length = max_seq_length + num_image_tokens + + def __call__(self, examples): + texts = [] + images = [] + + keys = list(examples[0].keys()) + if not all(key in ["image", "query", "answers"] for key in keys): + raise ValueError("Unsupported dataset format") + for example in examples: + image = example["image"] + question = example["query"]["en"] + answer = random.choice(example["answers"]) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": question}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + text = self.processor.apply_chat_template(messages, add_generation_prompt=False) + texts.append(text.strip()) + images.append(image) + + batch = self.processor( + images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length + ) + + labels = batch["input_ids"].clone() + if self.processor.tokenizer.pad_token_id is not None: + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + + return batch -def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length): + +def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length, model_arc=""): from tqdm import tqdm answers_unique = [] @@ -307,7 +357,6 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m for i in tqdm(range(0, len(dataset), batch_size)): examples = dataset[i : i + batch_size] answers_unique.extend(examples["answers"]) - images = [[im] for im in examples["image"]] texts = [] for q in examples["query"]: messages = [ @@ -322,14 +371,31 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m ] text = processor.apply_chat_template(messages, add_generation_prompt=True) texts.append(text.strip()) - inputs = processor( - text=texts, - images=images, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=max_seq_length, - ) + + if "Llava" in model_arc: + images = [] + for im in examples["image"]: + images.append(im) + + inputs = processor( + images, + texts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_seq_length, + padding_side="left", + ) + else: + images = [[im] for im in examples["image"]] + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_seq_length, + ) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate( **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs @@ -346,6 +412,22 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m return anls +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): @@ -380,7 +462,7 @@ def main(): do_image_splitting=model_args.do_image_splitting, padding_side="right", ) - setattr(processor.image_processor, "pad_to_longest_edge", True) + config_kwargs = { "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, @@ -395,7 +477,17 @@ def main(): else: raise ValueError("Please provide value for model_name_or_path or config_name.") - # Load model + model_arc = "" + if config.architectures is not None: + model_arc = config.architectures[0] + + if "Llava" in model_arc: + setattr(processor, "patch_size", config.vision_config.patch_size) + setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy) + else: + setattr(processor.image_processor, "pad_to_longest_edge", True) + + # Load model if model_args.model_name_or_path: model_dtype = torch.bfloat16 if training_args.bf16 else None model = AutoModelForVision2Seq.from_pretrained( @@ -413,11 +505,16 @@ def main(): else: raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") + if finetune_args.lora_target_modules is None: + target_modules = find_all_linear_names(model) + else: + target_modules = finetune_args.lora_target_modules + lora_config = LoraConfig( r=finetune_args.lora_rank, lora_alpha=finetune_args.lora_alpha, lora_dropout=finetune_args.lora_dropout, - target_modules=finetune_args.lora_target_modules, + target_modules=target_modules, init_lora_weights="gaussian", ) model = get_peft_model(model, lora_config) @@ -456,15 +553,19 @@ def main(): if col not in (data_args.input_column_names + data_args.output_column_names) ] ) - if hasattr(config, "image_token_id"): - # idefics - image_token_id = config.image_token_id - elif hasattr(config, "image_token_index"): - # mllama - image_token_id = config.image_token_index + if "Llava" in model_arc: + data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length) else: - raise ValueError("Please provide value for image_token_id") - data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id) + if hasattr(config, "image_token_id"): + # idefics + image_token_id = config.image_token_id + elif hasattr(config, "image_token_index"): + # mllama + image_token_id = config.image_token_index + else: + raise ValueError("Please provide value for image_token_id") + + data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id) gaudi_config = GaudiConfig() gaudi_config.use_fused_adam = True @@ -509,14 +610,29 @@ def main(): } ] text = processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = processor( - text=[text.strip()], - images=[image], - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=data_args.max_seq_length, - ) + + if "Llava" in model_arc: + # don't expand image_token_id + setattr(processor, "patch_size", None) + setattr(processor, "vision_feature_select_strategy", None) + inputs = processor( + [image], + [text.strip()], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + padding_side="left", + ) + else: + inputs = processor( + text=[text.strip()], + images=[image], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + ) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate( **inputs, @@ -543,6 +659,7 @@ def main(): use_lazy_mode=training_args.use_lazy_mode, use_hpu_graphs=training_args.use_hpu_graphs_for_inference, max_seq_length=data_args.max_seq_length, + model_arc=model_arc ) eval_metrics = {"eval_accuracy": anls} trainer.log_metrics("eval", eval_metrics) diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index 997c16d700..cd5fb17d79 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn as nn from transformers.cache_utils import Cache from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration from transformers.utils import logging @@ -129,57 +130,74 @@ def forward( flash_attention_recompute: Optional[bool] = False, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llava/modeling_llava.py + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/llava/modeling_llava.py#L362 The only differences are: - add new args token_idx - add new args image_offset - add new args tokens_pos """ - if token_idx is not None: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - # 1. Extra the input embeddings + legacy_processing = False + if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - image_features = None - # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower( - pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + legacy_processing = ( + (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) - image_features = self.multi_modal_projector(selected_image_feature) - inputs_embeds = _merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, self.config.image_token_index - ) + image_features = None + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower( + pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + image_features = self.multi_modal_projector(selected_image_feature) + inputs_embeds = _merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, self.config.image_token_index + ) + + if token_idx is not None: outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, @@ -190,8 +208,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here - # num_logits_to_keep=num_logits_to_keep, + num_logits_to_keep=num_logits_to_keep, token_idx=token_idx + image_offset, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, @@ -220,20 +237,50 @@ def forward( ) else: - return super().forward( - input_ids=input_ids, - pixel_values=pixel_values, + outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation(