Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to finetune llava-ov from PREV_STAGE_CHECKPOINT? #378

Open
BBBBchan opened this issue Dec 29, 2024 · 0 comments
Open

How to finetune llava-ov from PREV_STAGE_CHECKPOINT? #378

BBBBchan opened this issue Dec 29, 2024 · 0 comments

Comments

@BBBBchan
Copy link

After reading the scripts/train/README.md, I am attempting to reproduce the training of LLaVA-OneVision from scratch. I successfully ran the scripts/train/pretrain_siglip.sh script, specifying the output directory as checkpoints/projectors/${BASE_RUN_NAME}, where BASE_RUN_NAME is set to llavanext-model_zoo_google_siglip-so400m-patch14-384-model_zoo_Qwen_Qwen2.5-0.5B-Instruct-mlp2x_gelu-pretrain_blip558k_plain.

Upon completion of the pretraining phase, the output directory contains the following files:

config.json
mm_projector.bin
trainer_state.json

Then, referring to the scripts/train/finetune_ov.sh script, I replaced the PREV_STAGE_CHECKPOINT variable with llavanext-model_zoo_google_siglip-so400m-patch14-384-model_zoo_Qwen_Qwen2.5-0.5B-Instruct-mlp2x_gelu-pretrain_blip558k_plain. However, this resulted in an error when I ran the fintune script:

Traceback (most recent call last):
  File "/mnt/data/LLaVA-NeXT/llava/train/train_mem.py", line 4, in <module>
    train()
  File "/mnt/data/LLaVA-NeXT/llava/train/train.py", line 1496, in train
    model = get_model(model_args, training_args, bnb_model_from_pretrained_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/LLaVA-NeXT/llava/train/train.py", line 1428, in get_model
    model = LlavaQwenForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3144, in from_pretrained
    raise EnvironmentError(
OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory checkpoints/projectors/llavanext-model_zoo_google_siglip-so400m-patch14-384-model_zoo_Qwen_Qwen2.5-0.5B-Instruct-mlp2x_gelu-pretrain_blip558k_plain.

It seems like the complete checkpoint is missing; the previous stage appears to have saved only the mm_projector.bin file rather than the entire model. How can I obtain a full checkpoint from the mm_projectors.bin?

P.S. In case my understanding was incorrect from the start, the correct script workflow for training llava-ov from scratch should be:

Stage-1: pretrain_siglip.sh
Stage-1.5: finetune_ov.sh (using the checkpoint from Stage-1)
Stage-2: finetune_ov.sh (using the checkpoint from Stage-1.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant