Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add autocast to prediction_step for SFTTrainer #2310

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

pdufour
Copy link

@pdufour pdufour commented Nov 3, 2024

What does this PR do?

Note: not ready for review until huggingface/transformers#32346 lands.

There is a PR to add predict_with_generate to the base Trainer class https://github.com/huggingface/transformers/pull/32346/files. However without any changes to SFTrainer, it throws error if you use it and have mixed precision turned on.

You can see a simliar bug (and fix) happened here - #1203

This pull request aims to make SFTrainer compatible with the soon-to-land predict_with_generate changes.

Reproducible Example:

import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from datasets import Dataset
from peft import LoraConfig
from PIL import Image
from io import BytesIO

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config)

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer

peft_config = LoraConfig(target_modules=["k_proj", "o_proj", "q_proj", "v_proj"], task_type="CAUSAL_LM")

args = SFTConfig(
    max_steps=2, output_dir="/tmp/test-model", gradient_checkpointing=True, bf16=True, bf16_full_eval=True,
    logging_steps=1, push_to_hub=False, report_to="wandb", eval_strategy="steps", eval_steps=2,
    generation_config={"num_beams": 1, "max_length": 1024}, predict_with_generate=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}, dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False
)

# Example of structured data without URLs for images
data = [
    {
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Generate code for a web page that looks exactly like this."},
                    {"type": "image", "image": "placeholder"}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": "<html><body>Sample HTML content for the web page...</body></html>"}
                ]
            }
        ]
    }
]

dataset = Dataset.from_list(data)

def generate_placeholder_image():
    placeholder_image = Image.new("RGB", (256, 256), color="gray")
    return placeholder_image

def collate_fn(examples, mode):
    texts = [processor.apply_chat_template(example["messages"], tokenize=False).strip() for example in examples]
    texts_eval = [processor.apply_chat_template([example["messages"][0]], add_generation_prompt=True).strip() for example in examples]

    images = [[generate_placeholder_image()] for _ in examples]
    
    processor.tokenizer.padding_side = "right"
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    
    if mode == 'test':
        processor.tokenizer.padding_side = "left"
        batch_eval = processor(text=texts_eval, images=images, return_tensors="pt", padding=True)
        batch["generation_input_ids"] = batch_eval["input_ids"]
        batch["generation_attention_mask"] = batch_eval["attention_mask"]

    labels = batch["input_ids"].clone()
    labels[labels == tokenizer.pad_token_id] = -100
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels
    
    return batch

trainer = SFTTrainer(
    model=model, args=args, train_dataset=dataset, data_collator=lambda x: collate_fn(x, mode='train'), eval_dataset=dataset, 
    dataset_text_field="messages", tokenizer=tokenizer, peft_config=peft_config,
    compute_metrics = lambda: dict(),
)

trainer.eval_data_collator = lambda examples: collate_fn(examples, 'test')

trainer.train()
print('Done')

Test Plan

Test fail case

  1. Install transformers branch with train_predict flag pip install git+https://github.com/pdufour/transformers.git@train_predict
  2. Edit the above code and comment out the generation config, predict_with_generate=True, arg, and print(f"bf16 casting status: {trainer._peft_has_been_casted_to_bf16}") line
  3. Run the above code
  4. See the error
image

Test success case

  1. Install trl branch with this fix (this pr)pip install git+https://github.com/pdufour/trl.git@train_predict
  2. Run the code
  3. See no error

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pdufour pdufour changed the title Add autocast to prediction_step Add autocast to prediction_step for SFTTrainer Nov 3, 2024
@pdufour pdufour changed the title Add autocast to prediction_step for SFTTrainer [Draft] Add autocast to prediction_step for SFTTrainer Nov 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant