Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support llava1.5 lora finetuning. #1487

Merged
merged 23 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
78d1ab6
support llava1.5 lora finetuning.
lkk12014402 Nov 14, 2024
bd85520
Merge branch 'main' into llava1.5
lkk12014402 Dec 2, 2024
9e22eb4
make style
lkk12014402 Dec 2, 2024
2e3580c
Merge branch 'huggingface:main' into llava1.5
lkk12014402 Dec 4, 2024
e0be37f
for transformers==v4.45.2.
lkk12014402 Dec 4, 2024
1ff6602
for transformers==v4.45.2.
lkk12014402 Dec 4, 2024
d83743b
Merge branch 'huggingface:main' into llava1.5
lkk12014402 Dec 6, 2024
9a547b5
merge two scripts.
lkk12014402 Dec 6, 2024
8af7c56
Merge branch 'main' into llava1.5
lkk12014402 Jan 27, 2025
5a7e0f8
update llava training command readme.
lkk12014402 Jan 27, 2025
8a9824e
remove command text
lkk12014402 Feb 5, 2025
1a28e90
Merge branch 'huggingface:main' into llava1.5
lkk12014402 Feb 5, 2025
0f959b7
Remove useless cmds.
lkk12014402 Feb 5, 2025
b469e82
Merge branch 'huggingface:main' into llava1.5
lkk12014402 Feb 11, 2025
7f12447
Update examples/image-to-text/run_image2text_lora_finetune.py
libinta Feb 12, 2025
b5df89e
Update examples/image-to-text/run_image2text_lora_finetune.py
libinta Feb 12, 2025
d76db0c
Update examples/image-to-text/run_image2text_lora_finetune.py
libinta Feb 12, 2025
6dd64c6
Update examples/image-to-text/run_image2text_lora_finetune.py
libinta Feb 12, 2025
2c43cc6
Update examples/image-to-text/run_image2text_lora_finetune.py
libinta Feb 12, 2025
6c730dc
Update examples/image-to-text/run_image2text_lora_finetune.py
libinta Feb 12, 2025
b417be3
Add regression test
regisss Feb 17, 2025
01cf51d
Merge remote-tracking branch 'optimum-habana/main' into llava1.5
regisss Feb 17, 2025
d3def06
Make style
regisss Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,40 @@ python3 ../gaudi_spawn.py \
--lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

Here are single card training command examples for llava-hf/llava-1.5-7b-hf.

```
python3 run_image2text_lora_finetune.py \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--dataset_name nielsr/docvqa_1200_examples \
--bf16 True \
--output_dir ./model_lora_llava \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--weight_decay 0.01 \
--logging_steps 25 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--lr_scheduler_type "constant" \
--input_column_names 'image' 'query' \
--output_column_names 'answers' \
--remove_unused_columns False \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--lora_rank=8 \
--lora_alpha=8 \
--lora_dropout=0.1 \
--max_seq_length=512 \
--use_hpu_graphs_for_inference \
--low_cpu_mem_usage True
```

## Multi-HPU inference

### BF16 Inference with FusedSDPA on 8 HPUs
Expand Down
175 changes: 146 additions & 29 deletions examples/image-to-text/run_image2text_lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,58 @@ def __call__(self, examples):

return batch

class LLavaDataCollator:
def __init__(self, processor, max_seq_length):
self.processor = processor

num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * (
self.processor.image_processor.crop_size["width"] // self.processor.patch_size
) + 1
if self.processor.vision_feature_select_strategy == "default":
num_image_tokens -= 1

# text length + image length
self.max_seq_length = max_seq_length + num_image_tokens

def __call__(self, examples):
texts = []
images = []

keys = list(examples[0].keys())
if not all(key in ["image", "query", "answers"] for key in keys):
raise ValueError("Unsupported dataset format")
for example in examples:
image = example["image"]
question = example["query"]["en"]
answer = random.choice(example["answers"])
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append(image)

batch = self.processor(
images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length
)

labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch

def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length):

def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length, model_arc=""):
from tqdm import tqdm

answers_unique = []
Expand All @@ -307,7 +357,6 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
for i in tqdm(range(0, len(dataset), batch_size)):
examples = dataset[i : i + batch_size]
answers_unique.extend(examples["answers"])
images = [[im] for im in examples["image"]]
texts = []
for q in examples["query"]:
messages = [
Expand All @@ -322,14 +371,31 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
texts.append(text.strip())
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
)

if "Llava" in model_arc:
images = []
for im in examples["image"]:
images.append(im)

inputs = processor(
images,
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
padding_side="left",
)
else:
images = [[im] for im in examples["image"]]
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
)
inputs = {k: v.to("hpu") for k, v in inputs.items()}
generated_ids = model.generate(
**inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs
Expand All @@ -346,6 +412,22 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
return anls


def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -380,7 +462,7 @@ def main():
do_image_splitting=model_args.do_image_splitting,
padding_side="right",
)
setattr(processor.image_processor, "pad_to_longest_edge", True)

config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
Expand All @@ -395,7 +477,17 @@ def main():
else:
raise ValueError("Please provide value for model_name_or_path or config_name.")

# Load model
model_arc = ""
if config.architectures is not None:
model_arc = config.architectures[0]

if "Llava" in model_arc:
setattr(processor, "patch_size", config.vision_config.patch_size)
setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy)
else:
setattr(processor.image_processor, "pad_to_longest_edge", True)

# Load model
if model_args.model_name_or_path:
model_dtype = torch.bfloat16 if training_args.bf16 else None
model = AutoModelForVision2Seq.from_pretrained(
Expand All @@ -413,11 +505,16 @@ def main():
else:
raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")

if finetune_args.lora_target_modules is None:
target_modules = find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules

lora_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=finetune_args.lora_target_modules,
target_modules=target_modules,
init_lora_weights="gaussian",
)
model = get_peft_model(model, lora_config)
Expand Down Expand Up @@ -456,15 +553,19 @@ def main():
if col not in (data_args.input_column_names + data_args.output_column_names)
]
)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
if "Llava" in model_arc:
data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length)
else:
raise ValueError("Please provide value for image_token_id")
data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
else:
raise ValueError("Please provide value for image_token_id")

data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
Expand Down Expand Up @@ -509,14 +610,29 @@ def main():
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[text.strip()],
images=[image],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
)

if "Llava" in model_arc:
# don't expand image_token_id
setattr(processor, "patch_size", None)
setattr(processor, "vision_feature_select_strategy", None)
inputs = processor(
[image],
[text.strip()],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
padding_side="left",
)
else:
inputs = processor(
text=[text.strip()],
images=[image],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
)
inputs = {k: v.to("hpu") for k, v in inputs.items()}
generated_ids = model.generate(
**inputs,
Expand All @@ -543,6 +659,7 @@ def main():
use_lazy_mode=training_args.use_lazy_mode,
use_hpu_graphs=training_args.use_hpu_graphs_for_inference,
max_seq_length=data_args.max_seq_length,
model_arc=model_arc
)
eval_metrics = {"eval_accuracy": anls}
trainer.log_metrics("eval", eval_metrics)
Expand Down
Loading