forked from TRI-ML/prismatic-vlms
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Epic] Prismatic Structural Refactor (TRI-ML#16)
This is an (intentionally) massive PR that completely refactors the base Prismatic VLM codebase following TRI-ML#15. **Please do not try to review this entire PR for your sanity; instead, see "proof-of-correctness" below**. All Prismatic models are now instances of HuggingFace `transformers.AutoModelForVis2Seq` and have native compatibility with all external HF libraries (e.g., `peft`, `accelerate`, `bitsandbytes`), and can easily be integrated with existing training frameworks (HF Trainer, PyTorch Lightning). Because this PR represents what I feel is the most "stable" version of the Prismatic codebase, I've bumped the major version to `v1.0.0`. Additionally, this PR implements: - Support for batched generation (speeds up decoding) - Conversion scripts for "v0" model weights, with all configs/models/processors pushed to [huggingface.co/TRI-ML](https://huggingface.co/collections/TRI-ML/prismatic-vlms-66857a7c64b6a6b6fbc84ea4) - Most of the "popular" Prismatic checkpoints have already been converted + uploaded as **private** models; the remaining models will be converted iteratively. - Simplified interactive generation script at `scripts/generate.py` - Basic validation + tests. CC @blake-wulfe-tri @charlesrichter-tri @jmercat @paarthshah-tri @jensen-gao-tri for visibility. Resolves TRI-ML#15 --- **Proof-of-Correctness**: Rather than review all files, I structured this PR as a series of commits that uphold two invariants: - **Fully (Bitwise) Reproducible Training** - For two model configs, assert that running 64 gradient steps across 8 GPUs results in the *exact same loss curves and performance*. - **Deterministic Generation Output** - When loading (converted) checkpoints, assert that generation output is exactly identical (plus/minus some CUDA non-determinism). **Commits** (for parsing W&B loss curves below): - `base` -- Gray line, represents the original loss curve for training the original models (`siglip-224px+7b` and `prism-dinosiglip-224px-controlled`) from several months ago. - `#fd2a0e4` -- Purple line, represents the latest commit on `vlm-core` (upstream branch); sanity check to make sure nothing has changed in the time since the original models were trained. - [NEW] `#fc78732` -- Green line, implements necessary changes to Prismatic base class to prepare (unify API) for full refactor (adds `<image>` token, batched generation, restructures `forward()` pass to remove dependence on `multimodal_indices`). - [NEW] `#b322374` -- Red line, **full HF "parallel" implementation** (side-by-side with original Prismatic code). Adds new preprocessing objects following HF convention, defines `PrismaticForVision2Seq` core VLM model, conversion scripts, `hf_pretrain.py` and `hf_generate.py`. - [NEW] `#b63f704` -- Orange line, finalizes refactor. Purges all "old" Prismatic code, makes `hf_*` files first-class citizens. Refactors README and installation instructions. *Note*: The `siglip-224px` training runs are bitwise identical across all above commits; the `prism-dinosiglip-224px` is pretty much the same, modulo a slight difference in randomness that stems from the new "fused backbone" API (just affects weight initialization in a small way). <img width="2427" alt="reproducible-loss-curves" src="https://github.com/TRI-ML/prismatic-dev/assets/126100644/b35e877e-d7c0-4c24-b44d-5373720b4a67">
- Loading branch information
Showing
70 changed files
with
2,022 additions
and
4,675 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,6 +144,5 @@ dmypy.json | |
# Mac OS | ||
.DS_Store | ||
|
||
# Caches and Datasets | ||
# Caches | ||
cache/ | ||
data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,10 @@ | ||
from .models import available_model_names, available_models, get_model_description, load | ||
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor | ||
|
||
from prismatic.models import PrismaticConfig, PrismaticForVision2Seq | ||
from prismatic.preprocessing.processors import PrismaticImageProcessor, PrismaticProcessor | ||
|
||
# === Register Models / Processors / Configs to the appropriate HF AutoClasses (required for `.from_pretrained()`) | ||
AutoConfig.register("prismatic", PrismaticConfig) | ||
AutoImageProcessor.register(PrismaticConfig, PrismaticImageProcessor) | ||
AutoProcessor.register(PrismaticConfig, PrismaticProcessor) | ||
AutoModelForVision2Seq.register(PrismaticConfig, PrismaticForVision2Seq) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .download import convert_to_jpg, download_extract | ||
from .materialize import get_dataset_and_collator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
materialize.py | ||
Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for | ||
clear control flow. | ||
""" | ||
|
||
from typing import Callable, Tuple, Type | ||
|
||
import torch | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
from prismatic.conf import DatasetConfig | ||
from prismatic.data.datasets import AlignDataset, FinetuneDataset | ||
from prismatic.preprocessing.prompting import PromptBuilder | ||
from prismatic.util.data_utils import PaddedCollatorForLanguageModeling | ||
|
||
# Dataset Initializers =>> Maps Stage --> cls() | ||
DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} | ||
|
||
|
||
def get_dataset_and_collator( | ||
stage: str, | ||
dataset_cfg: DatasetConfig, | ||
image_transform: Callable[[Image.Image], torch.Tensor], | ||
tokenizer: PreTrainedTokenizerBase, | ||
prompt_builder_fn: Type[PromptBuilder], | ||
padding_side: str = "right", | ||
) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: | ||
dataset_cls = DATASET_INITIALIZER[stage] | ||
dataset_root_dir = dataset_cfg.dataset_root_dir | ||
collator = PaddedCollatorForLanguageModeling( | ||
tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side | ||
) | ||
|
||
# Switch on `stage` | ||
if stage == "align": | ||
annotation_json, image_dir = dataset_cfg.align_stage_components | ||
dataset = dataset_cls( | ||
dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer | ||
) | ||
return dataset, collator | ||
|
||
elif stage == "finetune": | ||
annotation_json, image_dir = dataset_cfg.finetune_stage_components | ||
dataset = dataset_cls( | ||
dataset_root_dir / annotation_json, | ||
dataset_root_dir / image_dir, | ||
image_transform, | ||
tokenizer, | ||
prompt_builder_fn=prompt_builder_fn, | ||
) | ||
return dataset, collator | ||
|
||
elif stage == "full-finetune": | ||
annotation_json, image_dir = dataset_cfg.finetune_stage_components | ||
dataset = dataset_cls( | ||
dataset_root_dir / annotation_json, | ||
dataset_root_dir / image_dir, | ||
image_transform, | ||
tokenizer, | ||
prompt_builder_fn=prompt_builder_fn, | ||
) | ||
return dataset, collator | ||
|
||
else: | ||
raise ValueError(f"Stage `{stage}` is not supported!") |
Empty file.
Empty file.
Oops, something went wrong.