Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GPU Training #64

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
Expand Down Expand Up @@ -289,5 +289,12 @@ $RECYCLE.BIN/

# Windows shortcuts
*.lnk

# End of https://www.toptal.com/developers/gitignore/api/web,linux,macos,python,windows,data,jupyternotebooks
# End of https://www.toptal.com/developers/gitignore/api/web,linux,macos,python,windows,data,jupyternotebooks

# others
slurm_tmp/
wandb/
*.out
checkpoints/
.deepspeed_env
test.sh
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pip install -e ".[tpu]"
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
```

### GPU Inference
### GPU Training & Inference
1. Clone this repository and navigate to into the codebase
```bash
git clone https://github.com/cambrian-mllm/cambrian
Expand Down Expand Up @@ -268,12 +268,15 @@ To begin, please visit our [Hugging Face alignment data page](https://huggingfac
- [Alignment Data (JSONL file)](https://huggingface.co/datasets/nyu-visionx/Cambrian-Alignment/blob/main/jsons/alignment_2.5m.jsonl)
- [Corresponding Images](https://huggingface.co/datasets/nyu-visionx/Cambrian-Alignment/tree/main)

We provide sample training scripts in:
We provide sample training scripts for TPU in:

- [scripts/cambrian/pretrain_cambrian_8b.sh](scripts/cambrian/pretrain_cambrian_8b.sh)
- [scripts/cambrian/pretrain_cambrian_13b.sh](scripts/cambrian/pretrain_cambrian_13b.sh)
- [scripts/cambrian/pretrain_cambrian_34b.sh](scripts/cambrian/pretrain_cambrian_34b.sh)

For GPU:

- [scripts/gpu_cambrian/pretrain_cambrian_8b.sh](scripts/gpu_cambrian/pretrain_cambrian_8b.sh)
#### Using Custom Data

If you wish to train with other data sources or custom data, we support the commonly used LLaVA data format. For handling very large files, we use JSONL format instead of JSON format for lazy data loading to optimize memory usage.
Expand All @@ -288,12 +291,15 @@ Similar to Training SVA, please visit our [Cambrian-10M data](https://huggingfac
- [Cambrian7M Data (JSONL file)](https://huggingface.co/datasets/nyu-visionx/Cambrian-10M/blob/main/jsons/Cambrian7M_withsystemprompt.jsonl)
- [Corresponding Images](https://huggingface.co/datasets/nyu-visionx/Cambrian-10M)

We provide sample training scripts in:
We provide sample training scripts for TPU in:

- [scripts/cambrian/finetune_cambrian_8b.sh](scripts/cambrian/finetune_cambrian_8b.sh)
- [scripts/cambrian/finetune_cambrian_13b.sh](scripts/cambrian/finetune_cambrian_13b.sh)
- [scripts/cambrian/finetune_cambrian_34b.sh](scripts/cambrian/finetune_cambrian_34b.sh)

For GPU:
- [scripts/gpu_cambrian/finetune_cambrian_8b.sh](scripts/gpu_cambrian/finetune_cambrian_8b.sh)

### Options to note:

- `--mm_projector_type`: To use our SVA module, set this value to `sva`. To use the LLaVA style 2-layer MLP projector, set this value to `mlp2x_gelu`.
Expand Down
8 changes: 4 additions & 4 deletions cambrian/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def get_prompt(self):
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.LLAMA_3:
wrap_sys = lambda msg: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{msg}<|eot_id|>" if len(msg) > 0 else msg
wrap_inst_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|>"
wrap_inst_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>{msg}<|eot_id|>"
wrap_sys = lambda msg: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
wrap_inst_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>\n\n{msg}<|eot_id|>"
wrap_inst_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>\n\n{msg}<|eot_id|>"
ret = ""

for i, (role, message) in enumerate(messages):
Expand All @@ -120,7 +120,7 @@ def get_prompt(self):
ret += message
else:
ret += ""
ret += "<|start_header_id|>assistant<|end_header_id|>"
ret += "<|start_header_id|>assistant<|end_header_id|>\n\n"
elif self.sep_style == SeparatorStyle.MISTRAL:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
Expand Down
10 changes: 6 additions & 4 deletions cambrian/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,15 @@ def insert_separator(X, sep):
return input_ids


def tokenizer_image_token_llama3(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
def tokenizer_image_token_llama3(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None,add_special_tokens=False):
prompt_chunks = [tokenizer(chunk,add_special_tokens=False).input_ids for chunk in prompt.split('<image>')]

def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

input_ids = []
if add_special_tokens:
input_ids = [tokenizer.bos_token_id]
else:
input_ids = []
for x in insert_separator(prompt_chunks, [image_token_index]):
input_ids.extend(x)

Expand Down
47 changes: 37 additions & 10 deletions cambrian/model/cambrian_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,34 @@
from cambrian.utils import IS_XLA_AVAILABLE


# Original version: train only when IS_XLA_AVAILABLE == True
# Now we need to set IF_TRAIN=True in pretrain.sh and finetune.sh
import os
IF_TRAIN = os.getenv('IF_TRAIN', False)
print(f"IF_TRAIN: {IF_TRAIN}")


def print_weights(model_or_weights):
if isinstance(model_or_weights, nn.Sequential):
for name, weight in model_or_weights.named_parameters():
print(name, weight.shape, weight.dtype)
elif isinstance(model_or_weights, dict):
for name, weight in model_or_weights.items():
print(name, weight.shape, weight.dtype)
else:
print(type(model_or_weights))


"""
def set_trace_rank0():
import torch.distributed as dist
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
"""


class CambrianMetaModel:

def __init__(self, config):
Expand Down Expand Up @@ -80,7 +108,7 @@ def __init__(self, config):

else:
self.vision_tower_aux_list = build_vision_tower_aux_list(config, delay_load=True)
config.mm_hidden_size = sum([vision_tower_aux.hidden_size for vision_tower_aux in self.vision_tower_aux_list])
config.mm_hidden_size = sum([vision_tower_aux.hidden_size for vision_tower_aux in self.vision_tower_aux_list])
self.mm_projector = build_vision_projector(config)
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
Expand Down Expand Up @@ -162,14 +190,13 @@ def initialize_vision_modules(self, model_args, fsdp=None):
self.vision_query = nn.Parameter(
torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype) * vision_embed_std
)

embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
)

else:
self.config.mm_hidden_size = sum([vision_tower_aux.hidden_size for vision_tower_aux in vision_tower_aux_list])
self.config.mm_hidden_size = sum([vision_tower_aux.hidden_size for vision_tower_aux in vision_tower_aux_list])
self.mm_projector = build_vision_projector(self.config)
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
Expand Down Expand Up @@ -383,7 +410,7 @@ def prepare_inputs_labels_for_multimodal(
query_features_i = self.get_model().vision_query[query_group_i, :].view(1, 1, 1, -1).expand(bs, query_num, -1, -1)
global_context_feature_i = global_context_feature.expand(-1, query_num, 1, -1).flatten(0,1)
query_side_len = int(query_num**0.5)
if IS_XLA_AVAILABLE:
if IS_XLA_AVAILABLE or IF_TRAIN:
vision_tower_aux_feature_list_i, vision_tower_aux_attention_masks_list_i = self.rearrange_vision_tower_features_train(vision_tower_aux_feature_list, image_aux_attention_masks_list, query_side_len)
else:
vision_tower_aux_feature_list_i, vision_tower_aux_attention_masks_list_i = self.rearrange_vision_tower_features_inference(vision_tower_aux_feature_list, query_side_len,
Expand All @@ -394,14 +421,14 @@ def prepare_inputs_labels_for_multimodal(
# interpolate to the final target size
if query_side_len != final_height:
query_features_i = query_features_i.permute(0, 2, 1).contiguous().view(bs, -1, query_side_len, query_side_len)
query_features_i = F.interpolate(query_features_i.float(),
size=(final_height, final_width),
mode='bilinear',
query_features_i = F.interpolate(query_features_i.float(),
size=(final_height, final_width),
mode='bilinear',
align_corners=False).to(dtype=query_features_i.dtype)
query_features_i = query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
final_image_features_list.append(query_features_i)

if IS_XLA_AVAILABLE:
if IS_XLA_AVAILABLE or IF_TRAIN:
vision_tower_aux_feature_list_final, vision_tower_aux_attention_masks_list_final = self.rearrange_vision_tower_features_train(vision_tower_aux_feature_list, image_aux_attention_masks_list, final_height)
global_context_feature_final = global_context_feature.expand(-1, final_height*final_width, 1, -1).flatten(0,1)
else:
Expand All @@ -410,7 +437,7 @@ def prepare_inputs_labels_for_multimodal(
image_features = torch.cat(final_image_features_list, -1)
image_features = self.get_model().mm_projector(image_features).to(dtype)

if IS_XLA_AVAILABLE:
if IS_XLA_AVAILABLE or IF_TRAIN:
image_features = image_features.view(image_features.shape[0], final_height, final_width, -1)
image_features = torch.cat((
image_features,
Expand Down Expand Up @@ -454,7 +481,7 @@ def prepare_inputs_labels_for_multimodal(
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
raise NotImplementedError

if IS_XLA_AVAILABLE:
if IS_XLA_AVAILABLE or IF_TRAIN:

# embed the input_ids
new_input_ids_padded_for_emb = torch.where(input_ids==IMAGE_TOKEN_INDEX, 0, input_ids)
Expand Down
5 changes: 2 additions & 3 deletions cambrian/model/language_model/cambrian_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -349,7 +348,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# training
if IS_XLA_AVAILABLE:
if IS_XLA_AVAILABLE or os.getenv('IF_TRAIN', False):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down
41 changes: 22 additions & 19 deletions cambrian/model/multimodal_encoder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .sam_encoder import SAMVisionTower
from .diffusion_encoder import DiffusionVisionTower
from .maws_encoder import MawsVisionTower

from concurrent.futures import ThreadPoolExecutor

def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
Expand Down Expand Up @@ -86,63 +86,66 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
vision_tower_aux_name_list = getattr(vision_tower_cfg, 'mm_vision_tower_aux_list', getattr(vision_tower_cfg, 'vision_tower_aux_list', None))
vision_tower_aux_token_len_list = getattr(vision_tower_cfg, 'mm_vision_tower_aux_token_len_list', getattr(vision_tower_cfg, 'vision_tower_aux_token_len_list', None))
vision_tower_aux_list = []
for vision_tower_aux_name, vision_tower_aux_token_len in zip(vision_tower_aux_name_list, vision_tower_aux_token_len_list):
def worker(vision_tower_aux_name,vision_tower_aux_token_len):
config = copy.deepcopy(vision_tower_cfg)
vision_tower = None
vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
if "maws" in vision_tower_aux_name.lower():
logger.info(f"Loading **MAWS** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(MawsVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=MawsVisionTower(vision_tower_aux_name, args=config, **kwargs)

# CLIP-based Vision Towers
elif "openai/clip" in vision_tower_aux_name.lower():
logger.info(f"Loading **OpenAI CLIP** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(ClipVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=ClipVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "apple/dfn" in vision_tower_aux_name.lower():
logger.info(f"Loading **Apple DFN CLIP** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(DfnClipVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=DfnClipVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "siglip" in vision_tower_aux_name.lower():
logger.info(f"Loading **SigLIP CLIP** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "eva/clip" in vision_tower_aux_name.lower():
logger.info(f"Loading **EVA CLIP** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(EvaClipVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=EvaClipVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "clip-convnext" in vision_tower_aux_name.lower():
logger.info(f"Loading **ConvNeXt CLIP** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(CLIPConvNextTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=CLIPConvNextTower(vision_tower_aux_name, args=config, **kwargs)

# SSL-based Vision Towers
elif "dinov2" in vision_tower_aux_name.lower():
logger.info(f"Loading **DINO Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(DinoVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "mae" in vision_tower_aux_name.lower():
logger.info(f"Loading **MAE** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(MAEVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=MAEVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "moco" in vision_tower_aux_name.lower():
logger.info(f"Loading **MoCo** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(MoCoVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=MoCoVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "ijepa" in vision_tower_aux_name.lower():
logger.info(f"Loading **IJepa** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(IJepaVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=IJepaVisionTower(vision_tower_aux_name, args=config, **kwargs)

# Supervised Vision Towers
elif "supervised-vit" in vision_tower_aux_name.lower():
logger.info(f"Loading **Supervised** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(SupervisedViT_VisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=SupervisedViT_VisionTower(vision_tower_aux_name, args=config, **kwargs)

# Other Vision Towers
elif "hybridmodel" in vision_tower_aux_name.lower():
logger.info(f"Loading **Hybrid** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(HybridVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=HybridVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "diffusion" in vision_tower_aux_name.lower():
logger.info(f"Loading **Diffusion CLIP** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(DiffusionVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=DiffusionVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "midas" in vision_tower_aux_name.lower():
logger.info(f"Loading **MiDaS** Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(MiDaSVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=MiDaSVisionTower(vision_tower_aux_name, args=config, **kwargs)
elif "sam" in vision_tower_aux_name.lower():
logger.info(f"Loading **SAM Vision Tower: {vision_tower_aux_name}")
vision_tower_aux_list.append(SAMVisionTower(vision_tower_aux_name, args=config, **kwargs))
vision_tower=SAMVisionTower(vision_tower_aux_name, args=config, **kwargs)
else:
raise ValueError(f'Unknown vision tower: {vision_tower_aux_name}')
return vision_tower_aux_list
return vision_tower
with ThreadPoolExecutor() as executor:
vision_tower_aux_list = executor.map(worker, vision_tower_aux_name_list,vision_tower_aux_token_len_list)
return list(vision_tower_aux_list)
Loading