Skip to content

Commit

Permalink
[Epic] Prismatic Structural Refactor (TRI-ML#16)
Browse files Browse the repository at this point in the history
This is an (intentionally) massive PR that completely refactors the base
Prismatic VLM codebase following TRI-ML#15. **Please do not try to review this
entire PR for your sanity; instead, see "proof-of-correctness" below**.

All Prismatic models are now instances of HuggingFace
`transformers.AutoModelForVis2Seq` and have native compatibility with
all external HF libraries (e.g., `peft`, `accelerate`, `bitsandbytes`),
and can easily be integrated with existing training frameworks (HF
Trainer, PyTorch Lightning).

Because this PR represents what I feel is the most "stable" version of
the Prismatic codebase, I've bumped the major version to `v1.0.0`.

Additionally, this PR implements:

- Support for batched generation (speeds up decoding)
- Conversion scripts for "v0" model weights, with all
configs/models/processors pushed to
[huggingface.co/TRI-ML](https://huggingface.co/collections/TRI-ML/prismatic-vlms-66857a7c64b6a6b6fbc84ea4)
- Most of the "popular" Prismatic checkpoints have already been
converted + uploaded as **private** models; the remaining models will be
converted iteratively.
- Simplified interactive generation script at `scripts/generate.py`
- Basic validation + tests.

CC @blake-wulfe-tri @charlesrichter-tri @jmercat @paarthshah-tri
@jensen-gao-tri for visibility.

Resolves TRI-ML#15 

---

**Proof-of-Correctness**: Rather than review all files, I structured
this PR as a series of commits that uphold two invariants:
- **Fully (Bitwise) Reproducible Training** - For two model configs,
assert that running 64 gradient steps across 8 GPUs results in the
*exact same loss curves and performance*.
- **Deterministic Generation Output** - When loading (converted)
checkpoints, assert that generation output is exactly identical
(plus/minus some CUDA non-determinism).

**Commits** (for parsing W&B loss curves below):

- `base` -- Gray line, represents the original loss curve for training
the original models (`siglip-224px+7b` and
`prism-dinosiglip-224px-controlled`) from several months ago.
- `#fd2a0e4` -- Purple line, represents the latest commit on `vlm-core`
(upstream branch); sanity check to make sure nothing has changed in the
time since the original models were trained.
- [NEW] `#fc78732` -- Green line, implements necessary changes to
Prismatic base class to prepare (unify API) for full refactor (adds
`<image>` token, batched generation, restructures `forward()` pass to
remove dependence on `multimodal_indices`).
- [NEW] `#b322374` -- Red line, **full HF "parallel" implementation**
(side-by-side with original Prismatic code). Adds new preprocessing
objects following HF convention, defines `PrismaticForVision2Seq` core
VLM model, conversion scripts, `hf_pretrain.py` and `hf_generate.py`.
- [NEW] `#b63f704` -- Orange line, finalizes refactor. Purges all "old"
Prismatic code, makes `hf_*` files first-class citizens. Refactors
README and installation instructions.

*Note*: The `siglip-224px` training runs are bitwise identical across
all above commits; the `prism-dinosiglip-224px` is pretty much the same,
modulo a slight difference in randomness that stems from the new "fused
backbone" API (just affects weight initialization in a small way).

<img width="2427" alt="reproducible-loss-curves"
src="https://github.com/TRI-ML/prismatic-dev/assets/126100644/b35e877e-d7c0-4c24-b44d-5373720b4a67">
  • Loading branch information
siddk-tri authored Jul 9, 2024
2 parents fd2a0e4 + d09c240 commit 777025a
Show file tree
Hide file tree
Showing 70 changed files with 2,022 additions and 4,675 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,5 @@ dmypy.json
# Mac OS
.DS_Store

# Caches and Datasets
# Caches
cache/
data/
60 changes: 19 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ To facilitate a clean workflow with open-source/public code, this internal repos

- **[Default]** `vlm-core` - Treat this as the `main` branch for developing new VLM-related changes; always PR to this
branch in lieu of `main`.
- `vla-core` - This is the central branch for developing on vision-language-action models; this is synced with external
collaborators. If working on the OpenVLA project, always PR to this branch!
- `main` - Treat this as a **locked branch**; it tracks the latest stable code in the open-source repository.

**Important:** Assume that all commits/features developed for `vlm-core` will be eventually merged into the upstream
Expand Down Expand Up @@ -114,33 +112,38 @@ import requests
import torch

from PIL import Image
from pathlib import Path

from prismatic import load
from prismatic import PrismaticForVision2Seq, PrismaticProcessor
from prismatic.preprocessing import get_prompt_builder_fn

# For gated LMs like Llama-2, make sure to request official access, and generate an access token
hf_token = Path(".hf_token").read_text().strip()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load a pretrained VLM (either local path, or ID to auto-download from the HF Hub)
model_id = "prism-dinosiglip+7b"
vlm = load(model_id, hf_token=hf_token)
vlm.to(device, dtype=torch.bfloat16)
model_path = "TRI-ML/prism-dinosiglip-7b"
processor = PrismaticProcessor.from_pretrained(model_path)
vlm = PrismaticForVision2Seq.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
device_map=device,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(device)

# Download an image and specify a prompt
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
user_prompt = "What is going on in this image?"

# Build prompt
prompt_builder = vlm.get_prompt_builder()
prompt_builder.add_turn(role="human", message=user_prompt)
prompt_builder = get_prompt_builder_fn(vlm.config.llm_backbone_id)()
prompt_builder.add_turn(role="human", message=user_prompt, add_image_token=True)
prompt_text = prompt_builder.get_prompt()

# Generate!
inputs = processor(prompt_text, image).to(device, torch.bfloat16)
generated_text = vlm.generate(
image,
prompt_text,
**inputs,
do_sample=True,
temperature=0.4,
max_new_tokens=512,
Expand All @@ -153,32 +156,10 @@ For a complete terminal-based CLI for interacting with our VLMs, check out [scri
## Pretrained Models

We release **all 49** VLMs trained as part of our work, with a range of different visual representations, language
models, data, and scale. The exhaustive set of models (with structured descriptions) can be found in
[`prismatic/models/registry.py](prismatic/models/registry.py) - we will continue to update this registry as we train
additional models.
models, data, and scale. The exhaustive set of models (with structured descriptions) can be found at
[huggingface.co/TRI-ML](https://huggingface.co/collections/TRI-ML/prismatic-vlms-66857a7c64b6a6b6fbc84ea4) -- we will
continue to update this collection as we train new models.

We also provide a top-level API for instantiating models from the names mentioned in the various Figures of our paper,
as well as for generally browsing our pretrained models by description:

```python
from prismatic import available_model_names, available_models, get_model_description
from pprint import pprint

# List all Pretrained VLMs (by HF Hub IDs)
pprint(available_models())

# List all Pretrained VLMs + Descriptions (by explicit labels / names from paper figures)
pprint(available_model_names())

# Print and return a targeted description of a model (by name or ID)
# =>> See `prismatic/models/registry.py` for explicit schema
description = get_model_description("Prism-DINOSigLIP 13B (Controlled)")
```

Currently, our best performing models are the `Prism-DINOSigLIP` series, with especially strong performance on spatial
understanding and localization tasks.

---
**Explicit Notes on Model Licensing & Commercial Use**: While all code in this repository is released under an MIT
License, our pretrained models inherit restrictions from the _datasets_ and _underlying LMs_ we use for training.

Expand All @@ -188,10 +169,7 @@ additionally train on the LLaVa Instruct Tuning data, which is synthetically gen
(subject to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use)).

**[5/21/24]** We release two `mistral-*-v0.1*` models derived from
[Mistral v0.1](https://mistral.ai/news/announcing-mistral-7b/) which is subject to an Apache 2.0 License. We also
release `phi-2+3b` derived from
[Phi-2](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) released
under an MIT License.
[Mistral v0.1](https://mistral.ai/news/announcing-mistral-7b/) which is subject to an Apache 2.0 License.

As we train new models, we will update this section of the README (and the LICENSE files associated with each model)
appropriately. If there are any questions, please file an Issue!
Expand Down
11 changes: 10 additions & 1 deletion prismatic/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
from .models import available_model_names, available_models, get_model_description, load
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor

from prismatic.models import PrismaticConfig, PrismaticForVision2Seq
from prismatic.preprocessing.processors import PrismaticImageProcessor, PrismaticProcessor

# === Register Models / Processors / Configs to the appropriate HF AutoClasses (required for `.from_pretrained()`)
AutoConfig.register("prismatic", PrismaticConfig)
AutoImageProcessor.register(PrismaticConfig, PrismaticImageProcessor)
AutoProcessor.register(PrismaticConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(PrismaticConfig, PrismaticForVision2Seq)
7 changes: 0 additions & 7 deletions prismatic/conf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,6 @@ class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage):
llm_backbone_id: str = "mistral-v0.1-7b-instruct"


@dataclass
class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage):
model_id: str = "phi-2+3b"
llm_backbone_id: str = "phi-2-3b"


# Section 4.3B :: ✌️ --> Co-training on Language-only Data
# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
@dataclass
Expand Down Expand Up @@ -537,7 +531,6 @@ class ModelRegistry(Enum):
EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat
EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1
EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1
EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2

# Cotraining w/ Unimodal Data
EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
Expand Down
2 changes: 2 additions & 0 deletions prismatic/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .download import convert_to_jpg, download_extract
from .materialize import get_dataset_and_collator
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
import copy
import json
from pathlib import Path
from typing import Dict, List, Tuple, Type
from typing import Callable, Dict, List, Tuple, Type

import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase
from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase

from prismatic.models.backbones.llm.prompting import PromptBuilder
from prismatic.models.backbones.vision import ImageTransform
from prismatic.preprocessing.prompting import PromptBuilder

# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
Expand All @@ -31,7 +30,7 @@ def __init__(
self,
chat_json: Path,
image_dir: Path,
image_transform: ImageTransform,
image_transform: Callable[[Image.Image], torch.Tensor],
tokenizer: PreTrainedTokenizerBase,
) -> None:
super().__init__()
Expand All @@ -40,7 +39,7 @@ def __init__(
self.dataset_type = "align"

# Create Prompt Template
self.prompt_template = "{caption}" + self.tokenizer.eos_token
self.prompt_template = "<image>{caption}" + self.tokenizer.eos_token

# Load Chat JSON
with open(self.chat_json, "r") as f:
Expand Down Expand Up @@ -95,6 +94,7 @@ def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
is_multimodal = "image" in example
n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))

return modality_lengths

def __len__(self) -> int:
Expand All @@ -106,7 +106,7 @@ def __init__(
self,
instruct_json: Path,
image_dir: Path,
image_transform: ImageTransform,
image_transform: Callable[[Image.Image], torch.Tensor],
tokenizer: PreTrainedTokenizerBase,
prompt_builder_fn: Type[PromptBuilder],
) -> None:
Expand Down Expand Up @@ -134,21 +134,18 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
:return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
"""
conversation = self.examples[idx]["conversations"]
has_image = "image" in self.examples[idx]

# Create Prompt Builder --> add each message sequentially
prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], []
for turn_idx, turn in enumerate(conversation):
# Get "effective" string added to prompt --> handle whitespace for tokenizer type!
msg = prompt_builder.add_turn(turn["from"], turn["value"])
msg = prompt_builder.add_turn(turn["from"], turn["value"], add_image_token=has_image)

# Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
if isinstance(self.tokenizer, LlamaTokenizerFast):
msg = msg.rstrip()

# Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
elif isinstance(self.tokenizer, CodeGenTokenizerFast):
pass

else:
raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")

Expand All @@ -172,7 +169,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length]

# === Handle "unimodal" (language-only) vs. "multimodal" ===
if "image" in self.examples[idx]:
if has_image:
image_path = Path(self.examples[idx]["image"])

# Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
Expand All @@ -185,7 +182,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

else:
# No image --> return `pixel_values` = None; Collator will do the smart batch handling for us!
return dict(pixel_values=None, input_ids=input_ids, labels=labels)
return dict(input_ids=input_ids, labels=labels)

def get_modality_lengths(self) -> List[Tuple[bool, int]]:
"""Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
Expand All @@ -194,6 +191,7 @@ def get_modality_lengths(self) -> List[Tuple[bool, int]]:
is_multimodal = "image" in example
n_words = sum([len(turn["value"].split()) for turn in example["conversations"]])
modality_lengths.append((is_multimodal, n_words))

return modality_lengths

def __len__(self) -> int:
Expand Down
File renamed without changes.
69 changes: 69 additions & 0 deletions prismatic/data/materialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
materialize.py
Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
clear control flow.
"""

from typing import Callable, Tuple, Type

import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase

from prismatic.conf import DatasetConfig
from prismatic.data.datasets import AlignDataset, FinetuneDataset
from prismatic.preprocessing.prompting import PromptBuilder
from prismatic.util.data_utils import PaddedCollatorForLanguageModeling

# Dataset Initializers =>> Maps Stage --> cls()
DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}


def get_dataset_and_collator(
stage: str,
dataset_cfg: DatasetConfig,
image_transform: Callable[[Image.Image], torch.Tensor],
tokenizer: PreTrainedTokenizerBase,
prompt_builder_fn: Type[PromptBuilder],
padding_side: str = "right",
) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
dataset_cls = DATASET_INITIALIZER[stage]
dataset_root_dir = dataset_cfg.dataset_root_dir
collator = PaddedCollatorForLanguageModeling(
tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side
)

# Switch on `stage`
if stage == "align":
annotation_json, image_dir = dataset_cfg.align_stage_components
dataset = dataset_cls(
dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
)
return dataset, collator

elif stage == "finetune":
annotation_json, image_dir = dataset_cfg.finetune_stage_components
dataset = dataset_cls(
dataset_root_dir / annotation_json,
dataset_root_dir / image_dir,
image_transform,
tokenizer,
prompt_builder_fn=prompt_builder_fn,
)
return dataset, collator

elif stage == "full-finetune":
annotation_json, image_dir = dataset_cfg.finetune_stage_components
dataset = dataset_cls(
dataset_root_dir / annotation_json,
dataset_root_dir / image_dir,
image_transform,
tokenizer,
prompt_builder_fn=prompt_builder_fn,
)
return dataset, collator

else:
raise ValueError(f"Stage `{stage}` is not supported!")
Empty file removed prismatic/extern/__init__.py
Empty file.
Empty file.
Loading

0 comments on commit 777025a

Please sign in to comment.