Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix InternVL2 model sharding #480

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions lmms_eval/models/internvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,55 @@ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=3
return pixel_values, num_patches_list


import math
from datetime import timedelta

from accelerate.state import AcceleratorState
from accelerate.utils import InitProcessGroupKwargs


# The reason for writing the code this way is to avoid errors that occur during multi-GPU inference due to tensors not being on the same device. By ensuring that the first and last layers of the large language model (LLM) are on the same device, we prevent such errors.
def split_model(model_name, num_layers=None):
device_map = {}
world_size = torch.cuda.device_count()
if num_layers is None:
num_layers = {
"InternVL2_5-1B": 24,
"InternVL2_5-2B": 24,
"InternVL2_5-4B": 36,
"InternVL2_5-8B": 32,
"InternVL2_5-26B": 48,
"InternVL2_5-38B": 64,
"InternVL2_5-78B": 80,
"InternVL2-1B": 24,
"InternVL2-2B": 24,
"InternVL2-4B": 32,
"InternVL2-8B": 32,
"InternVL2-26B": 48,
"InternVL2-40B": 60,
"InternVL2-Llama3-76B": 80,
}[model_name]
# Since the first GPU will be used for ViT, treat it as half a GPU.
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
num_layers_per_gpu = [num_layers_per_gpu] * world_size
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f"language_model.model.layers.{layer_cnt}"] = i
layer_cnt += 1
device_map["vision_model"] = 0
device_map["mlp1"] = 0
device_map["language_model.model.tok_embeddings"] = 0
device_map["language_model.model.embed_tokens"] = 0
device_map["language_model.output"] = 0
device_map["language_model.model.norm"] = 0
device_map["language_model.lm_head"] = 0
device_map[f"language_model.model.layers.{num_layers - 1}"] = 0

return device_map


@register_model("internvl2")
class InternVL2(lmms):
def __init__(
Expand All @@ -135,13 +178,12 @@ def __init__(
device_map: str = "cuda:0",
batch_size: str = "1",
num_frame: int = 32,
num_layers=None,
**kwargs,
):
super().__init__()

self.path = pretrained
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map)
self.num_frame = num_frame

batch_size = int(batch_size)
Expand All @@ -156,11 +198,15 @@ def __init__(
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
device_map = split_model(pretrained.split("/")[-1], num_layers=num_layers)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"

self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map)

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
Expand Down
Loading