From cf29c5b5cd0305d708fdd445042159b3b2b968da Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 14 Aug 2024 00:01:20 +0000 Subject: [PATCH 01/17] feat: expand vlm support and add image token logic and tests --- integration-tests/models/test_idefics3.py | 104 +++++++ router/src/config.rs | 8 + router/src/lib.rs | 1 + .../text_generation_server/models/__init__.py | 24 ++ .../custom_modeling/flash_llama_modeling.py | 14 +- .../models/custom_modeling/idefics2.py | 277 +++++++++++++++++- .../models/custom_modeling/vlm.py | 2 +- .../models/vlm_causal_lm.py | 5 + 8 files changed, 424 insertions(+), 11 deletions(-) create mode 100644 integration-tests/models/test_idefics3.py diff --git a/integration-tests/models/test_idefics3.py b/integration-tests/models/test_idefics3.py new file mode 100644 index 00000000000..8bcaeda2d5f --- /dev/null +++ b/integration-tests/models/test_idefics3.py @@ -0,0 +1,104 @@ +import pytest +import base64 + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture(scope="module") +def flash_idefics3_next_handle(launcher): + with launcher( + "HuggingFaceM4/Idefics3-8B-Llama3", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_idefics3_next(flash_idefics3_next_handle): + await flash_idefics3_next_handle.health(300) + return flash_idefics3_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_simple(flash_idefics3_next, response_snapshot): + chicken = get_chicken() + response = await flash_idefics3_next.generate( + f"User:![]({chicken})Write me a short story \nAssistant:", + max_new_tokens=10, + ) + assert ( + response.generated_text == " A chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_two_images(flash_idefics3_next, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await flash_idefics3_next.generate( + f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", + max_new_tokens=20, + ) + assert ( + response.generated_text + == " The cow is standing on the beach and the chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 19 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_all_params(flash_idefics3_next, response_snapshot): + response = await flash_idefics3_next.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_load( + flash_idefics3_next, generate_load, response_snapshot +): + chicken = get_chicken() + responses = await generate_load( + flash_idefics3_next, + f"User:![]({chicken})Write me a short story \nAssistant:", + max_new_tokens=10, + n=4, + ) + generated_texts = [r.generated_text for r in responses] + assert generated_texts[0] == " A chicken is sitting on a pile of money." + assert len(generated_texts) == 4 + assert all([r.generated_text == generated_texts[0] for r in responses]) + + assert responses == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 5d07a293ecb..8510d3560f1 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -110,6 +110,13 @@ pub struct ClipVisionModel { patch_size: usize, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Idefics3 { + pub(crate) vision_encoder_max_image_size: usize, + pub(crate) image_seq_len: usize, +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Idefics2 {} @@ -178,6 +185,7 @@ pub enum Config { Idefics, Mllama, Idefics2(Idefics2), + Idefics3(Idefics3), Ssm, GptBigcode, Granite, diff --git a/router/src/lib.rs b/router/src/lib.rs index 84e9bc48286..21c45241308 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -170,6 +170,7 @@ impl TokenizerConfigToken { #[serde(tag = "processor_class")] pub enum HubPreprocessorConfig { Idefics2Processor(Idefics2Preprocessor), + Idefics3Processor(Idefics2Preprocessor), } impl HubPreprocessorConfig { diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608645..c8b6896093e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -151,6 +151,7 @@ ) from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, + Idefics3ForConditionalGeneration, ) from text_generation_server.models.custom_modeling.qwen2_vl import ( Qwen2VLForConditionalGeneration, @@ -188,6 +189,12 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", "multimodal": True, } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } LLAVA_NEXT = { "type": "llava_next", "name": "Llava Next (1.6)", @@ -1253,6 +1260,23 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == IDEFICS3: + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2c007d15648..3e565109c7a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -515,9 +515,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"), config=config, weights=weights, ) @@ -564,7 +562,7 @@ def __init__(self, prefix, config, weights): prefix=( f"model.layers.{last_layer_id}" if not prefix - else f"{prefix}.model.layers.{last_layer_id}" + else f"{prefix}.layers.{last_layer_id}" ), config=config, weights=weights, @@ -572,7 +570,7 @@ def __init__(self, prefix, config, weights): ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix="model.norm" if not prefix else f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -635,9 +633,7 @@ def __init__(self, prefix: str, config, weights): with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" - if not prefix - else f"{prefix}.model.embed_tokens" + "model.embed_tokens" if not prefix else f"{prefix}.embed_tokens" ), weights=weights, ) @@ -655,7 +651,7 @@ def __init__(self, prefix: str, config, weights): with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", + prefix=suffix, weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 923123d61b6..55cfb9e62d5 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -679,6 +679,281 @@ def forward(self, image_hidden_states, attention_mask): return image_hidden_states +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = TensorParallelRowLinear.load( + prefix=f"{prefix}.modality_projection.proj", + config=config, + weights=weights, + bias=False, + ) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states, attention_mask): + print(image_hidden_states.device, self.modality_projection.linear.weight.device) + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class PerceiverConfig: + def __init__(self, config_dict): + self._name_or_path = config_dict.get("_name_or_path", "") + self.add_cross_attention = config_dict.get("add_cross_attention", False) + self.architectures = config_dict.get("architectures", None) + self.attention_dropout = config_dict.get("attention_dropout", 0.0) + self.bad_words_ids = config_dict.get("bad_words_ids", None) + self.begin_suppress_tokens = config_dict.get("begin_suppress_tokens", None) + self.bos_token_id = config_dict.get("bos_token_id", None) + self.chunk_size_feed_forward = config_dict.get("chunk_size_feed_forward", 0) + self.cross_attention_hidden_size = config_dict.get( + "cross_attention_hidden_size", None + ) + self.decoder_start_token_id = config_dict.get("decoder_start_token_id", None) + self.diversity_penalty = config_dict.get("diversity_penalty", 0.0) + self.do_sample = config_dict.get("do_sample", False) + self.early_stopping = config_dict.get("early_stopping", False) + self.encoder_no_repeat_ngram_size = config_dict.get( + "encoder_no_repeat_ngram_size", 0 + ) + self.eos_token_id = config_dict.get("eos_token_id", None) + self.exponential_decay_length_penalty = config_dict.get( + "exponential_decay_length_penalty", None + ) + self.finetuning_task = config_dict.get("finetuning_task", None) + self.forced_bos_token_id = config_dict.get("forced_bos_token_id", None) + self.forced_eos_token_id = config_dict.get("forced_eos_token_id", None) + self.hidden_act = config_dict.get("hidden_act", "silu") + self.id2label = config_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"}) + self.is_decoder = config_dict.get("is_decoder", False) + self.is_encoder_decoder = config_dict.get("is_encoder_decoder", False) + self.label2id = config_dict.get("label2id", {"LABEL_0": 0, "LABEL_1": 1}) + self.length_penalty = config_dict.get("length_penalty", 1.0) + self.max_length = config_dict.get("max_length", 20) + self.min_length = config_dict.get("min_length", 0) + self.model_type = config_dict.get("model_type", "idefics3") + self.no_repeat_ngram_size = config_dict.get("no_repeat_ngram_size", 0) + self.num_beam_groups = config_dict.get("num_beam_groups", 1) + self.num_beams = config_dict.get("num_beams", 1) + self.num_key_value_heads = config_dict.get("num_key_value_heads", 1) + self.num_return_sequences = config_dict.get("num_return_sequences", 1) + self.output_attentions = config_dict.get("output_attentions", False) + self.output_hidden_states = config_dict.get("output_hidden_states", False) + self.output_scores = config_dict.get("output_scores", False) + self.pad_token_id = config_dict.get("pad_token_id", 128002) + self.prefix = config_dict.get("prefix", None) + self.problem_type = config_dict.get("problem_type", None) + self.pruned_heads = config_dict.get("pruned_heads", {}) + self.qk_layer_norms_perceiver = config_dict.get( + "qk_layer_norms_perceiver", False + ) + self.remove_invalid_values = config_dict.get("remove_invalid_values", False) + self.repetition_penalty = config_dict.get("repetition_penalty", 1.0) + self.resampler_depth = config_dict.get("resampler_depth", 6) + self.resampler_head_dim = config_dict.get("resampler_head_dim", 96) + self.resampler_n_heads = config_dict.get("resampler_n_heads", 16) + self.resampler_n_latents = config_dict.get("resampler_n_latents", 64) + self.return_dict = config_dict.get("return_dict", True) + self.return_dict_in_generate = config_dict.get("return_dict_in_generate", False) + self.sep_token_id = config_dict.get("sep_token_id", None) + self.suppress_tokens = config_dict.get("suppress_tokens", None) + self.task_specific_params = config_dict.get("task_specific_params", None) + self.temperature = config_dict.get("temperature", 1.0) + self.tf_legacy_loss = config_dict.get("tf_legacy_loss", False) + self.tie_encoder_decoder = config_dict.get("tie_encoder_decoder", False) + self.tie_word_embeddings = config_dict.get("tie_word_embeddings", True) + self.tokenizer_class = config_dict.get("tokenizer_class", None) + self.top_k = config_dict.get("top_k", 50) + self.top_p = config_dict.get("top_p", 1.0) + self.torch_dtype = config_dict.get("torch_dtype", None) + self.torchscript = config_dict.get("torchscript", False) + self.transformers_version = config_dict.get("transformers_version", "4.43.2") + self.typical_p = config_dict.get("typical_p", 1.0) + self.use_bfloat16 = config_dict.get("use_bfloat16", False) + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + config.text_config.perceiver_config = PerceiverConfig( + config_dict=config.text_config.perceiver_config + ) + self.image_seq_len = config.text_config.perceiver_config.resampler_n_latents + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + # TODO: finish implementing the image token replacement + + # inputs_embeds = self.inputs_merger( + # input_ids=input_ids, + # inputs_embeds=inputs_embeds, + # image_hidden_states=image_hidden_states, + # ) + + # import ipdb; ipdb.set_trace() + # num_images, _, vision_hidden_size = image_hidden_states.shape + # special_image_token_mask = input_ids == self.image_token_id + # new_inputs_embeds = inputs_embeds.clone() + # reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to( + # inputs_embeds.dtype + # ) # cast to the dtype of the input_embeds to support quantized models + # new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states + # inputs_embeds = new_inputs_embeds + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits + + class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -707,7 +982,7 @@ def __init__(self, prefix, config, weights): ) config.quantize = None - self.connector = Idefics2Connector( + self.connector = Idefics3Connector( prefix=f"{prefix}.model.connector" if prefix else "model.connector", config=config, weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 82e409a673c..04edd0a4064 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(f"{prefix}.text_model", config, weights) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 81b4369b986..61562f8a2da 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -54,6 +54,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str if processor.image_processor.do_image_splitting: image_str *= 5 return image_str + if config.model_type == "idefics3": + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}" + image_str = "" + return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) @@ -288,6 +292,7 @@ def __init__( **processor_kwargs, ) self.batch_class = batch_class + # import ipdb; ipdb.set_trace() super().__init__( model_id=model_id, revision=revision, From 305db7ea1e7c356ef806433f6b487ec348760d3f Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 19 Aug 2024 13:36:39 +0000 Subject: [PATCH 02/17] fix: avoid unused perceiver config --- .../models/custom_modeling/idefics2.py | 80 ------------------- 1 file changed, 80 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 55cfb9e62d5..1b7b7983552 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -713,82 +713,6 @@ def forward(self, image_hidden_states, attention_mask): return image_hidden_states -class PerceiverConfig: - def __init__(self, config_dict): - self._name_or_path = config_dict.get("_name_or_path", "") - self.add_cross_attention = config_dict.get("add_cross_attention", False) - self.architectures = config_dict.get("architectures", None) - self.attention_dropout = config_dict.get("attention_dropout", 0.0) - self.bad_words_ids = config_dict.get("bad_words_ids", None) - self.begin_suppress_tokens = config_dict.get("begin_suppress_tokens", None) - self.bos_token_id = config_dict.get("bos_token_id", None) - self.chunk_size_feed_forward = config_dict.get("chunk_size_feed_forward", 0) - self.cross_attention_hidden_size = config_dict.get( - "cross_attention_hidden_size", None - ) - self.decoder_start_token_id = config_dict.get("decoder_start_token_id", None) - self.diversity_penalty = config_dict.get("diversity_penalty", 0.0) - self.do_sample = config_dict.get("do_sample", False) - self.early_stopping = config_dict.get("early_stopping", False) - self.encoder_no_repeat_ngram_size = config_dict.get( - "encoder_no_repeat_ngram_size", 0 - ) - self.eos_token_id = config_dict.get("eos_token_id", None) - self.exponential_decay_length_penalty = config_dict.get( - "exponential_decay_length_penalty", None - ) - self.finetuning_task = config_dict.get("finetuning_task", None) - self.forced_bos_token_id = config_dict.get("forced_bos_token_id", None) - self.forced_eos_token_id = config_dict.get("forced_eos_token_id", None) - self.hidden_act = config_dict.get("hidden_act", "silu") - self.id2label = config_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"}) - self.is_decoder = config_dict.get("is_decoder", False) - self.is_encoder_decoder = config_dict.get("is_encoder_decoder", False) - self.label2id = config_dict.get("label2id", {"LABEL_0": 0, "LABEL_1": 1}) - self.length_penalty = config_dict.get("length_penalty", 1.0) - self.max_length = config_dict.get("max_length", 20) - self.min_length = config_dict.get("min_length", 0) - self.model_type = config_dict.get("model_type", "idefics3") - self.no_repeat_ngram_size = config_dict.get("no_repeat_ngram_size", 0) - self.num_beam_groups = config_dict.get("num_beam_groups", 1) - self.num_beams = config_dict.get("num_beams", 1) - self.num_key_value_heads = config_dict.get("num_key_value_heads", 1) - self.num_return_sequences = config_dict.get("num_return_sequences", 1) - self.output_attentions = config_dict.get("output_attentions", False) - self.output_hidden_states = config_dict.get("output_hidden_states", False) - self.output_scores = config_dict.get("output_scores", False) - self.pad_token_id = config_dict.get("pad_token_id", 128002) - self.prefix = config_dict.get("prefix", None) - self.problem_type = config_dict.get("problem_type", None) - self.pruned_heads = config_dict.get("pruned_heads", {}) - self.qk_layer_norms_perceiver = config_dict.get( - "qk_layer_norms_perceiver", False - ) - self.remove_invalid_values = config_dict.get("remove_invalid_values", False) - self.repetition_penalty = config_dict.get("repetition_penalty", 1.0) - self.resampler_depth = config_dict.get("resampler_depth", 6) - self.resampler_head_dim = config_dict.get("resampler_head_dim", 96) - self.resampler_n_heads = config_dict.get("resampler_n_heads", 16) - self.resampler_n_latents = config_dict.get("resampler_n_latents", 64) - self.return_dict = config_dict.get("return_dict", True) - self.return_dict_in_generate = config_dict.get("return_dict_in_generate", False) - self.sep_token_id = config_dict.get("sep_token_id", None) - self.suppress_tokens = config_dict.get("suppress_tokens", None) - self.task_specific_params = config_dict.get("task_specific_params", None) - self.temperature = config_dict.get("temperature", 1.0) - self.tf_legacy_loss = config_dict.get("tf_legacy_loss", False) - self.tie_encoder_decoder = config_dict.get("tie_encoder_decoder", False) - self.tie_word_embeddings = config_dict.get("tie_word_embeddings", True) - self.tokenizer_class = config_dict.get("tokenizer_class", None) - self.top_k = config_dict.get("top_k", 50) - self.top_p = config_dict.get("top_p", 1.0) - self.torch_dtype = config_dict.get("torch_dtype", None) - self.torchscript = config_dict.get("torchscript", False) - self.transformers_version = config_dict.get("transformers_version", "4.43.2") - self.typical_p = config_dict.get("typical_p", 1.0) - self.use_bfloat16 = config_dict.get("use_bfloat16", False) - - class Idefics3ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -824,10 +748,6 @@ def __init__(self, prefix, config, weights): ) self.config = config - config.text_config.perceiver_config = PerceiverConfig( - config_dict=config.text_config.perceiver_config - ) - self.image_seq_len = config.text_config.perceiver_config.resampler_n_latents self.image_token_id = config.image_token_id self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 From a59b7faf0c5374e8ea2bec33ad50f1ef83dc08dd Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 20 Aug 2024 16:33:39 +0000 Subject: [PATCH 03/17] feat: integrate image tokens into inputs embeds --- router/src/config.rs | 17 ++- .../models/custom_modeling/idefics2.py | 48 +++++---- .../models/vlm_causal_lm.py | 100 +++++++++++++++++- 3 files changed, 138 insertions(+), 27 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index 8510d3560f1..4d5fcfa0639 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -112,9 +112,20 @@ pub struct ClipVisionModel { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub struct Idefics3 { - pub(crate) vision_encoder_max_image_size: usize, - pub(crate) image_seq_len: usize, +pub struct Idefics3 {} + +impl Idefics3 { + pub fn get_max_longest_edge(&self) -> usize { + 364 + } + + pub fn get_number_of_features(&self) -> usize { + 169 + } + + pub fn get_max_longest_edge_for_image_resize(&self) -> usize { + 1456 + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 1b7b7983552..6040625bfb2 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -753,6 +753,19 @@ def __init__(self, prefix, config, weights): config.pad_token_id if config.pad_token_id is not None else -1 ) + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -835,25 +848,22 @@ def forward( all_states.append(image_hidden_states) image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - # TODO: finish implementing the image token replacement - - # inputs_embeds = self.inputs_merger( - # input_ids=input_ids, - # inputs_embeds=inputs_embeds, - # image_hidden_states=image_hidden_states, - # ) - - # import ipdb; ipdb.set_trace() - # num_images, _, vision_hidden_size = image_hidden_states.shape - # special_image_token_mask = input_ids == self.image_token_id - # new_inputs_embeds = inputs_embeds.clone() - # reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to( - # inputs_embeds.dtype - # ) # cast to the dtype of the input_embeds to support quantized models - # new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states - # inputs_embeds = new_inputs_embeds + # TODO: remove when prefill image tokens are handled correctly + # * for now dummy tokens are added instead of the image tokens output byt the vision model + mask_size = (input_ids == self.config.image_token_id).sum().item() + unrolled_image_size = ( + image_hidden_states.shape[1] * image_hidden_states.shape[2] + ) + diff = mask_size - unrolled_image_size + if diff > 0: + print( + f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}." + ) + + if mask_size == unrolled_image_size: + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 61562f8a2da..bc1fd073113 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -23,6 +23,75 @@ IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +def _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def _prompt_single_image( + image_seq_len, fake_token_around_image, image_token, global_img_token +): + """Prompt with expanded image tokens for a single image.""" + return ( + f"{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + + +def get_image_prompt_string( + image_rows, + image_cols, + image_seq_len, + fake_token_around_image, + image_token, + global_img_token, +): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token, + ) + return _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, + ) + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_str *= 5 return image_str if config.model_type == "idefics3": - image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}" - image_str = "" + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + + # TODO: avoid using hardcoded values + image_seq_len = 169 # default value + # image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2)) + + image_str = get_image_prompt_string( + n_rows, + n_cols, + image_seq_len, + image_token=IDEFICS3_IMAGE_TOKEN, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] @@ -85,6 +168,10 @@ def image_text_replacement_fixup(config, text: str) -> str: return text.replace( f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN ) + if config.model_type == "idefics3": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) return text @@ -198,7 +285,9 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - image_inputs = processor.image_processor(images, return_tensors="pt") + image_inputs = processor.image_processor( + images, return_tensors="pt", return_row_col_info=True + ) else: image_inputs = None @@ -212,9 +301,10 @@ def batch_tokenized_inputs( if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - full_text += image_text_replacement( + replacement_text = image_text_replacement( processor, image_inputs, config, image_id ) + full_text += replacement_text image_id += 1 full_text = image_text_replacement_fixup(config, full_text) @@ -289,7 +379,7 @@ def __init__( model_id, revision=revision, trust_remote_code=trust_remote_code, - **processor_kwargs, + # **processor_kwargs, ) self.batch_class = batch_class # import ipdb; ipdb.set_trace() From ebef284b3d7ad1725e53332f84a6f6dfeecc35e5 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 20 Aug 2024 16:55:22 +0000 Subject: [PATCH 04/17] feat: add simple idefics3 test --- integration-tests/conftest.py | 4 + .../test_flash_idefics3_next_simple_url.json | 73 +++++++++++++++++ integration-tests/models/test_idefics3.py | 80 +++++-------------- 3 files changed, 97 insertions(+), 60 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c9c477665a3..c702ae709ee 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -354,6 +354,7 @@ def local_launcher( kv_cache_dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, + max_input_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, lora_adapters: Optional[List[str]] = None, @@ -402,6 +403,9 @@ def local_launcher( if max_input_length: args.append("--max-input-length") args.append(str(max_input_length)) + if max_input_tokens: + args.append("--max-input-tokens") + args.append(str(max_input_tokens)) if max_batch_prefill_tokens: args.append("--max-batch-prefill-tokens") args.append(str(max_batch_prefill_tokens)) diff --git a/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json new file mode 100644 index 00000000000..052318df2c9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -0.2475586, + "special": false, + "text": " The" + }, + { + "id": 2217, + "logprob": -0.017303467, + "special": false, + "text": " image" + }, + { + "id": 62991, + "logprob": -0.7368164, + "special": false, + "text": " depicts" + }, + { + "id": 279, + "logprob": -0.39990234, + "special": false, + "text": " the" + }, + { + "id": 89675, + "logprob": -0.34350586, + "special": false, + "text": " Statue" + }, + { + "id": 315, + "logprob": -0.0002901554, + "special": false, + "text": " of" + }, + { + "id": 32492, + "logprob": -0.0009598732, + "special": false, + "text": " Liberty" + }, + { + "id": 11, + "logprob": -0.2355957, + "special": false, + "text": "," + }, + { + "id": 264, + "logprob": -0.66503906, + "special": false, + "text": " a" + }, + { + "id": 97937, + "logprob": -0.9199219, + "special": false, + "text": " colossal" + } + ], + "top_tokens": null + }, + "generated_text": " The image depicts the Statue of Liberty, a colossal" +} diff --git a/integration-tests/models/test_idefics3.py b/integration-tests/models/test_idefics3.py index 8bcaeda2d5f..1f55872a066 100644 --- a/integration-tests/models/test_idefics3.py +++ b/integration-tests/models/test_idefics3.py @@ -2,23 +2,19 @@ import base64 -# TODO fix the server parsser to count inline image tokens correctly def get_chicken(): with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()) return f"data:image/png;base64,{encoded_string.decode('utf-8')}" -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.fixture(scope="module") def flash_idefics3_next_handle(launcher): with launcher( "HuggingFaceM4/Idefics3-8B-Llama3", + max_total_tokens=3000, + max_batch_prefill_tokens=2501, + max_input_tokens=2500, ) as handle: yield handle @@ -29,76 +25,40 @@ async def flash_idefics3_next(flash_idefics3_next_handle): return flash_idefics3_next_handle.client +# TODO: dont skip when token issue is resolved +@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics3_next_simple(flash_idefics3_next, response_snapshot): +async def test_flash_idefics3_next_simple_base64( + flash_idefics3_next, response_snapshot +): chicken = get_chicken() + query = "Write me a short story" response = await flash_idefics3_next.generate( - f"User:![]({chicken})Write me a short story \nAssistant:", + f"<|begin_of_text|><|begin_of_text|>User:![]({chicken}){query}\nAssistant:", max_new_tokens=10, ) assert ( response.generated_text == " A chicken is sitting on a pile of money." ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 10 - assert response == response_snapshot + # assert response.details.generated_tokens == 10 + # assert response == response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics3_two_images(flash_idefics3_next, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot): + ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + query = "What is in this image?" response = await flash_idefics3_next.generate( - f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", - max_new_tokens=20, + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, ) + print(response) assert ( response.generated_text - == " The cow is standing on the beach and the chicken is sitting on a pile of money." + == " The image depicts the Statue of Liberty, a colossal" ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 19 - assert response == response_snapshot - - -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_idefics3_next_all_params(flash_idefics3_next, response_snapshot): - response = await flash_idefics3_next.generate( - "Test request", - max_new_tokens=10, - repetition_penalty=1.2, - return_full_text=True, - stop_sequences=["test"], - temperature=0.5, - top_p=0.9, - top_k=10, - truncate=5, - typical_p=0.9, - watermark=True, - decoder_input_details=True, - seed=0, - ) - assert response.details.generated_tokens == 10 assert response == response_snapshot - - -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_idefics3_next_load( - flash_idefics3_next, generate_load, response_snapshot -): - chicken = get_chicken() - responses = await generate_load( - flash_idefics3_next, - f"User:![]({chicken})Write me a short story \nAssistant:", - max_new_tokens=10, - n=4, - ) - generated_texts = [r.generated_text for r in responses] - assert generated_texts[0] == " A chicken is sitting on a pile of money." - assert len(generated_texts) == 4 - assert all([r.generated_text == generated_texts[0] for r in responses]) - - assert responses == response_snapshot From dbe1666bc7651400dc606fbdf5e6bb054ceac8bc Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 17 Dec 2024 19:44:24 +0000 Subject: [PATCH 05/17] feat: update docs, image token logic and weight names --- docs/source/supported_models.md | 1 + router/src/validation.rs | 70 ++++++++++++++++++- .../custom_modeling/flash_llama_modeling.py | 4 +- .../models/custom_modeling/idefics2.py | 8 ++- 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 0f39ff28e42..5ac903516d9 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) +- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) diff --git a/router/src/validation.rs b/router/src/validation.rs index 8137ac58d2b..6d5b06bd39e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -614,6 +614,73 @@ fn image_tokens( image_string } + Idefics3(config) => { + const FAKE: &str = ""; + const IMAGE: &str = ""; + const GLOBAL_IMG: &str = ""; + + let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize(); + + // resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio + let (height, width) = if height > max_longest_edge_for_image_resize + || width > max_longest_edge_for_image_resize + { + let aspect_ratio = height as f32 / width as f32; + if height > width { + ( + max_longest_edge_for_image_resize, + (max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize, + ) + } else { + ( + (max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize, + max_longest_edge_for_image_resize, + ) + } + } else { + (height, width) + }; + + let image_seq_len = config.get_number_of_features(); + let max_edge = config.get_max_longest_edge(); + + let (image_rows, image_cols) = if height > max_edge || width > max_edge { + ( + (height as f32 / max_edge as f32).ceil() as usize, + (width as f32 / max_edge as f32).ceil() as usize, + ) + } else { + (0, 0) + }; + + let mut image_string = String::new(); + + if image_rows == 0 && image_cols == 0 { + // Single image case + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } else { + // Split image case + for n_h in 0..image_rows { + for n_w in 0..image_cols { + image_string.push_str(FAKE); + image_string.push_str(&format!("", n_h + 1, n_w + 1)); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + } + image_string.push('\n'); + } + + image_string.push('\n'); + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } + + image_string + } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), Qwen2Vl(config) => format!( @@ -647,7 +714,8 @@ fn prepare_input( static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { Some( - config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), + config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_) + | Qwen2Vl(_)), ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3e565109c7a..6320793837e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -534,7 +534,7 @@ def __init__(self, prefix, config, weights): prefix=( f"model.layers.{layer_id}" if not prefix - else f"{prefix}.model.layers.{layer_id}" + else f"{prefix}.layers.{layer_id}" ), config=config, weights=weights, @@ -547,7 +547,7 @@ def __init__(self, prefix, config, weights): prefix=( f"model.layers.{layer_id}" if not prefix - else f"{prefix}.model.layers.{layer_id}" + else f"{prefix}.layers.{layer_id}" ), config=config, weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 6040625bfb2..b1967ec3d51 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -774,7 +774,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -783,6 +783,10 @@ def forward( # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: @@ -872,7 +876,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, From c9573ddf2856eb18f9c768249b95f2dce5b19b47 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 18 Dec 2024 01:36:36 +0000 Subject: [PATCH 06/17] fix: improve image processing --- server/text_generation_server/models/__init__.py | 3 ++- .../models/vlm_causal_lm.py | 15 ++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c8b6896093e..a96cb37f95c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1269,11 +1269,12 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, # XXX: Extremely important to cap resolution in order to limit # VRAM usage. - processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + processor_kwargs={"size": {"longest_edge": 1456}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index bc1fd073113..0548fbc61b3 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -127,11 +127,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str # TODO: implement this in a more general way n_rows = image_input["rows"][0][image_id] n_cols = image_input["cols"][0][image_id] - - # TODO: avoid using hardcoded values - image_seq_len = 169 # default value - # image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2)) - + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) image_str = get_image_prompt_string( n_rows, n_cols, @@ -301,10 +300,9 @@ def batch_tokenized_inputs( if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - replacement_text = image_text_replacement( + full_text += image_text_replacement( processor, image_inputs, config, image_id ) - full_text += replacement_text image_id += 1 full_text = image_text_replacement_fixup(config, full_text) @@ -379,10 +377,9 @@ def __init__( model_id, revision=revision, trust_remote_code=trust_remote_code, - # **processor_kwargs, + **processor_kwargs, ) self.batch_class = batch_class - # import ipdb; ipdb.set_trace() super().__init__( model_id=model_id, revision=revision, From 53b2bea6b99d0dfcf1c088853adcf45dfadfe171 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 18 Dec 2024 03:25:22 +0000 Subject: [PATCH 07/17] feat: improve prefix for idefics3 --- .../custom_modeling/flash_llama_modeling.py | 36 +++++++++++-------- .../models/vlm_causal_lm.py | 26 ++++++-------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6320793837e..d2c4f7515e0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -515,7 +515,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=0, - prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"), + prefix=f"{prefix}.layers.0" if prefix else "model.layers.0", config=config, weights=weights, ) @@ -532,9 +532,9 @@ def __init__(self, prefix, config, weights): FlashLlamaCrossLayer( index=layer_id, prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.layers.{layer_id}" + f"{prefix}.layers.{layer_id}" + if prefix + else f"model.layers.{layer_id}" ), config=config, weights=weights, @@ -545,9 +545,9 @@ def __init__(self, prefix, config, weights): FlashLlamaLayer( index=layer_id, prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.layers.{layer_id}" + f"{prefix}.layers.{layer_id}" + if prefix + else f"model.layers.{layer_id}" ), config=config, weights=weights, @@ -560,9 +560,9 @@ def __init__(self, prefix, config, weights): FlashLlamaLayer( index=last_layer_id, prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.layers.{last_layer_id}" + f"{prefix}.layers.{last_layer_id}" + if prefix + else f"model.layers.{last_layer_id}" ), config=config, weights=weights, @@ -570,7 +570,7 @@ def __init__(self, prefix, config, weights): ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.norm", + prefix=f"{prefix}.norm" if prefix else "model.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -630,11 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() + if config.model_type == "mllama_text_model": + prefix = f"{prefix}.model" + with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.embed_tokens" - ), + prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"), weights=weights, ) self.model = FlashLlamaModel(prefix, config, weights) @@ -648,6 +649,13 @@ def __init__(self, prefix: str, config, weights): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier + if config.model_type == "mllama_text_model": + prefix = prefix.replace(".model", "") + suffix = f"{prefix}.{suffix}" + + if config.model_type == "granite": + suffix = f"{prefix}.{suffix}" + with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 0548fbc61b3..306da49753a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -167,10 +167,6 @@ def image_text_replacement_fixup(config, text: str) -> str: return text.replace( f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN ) - if config.model_type == "idefics3": - return text.replace( - f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN - ) return text @@ -290,8 +286,8 @@ def batch_tokenized_inputs( else: image_inputs = None - batch_inputs = [] - max_truncation = 0 + batch_tokenized_inputs = [] + max_length = 0 image_id = 0 for r in requests: full_text = "" @@ -306,16 +302,14 @@ def batch_tokenized_inputs( image_id += 1 full_text = image_text_replacement_fixup(config, full_text) - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=not config.model_type == "paligemma", - )["input_ids"] + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) return batch_tokenized_inputs, image_inputs From 34174af8c8a3b4f32e82d18174593afe894834d1 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 18 Dec 2024 05:06:49 +0000 Subject: [PATCH 08/17] fix: bump idefics3 tests and snapshots --- .../test_flash_idefics3_next_simple_url.json | 68 +++++++++---------- integration-tests/models/test_idefics3.py | 39 +---------- 2 files changed, 34 insertions(+), 73 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json index 052318df2c9..6bf2b93a2da 100644 --- a/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json +++ b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json @@ -1,73 +1,67 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 9, "prefill": [], "seed": null, "tokens": [ { - "id": 578, - "logprob": -0.2475586, + "id": 2684, + "logprob": -0.24902344, "special": false, - "text": " The" + "text": " There" }, { - "id": 2217, - "logprob": -0.017303467, - "special": false, - "text": " image" - }, - { - "id": 62991, - "logprob": -0.7368164, + "id": 374, + "logprob": -0.0703125, "special": false, - "text": " depicts" + "text": " is" }, { - "id": 279, - "logprob": -0.39990234, + "id": 264, + "logprob": -0.23535156, "special": false, - "text": " the" + "text": " a" }, { - "id": 89675, - "logprob": -0.34350586, + "id": 35372, + "logprob": -0.125, "special": false, - "text": " Statue" + "text": " statue" }, { - "id": 315, - "logprob": -0.0002901554, + "id": 304, + "logprob": -0.30273438, "special": false, - "text": " of" + "text": " in" }, { - "id": 32492, - "logprob": -0.0009598732, + "id": 279, + "logprob": -0.20507812, "special": false, - "text": " Liberty" + "text": " the" }, { - "id": 11, - "logprob": -0.2355957, + "id": 2217, + "logprob": -0.076171875, "special": false, - "text": "," + "text": " image" }, { - "id": 264, - "logprob": -0.66503906, + "id": 13, + "logprob": -0.053710938, "special": false, - "text": " a" + "text": "." }, { - "id": 97937, - "logprob": -0.9199219, - "special": false, - "text": " colossal" + "id": 128258, + "logprob": -0.011352539, + "special": true, + "text": "" } ], "top_tokens": null }, - "generated_text": " The image depicts the Statue of Liberty, a colossal" + "generated_text": " There is a statue in the image." } diff --git a/integration-tests/models/test_idefics3.py b/integration-tests/models/test_idefics3.py index 1f55872a066..80be2350fad 100644 --- a/integration-tests/models/test_idefics3.py +++ b/integration-tests/models/test_idefics3.py @@ -1,21 +1,9 @@ import pytest -import base64 - - -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") def flash_idefics3_next_handle(launcher): - with launcher( - "HuggingFaceM4/Idefics3-8B-Llama3", - max_total_tokens=3000, - max_batch_prefill_tokens=2501, - max_input_tokens=2500, - ) as handle: + with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle: yield handle @@ -25,26 +13,6 @@ async def flash_idefics3_next(flash_idefics3_next_handle): return flash_idefics3_next_handle.client -# TODO: dont skip when token issue is resolved -@pytest.mark.skip -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_idefics3_next_simple_base64( - flash_idefics3_next, response_snapshot -): - chicken = get_chicken() - query = "Write me a short story" - response = await flash_idefics3_next.generate( - f"<|begin_of_text|><|begin_of_text|>User:![]({chicken}){query}\nAssistant:", - max_new_tokens=10, - ) - assert ( - response.generated_text == " A chicken is sitting on a pile of money." - ), f"{repr(response.generated_text)}" - # assert response.details.generated_tokens == 10 - # assert response == response_snapshot - - @pytest.mark.asyncio @pytest.mark.private async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot): @@ -57,8 +25,7 @@ async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snap ) print(response) assert ( - response.generated_text - == " The image depicts the Statue of Liberty, a colossal" + response.generated_text == " There is a statue in the image." ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 10 + assert response.details.generated_tokens == 9 assert response == response_snapshot From 064e040ee30635f6abb25acb2d7587489a80afde Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 18 Dec 2024 14:58:27 +0000 Subject: [PATCH 09/17] fix: improve text model loading --- .../text_generation_server/models/custom_modeling/idefics2.py | 2 +- server/text_generation_server/models/custom_modeling/vlm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index b1967ec3d51..6c1d5823185 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -723,7 +723,7 @@ def __init__(self, prefix, config, weights): vision_config = config.vision_config self.text_model = load_text_model( - prefix="model" if not prefix else f"{prefix}.model", + prefix=f"{prefix}.model.text_model" if prefix else "model.text_model", config=config.text_config, weights=weights, name="text_model", diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 04edd0a4064..82e409a673c 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(f"{prefix}.text_model", config, weights) + return FlashLlamaForCausalLM(prefix, config, weights) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, From 0d1bf9e983100206ef8d9e332532ee54564f8483 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 19 Dec 2024 01:54:10 +0000 Subject: [PATCH 10/17] feat: consolidate changes with existing vlms and add support and test for smolvlm --- .../test_flash_smolvlm_next_simple_url.json | 61 +++++++++++++++++++ integration-tests/models/test_smolvlm.py | 31 ++++++++++ .../models/custom_modeling/idefics2.py | 2 +- .../models/vlm_causal_lm.py | 7 ++- 4 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json create mode 100644 integration-tests/models/test_smolvlm.py diff --git a/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json new file mode 100644 index 00000000000..17a69d0d409 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json @@ -0,0 +1,61 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 8, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.118652344, + "special": false, + "text": " A" + }, + { + "id": 11426, + "logprob": -0.28320312, + "special": false, + "text": " bee" + }, + { + "id": 335, + "logprob": -0.95703125, + "special": false, + "text": " on" + }, + { + "id": 253, + "logprob": -0.06982422, + "special": false, + "text": " a" + }, + { + "id": 11986, + "logprob": -0.49414062, + "special": false, + "text": " pink" + }, + { + "id": 8525, + "logprob": -0.07763672, + "special": false, + "text": " flower" + }, + { + "id": 30, + "logprob": -1.0703125, + "special": false, + "text": "." + }, + { + "id": 49154, + "logprob": -0.092285156, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " A bee on a pink flower." +} diff --git a/integration-tests/models/test_smolvlm.py b/integration-tests/models/test_smolvlm.py new file mode 100644 index 00000000000..cd105d84cb5 --- /dev/null +++ b/integration-tests/models/test_smolvlm.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_smolvlm_next_handle(launcher): + with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_smolvlm_next(flash_smolvlm_next_handle): + await flash_smolvlm_next_handle.health(300) + return flash_smolvlm_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot): + ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg" + query = "What is in this image?" + response = await flash_smolvlm_next.generate( + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, + ) + print(response) + assert ( + response.generated_text == " A bee on a pink flower." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 8 + assert response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 6c1d5823185..2e49900102d 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -916,7 +916,7 @@ def __init__(self, prefix, config, weights): ) config.quantize = None - self.connector = Idefics3Connector( + self.connector = Idefics2Connector( prefix=f"{prefix}.model.connector" if prefix else "model.connector", config=config, weights=weights, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 306da49753a..c1908d8ec49 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -280,8 +280,13 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: + kwargs = {} + match processor.image_processor_class: + case "Idefics3ImageProcessor": + kwargs["return_row_col_info"] = True + image_inputs = processor.image_processor( - images, return_tensors="pt", return_row_col_info=True + images, return_tensors="pt", **kwargs ) else: image_inputs = None From 575d97339c27c93350463bca2cd771f1251044b0 Mon Sep 17 00:00:00 2001 From: drbh Date: Sat, 21 Dec 2024 00:27:29 +0000 Subject: [PATCH 11/17] fix: create new idefic3 file, simplify logic and adjust llama weight loading --- .../text_generation_server/models/__init__.py | 2 + .../custom_modeling/flash_llama_modeling.py | 53 +- .../models/custom_modeling/idefics2.py | 209 ---- .../models/custom_modeling/idefics3.py | 1040 +++++++++++++++++ .../models/vlm_causal_lm.py | 85 +- 5 files changed, 1093 insertions(+), 296 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/idefics3.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a96cb37f95c..beefeb01672 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -151,6 +151,8 @@ ) from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.idefics3 import ( Idefics3ForConditionalGeneration, ) from text_generation_server.models.custom_modeling.qwen2_vl import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d2c4f7515e0..43c5dfb42a0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -507,6 +507,7 @@ def __init__(self, prefix, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() + base_model = "" if prefix.endswith("text_model") else ".model" # Skip fp8 quant for first and last layers self.layers = nn.ModuleList() @@ -515,7 +516,11 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=0, - prefix=f"{prefix}.layers.0" if prefix else "model.layers.0", + prefix=( + "model.layers.0" + if not prefix + else f"{prefix}{base_model}.layers.0" + ), config=config, weights=weights, ) @@ -532,9 +537,9 @@ def __init__(self, prefix, config, weights): FlashLlamaCrossLayer( index=layer_id, prefix=( - f"{prefix}.layers.{layer_id}" - if prefix - else f"model.layers.{layer_id}" + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}{base_model}.layers.{layer_id}" ), config=config, weights=weights, @@ -545,9 +550,9 @@ def __init__(self, prefix, config, weights): FlashLlamaLayer( index=layer_id, prefix=( - f"{prefix}.layers.{layer_id}" - if prefix - else f"model.layers.{layer_id}" + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}{base_model}.layers.{layer_id}" ), config=config, weights=weights, @@ -560,9 +565,9 @@ def __init__(self, prefix, config, weights): FlashLlamaLayer( index=last_layer_id, prefix=( - f"{prefix}.layers.{last_layer_id}" - if prefix - else f"model.layers.{last_layer_id}" + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}{base_model}.layers.{last_layer_id}" ), config=config, weights=weights, @@ -570,7 +575,7 @@ def __init__(self, prefix, config, weights): ) self.norm = FastRMSNorm.load( - prefix=f"{prefix}.norm" if prefix else "model.norm", + prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -629,18 +634,20 @@ def forward( class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - - if config.model_type == "mllama_text_model": - prefix = f"{prefix}.model" + base_model = "" if prefix.endswith("text_model") else ".model" with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( - prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"), + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}{base_model}.embed_tokens" + ), weights=weights, ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: - suffix = "model.embed_tokens" + suffix = f"model.embed_tokens" else: suffix = "lm_head" @@ -649,17 +656,17 @@ def __init__(self, prefix: str, config, weights): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - if config.model_type == "mllama_text_model": - prefix = prefix.replace(".model", "") - suffix = f"{prefix}.{suffix}" - - if config.model_type == "granite": - suffix = f"{prefix}.{suffix}" + if not prefix: + head_prefix = suffix + elif prefix.endswith("text_model"): + head_prefix = suffix + else: + head_prefix = f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix, + prefix=head_prefix, weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 2e49900102d..923123d61b6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -679,215 +679,6 @@ def forward(self, image_hidden_states, attention_mask): return image_hidden_states -class Idefics3Connector(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.modality_projection = TensorParallelRowLinear.load( - prefix=f"{prefix}.modality_projection.proj", - config=config, - weights=weights, - bias=False, - ) - self.scale_factor = config.scale_factor - - def pixel_shuffle(self, x, scale_factor=2): - bsz, seq, embed_dim = x.size() - height = width = int(seq**0.5) - x = x.view(bsz, height, width, embed_dim) - x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) - x = x.permute(0, 2, 1, 3) - x = x.reshape( - bsz, - int(width / scale_factor), - int(height / scale_factor), - embed_dim * (scale_factor**2), - ) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) - return x - - def forward(self, image_hidden_states, attention_mask): - print(image_hidden_states.device, self.modality_projection.linear.weight.device) - image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) - image_hidden_states = self.modality_projection(image_hidden_states) - return image_hidden_states - - -class Idefics3ForConditionalGeneration(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - config.vision_config.quantize = None - config.vision_config.speculator = config.speculator - config.text_config.quantize = config.quantize - config.text_config.speculator = config.speculator - - vision_config = config.vision_config - self.text_model = load_text_model( - prefix=f"{prefix}.model.text_model" if prefix else "model.text_model", - config=config.text_config, - weights=weights, - name="text_model", - ) - self.dtype = weights.dtype - - # The vision and connector models are not quantized. - with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): - self.vision_model = Idefics2VisionTransformer( - prefix=( - f"{prefix}.model.vision_model" if prefix else "model.vision_model" - ), - config=vision_config, - weights=weights, - ) - - config.quantize = None - self.connector = Idefics3Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) - - self.config = config - self.image_token_id = config.image_token_id - self.pad_token_id = ( - config.pad_token_id if config.pad_token_id is not None else -1 - ) - - def _merge_input_ids_with_image_features( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - ): - """In place merges in vision_embeddings with inputs_embeds.""" - # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id - # Let's pray we have enabled enough slots ! - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, - adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - image_indices=None, - ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), - ) - - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - # TODO: remove when prefill image tokens are handled correctly - # * for now dummy tokens are added instead of the image tokens output byt the vision model - mask_size = (input_ids == self.config.image_token_id).sum().item() - unrolled_image_size = ( - image_hidden_states.shape[1] * image_hidden_states.shape[2] - ) - diff = mask_size - unrolled_image_size - if diff > 0: - print( - f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}." - ) - - if mask_size == unrolled_image_size: - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - - hidden_states = self.text_model.model( - inputs_embeds=inputs_embeds, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, - adapter_data=adapter_data, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.text_model.lm_head(hidden_states) - return logits, speculative_logits - - class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py new file mode 100644 index 00000000000..2c467877b4c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -0,0 +1,1040 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics2 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +import math + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics2VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics2VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics2Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics2VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics2Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.text_config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + def forward(self, hidden_states): + start_shape = hidden_states.shape[:-1] + gate_up_states = self.gate_up_proj(hidden_states) + intermediate_size = gate_up_states.shape[-1] // 2 + gate_up_states = gate_up_states.view(-1, 2, intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] + ).view(*start_shape, -1) + + +class Idefics2RMSNorm(nn.Module): + def __init__(self, prefix, weights, eps): + """ + Idefics2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Idefics2PerceiverAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.layer_idx = None + self.hidden_size = config.text_config.hidden_size + self.num_heads = config.perceiver_config.resampler_n_heads + self.head_size = config.perceiver_config.resampler_head_dim + self.num_key_value_heads = config.perceiver_config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.attention_dropout = config.perceiver_config.attention_dropout + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.kv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False + ) + + self.is_causal = False + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = latents.size() + kv_seq_len = q_len + context.size()[1] + + hidden_states = torch.concat([context, latents], dim=-2) + query_states = self.q_proj(latents) + kv = self.kv(hidden_states) + key_states, value_states = kv.split( + [ + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + ], + dim=2, + ) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + bsz, kv_seq_len, self.num_key_value_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + bsz, kv_seq_len, self.num_key_value_heads, self.head_size + ).transpose(1, 2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_size) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Idefics2PerceiverLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.n_latents = config.perceiver_config.resampler_n_latents + self.depth = config.perceiver_config.resampler_depth + self.rms_norm_eps = config.text_config.rms_norm_eps + + self.input_latents_norm = Idefics2RMSNorm( + prefix=f"{prefix}.input_latents_norm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.input_context_norm = Idefics2RMSNorm( + prefix=f"{prefix}.input_context_norm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.self_attn = Idefics2PerceiverAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.post_attention_layernorm = Idefics2RMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + Args: + latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + """ + residual = latents + + latents = self.input_latents_norm(latents) + context = self.input_context_norm(context) + + latents = self.self_attn( + latents=latents, + context=context, + attention_mask=attention_mask, + ) + latents = residual + latents + residual = latents + + latents = self.post_attention_layernorm(latents) + latents = self.mlp(latents) + latents = residual + latents + + return latents + + +class Idefics2PerceiverResampler(nn.Module): + def __init__(self, prefix, config, weights) -> None: + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.hidden_act = config.perceiver_config.hidden_act + self.n_latents = config.perceiver_config.resampler_n_latents + self.depth = config.perceiver_config.resampler_depth + self.rms_norm_eps = config.text_config.rms_norm_eps + + # Create Latents for Perceiver + self.latents = weights.get_tensor(f"{prefix}.latents") + + # Create Transformer Blocks + self.layers = nn.ModuleList( + [ + Idefics2PerceiverLayer( + prefix=f"{prefix}.layers.{idx}", config=config, weights=weights + ) + for idx in range(self.depth) + ] + ) + self.norm = Idefics2RMSNorm( + prefix=f"{prefix}.norm", + weights=weights, + eps=config.text_config.rms_norm_eps, + ) + + def forward( + self, + context: torch.Tensor, + attention_mask, + ) -> torch.Tensor: + # seq embed -> bsz seq embed + latents = self.latents.unsqueeze(0).expand( + (context.shape[0], *self.latents.size()) + ) + + latent_attention_mask = torch.ones( + (attention_mask.size(0), latents.size(1)), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) + attention_mask = _prepare_4d_attention_mask( + attention_mask, latents.dtype, tgt_len=self.n_latents + ) + + compressed_context = latents + for perceiver_layer in self.layers: + compressed_context = perceiver_layer( + compressed_context, + context, + attention_mask=attention_mask, + ) + compressed_context = self.norm(compressed_context) + + return compressed_context + + +class Idefics2Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics2MLP( + prefix=f"{prefix}.modality_projection", config=config, weights=weights + ) + self.perceiver_resampler = Idefics2PerceiverResampler( + prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights + ) + + def forward(self, image_hidden_states, attention_mask): + image_hidden_states = self.modality_projection(image_hidden_states) + image_hidden_states = self.perceiver_resampler( + context=image_hidden_states, attention_mask=attention_mask + ) + return image_hidden_states + + +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = TensorParallelRowLinear.load( + prefix=f"{prefix}.modality_projection.proj", + config=config, + weights=weights, + bias=False, + ) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states, attention_mask): + print(image_hidden_states.device, self.modality_projection.linear.weight.device) + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight` + # since Idefics3 uses the `embed_tokens` for the final prediction + # config.text_config.tie_word_embeddings = True + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix=f"{prefix}.model.text_model" if prefix else "model.text_model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits + + +class Idefics2ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_seq_len = config.perceiver_config.resampler_n_latents + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index c1908d8ec49..082f4b81840 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -28,69 +28,26 @@ IDEFICS3_GLOBAL_IMG_TOKEN = "" -def _prompt_split_image( - image_seq_len, - image_rows, - image_cols, - fake_token_around_image, - image_token, - global_img_token, +def get_image_prompt_string( + rows=0, + cols=0, + seq_len=1, + fake_token=IDEFICS3_FAKE_IMAGE_TOKEN, + img_token=IDEFICS3_IMAGE_TOKEN, + global_token=IDEFICS3_GLOBAL_IMG_TOKEN, ): - """Prompt with expanded image tokens for when the image is split into patches.""" - text_split_images = "" - for n_h in range(image_rows): - for n_w in range(image_cols): - text_split_images += ( - f"{fake_token_around_image}" - + f"" - + f"{image_token}" * image_seq_len - ) - text_split_images += "\n" + tokens = img_token * seq_len + end_token = f"{fake_token}{global_token}{tokens}{fake_token}" - text_split_images += ( - f"\n{fake_token_around_image}" - + f"{global_img_token}" - + f"{image_token}" * image_seq_len - + f"{fake_token_around_image}" - ) - return text_split_images + if rows == 0 or cols == 0: + return end_token - -def _prompt_single_image( - image_seq_len, fake_token_around_image, image_token, global_img_token -): - """Prompt with expanded image tokens for a single image.""" - return ( - f"{fake_token_around_image}" - + f"{global_img_token}" - + f"{image_token}" * image_seq_len - + f"{fake_token_around_image}" + grid = "\n".join( + "".join(f"{fake_token}{tokens}" for j in range(cols)) + for i in range(rows) ) - -def get_image_prompt_string( - image_rows, - image_cols, - image_seq_len, - fake_token_around_image, - image_token, - global_img_token, -): - if image_rows == 0 and image_cols == 0: - return _prompt_single_image( - image_seq_len, - fake_token_around_image=fake_token_around_image, - image_token=image_token, - global_img_token=global_img_token, - ) - return _prompt_split_image( - image_seq_len, - image_rows, - image_cols, - fake_token_around_image, - image_token, - global_img_token, - ) + return f"{grid}\n\n{end_token}" def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): @@ -132,12 +89,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str / (config.scale_factor**2) ) image_str = get_image_prompt_string( - n_rows, - n_cols, - image_seq_len, - image_token=IDEFICS3_IMAGE_TOKEN, - fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, - global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + rows=n_rows, + cols=n_cols, + seq_len=image_seq_len, + fake_token=IDEFICS3_FAKE_IMAGE_TOKEN, + img_token=IDEFICS3_IMAGE_TOKEN, + global_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) return image_str elif config.model_type == "llava_next": From 4c8f5cdc359fd9167c9ecdb226ba0fd9415ed67a Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 23 Dec 2024 14:39:55 +0000 Subject: [PATCH 12/17] fix: lint with ruff --- .../models/custom_modeling/flash_llama_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 43c5dfb42a0..20bab01b88b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -647,7 +647,7 @@ def __init__(self, prefix: str, config, weights): ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: - suffix = f"model.embed_tokens" + suffix = "model.embed_tokens" else: suffix = "lm_head" From 765ca78014f0c601a31bb80f8e17c761a479fbce Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 7 Jan 2025 22:05:47 +0000 Subject: [PATCH 13/17] fix: clean up idefics 3 and improve prefix handling --- .../custom_modeling/flash_llama_modeling.py | 27 +- .../models/custom_modeling/idefics3.py | 509 +----------------- .../models/custom_modeling/vlm.py | 2 +- .../models/vlm_causal_lm.py | 61 ++- 4 files changed, 76 insertions(+), 523 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 20bab01b88b..7525940a341 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -632,20 +632,24 @@ def forward( class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - base_model = "" if prefix.endswith("text_model") else ".model" - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}{base_model}.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: @@ -656,18 +660,13 @@ def __init__(self, prefix: str, config, weights): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - if not prefix: - head_prefix = suffix - elif prefix.endswith("text_model"): - head_prefix = suffix - else: - head_prefix = f"{prefix}.{suffix}" + prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=head_prefix, - weights=weights, + prefix, + weights, ) # Used in Granite diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 2c467877b4c..81e03943575 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Idefics2 model.""" +""" PyTorch Idefics3 model.""" from typing import List, Optional, Tuple @@ -50,7 +50,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class Idefics2VisionEmbeddings(nn.Module): +class Idefics3VisionEmbeddings(nn.Module): """ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable resolution. @@ -131,7 +131,7 @@ def forward( return embeddings -class Idefics2VisionAttention(nn.Module): +class Idefics3VisionAttention(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config @@ -229,7 +229,7 @@ def forward( return attn_output -class Idefics2VisionMLP(nn.Module): +class Idefics3VisionMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config @@ -248,11 +248,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class Idefics2EncoderLayer(nn.Module): +class Idefics3EncoderLayer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention( + self.self_attn = Idefics3VisionAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.layer_norm1 = nn.LayerNorm.load( @@ -261,7 +261,7 @@ def __init__(self, prefix, config, weights): self.layer_norm2 = nn.LayerNorm.load( prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights ) - self.mlp = Idefics2VisionMLP( + self.mlp = Idefics3VisionMLP( prefix=f"{prefix}.mlp", config=config, weights=weights ) @@ -288,13 +288,13 @@ def forward( return hidden_states -class Idefics2Encoder(nn.Module): +class Idefics3Encoder(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config self.layers = nn.ModuleList( [ - Idefics2EncoderLayer( + Idefics3EncoderLayer( prefix=f"{prefix}.layers.{i}", config=config, weights=weights ) for i in range(config.num_hidden_layers) @@ -316,14 +316,14 @@ def forward( return hidden_states -class Idefics2VisionTransformer(nn.Module): +class Idefics3VisionTransformer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config - self.embeddings = Idefics2VisionEmbeddings( + self.embeddings = Idefics3VisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights ) - self.encoder = Idefics2Encoder( + self.encoder = Idefics3Encoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) self.post_layernorm = nn.LayerNorm.load( @@ -377,317 +377,26 @@ def forward( return last_hidden_state -class Idefics2MLP(nn.Module): +class Idefics3SimpleMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - act = config.text_config.hidden_act - self.act = ( - ACT2FN[act] - if "gelu" not in act - else lambda x: torch.nn.functional.gelu( - x, - approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" - ), - ) - ) - self.gate_up_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - weights=weights, - dim=0, - bias=False, - ) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ) - - def forward(self, hidden_states): - start_shape = hidden_states.shape[:-1] - gate_up_states = self.gate_up_proj(hidden_states) - intermediate_size = gate_up_states.shape[-1] // 2 - gate_up_states = gate_up_states.view(-1, 2, intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] - ).view(*start_shape, -1) - - -class Idefics2RMSNorm(nn.Module): - def __init__(self, prefix, weights, eps): - """ - Idefics2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter( - weights.get_tensor(f"{prefix}.weight"), requires_grad=False - ) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class Idefics2PerceiverAttention(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.layer_idx = None - self.hidden_size = config.text_config.hidden_size - self.num_heads = config.perceiver_config.resampler_n_heads - self.head_size = config.perceiver_config.resampler_head_dim - self.num_key_value_heads = config.perceiver_config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.attention_dropout = config.perceiver_config.attention_dropout - self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - self.num_key_value_heads // weights.process_group.size() - ) - - self.q_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.q_proj", - weights=weights, - bias=False, - ) - self.kv = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - self.o_proj = TensorParallelRowLinear.load( - config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False - ) - - self.is_causal = False - - def forward( - self, - latents: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = latents.size() - kv_seq_len = q_len + context.size()[1] - - hidden_states = torch.concat([context, latents], dim=-2) - query_states = self.q_proj(latents) - kv = self.kv(hidden_states) - key_states, value_states = kv.split( - [ - self.head_size * self.num_key_value_heads, - self.head_size * self.num_key_value_heads, - ], - dim=2, - ) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_size - ).transpose(1, 2) - key_states = key_states.view( - bsz, kv_seq_len, self.num_key_value_heads, self.head_size - ).transpose(1, 2) - value_states = value_states.view( - bsz, kv_seq_len, self.num_key_value_heads, self.head_size - ).transpose(1, 2) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_size) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size) - - attn_output = self.o_proj(attn_output) - - return attn_output - - -class Idefics2PerceiverLayer(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.hidden_size = config.text_config.hidden_size - self.n_latents = config.perceiver_config.resampler_n_latents - self.depth = config.perceiver_config.resampler_depth - self.rms_norm_eps = config.text_config.rms_norm_eps - - self.input_latents_norm = Idefics2RMSNorm( - prefix=f"{prefix}.input_latents_norm", - weights=weights, - eps=self.rms_norm_eps, - ) - self.input_context_norm = Idefics2RMSNorm( - prefix=f"{prefix}.input_context_norm", - weights=weights, - eps=self.rms_norm_eps, - ) - self.self_attn = Idefics2PerceiverAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights - ) - self.post_attention_layernorm = Idefics2RMSNorm( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=self.rms_norm_eps, - ) - self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - - def forward( - self, - latents: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - """ - Args: - latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - """ - residual = latents - - latents = self.input_latents_norm(latents) - context = self.input_context_norm(context) - - latents = self.self_attn( - latents=latents, - context=context, - attention_mask=attention_mask, - ) - latents = residual + latents - residual = latents - - latents = self.post_attention_layernorm(latents) - latents = self.mlp(latents) - latents = residual + latents - - return latents - - -class Idefics2PerceiverResampler(nn.Module): - def __init__(self, prefix, config, weights) -> None: - super().__init__() - self.hidden_size = config.text_config.hidden_size - self.hidden_act = config.perceiver_config.hidden_act - self.n_latents = config.perceiver_config.resampler_n_latents - self.depth = config.perceiver_config.resampler_depth - self.rms_norm_eps = config.text_config.rms_norm_eps - - # Create Latents for Perceiver - self.latents = weights.get_tensor(f"{prefix}.latents") - - # Create Transformer Blocks - self.layers = nn.ModuleList( - [ - Idefics2PerceiverLayer( - prefix=f"{prefix}.layers.{idx}", config=config, weights=weights - ) - for idx in range(self.depth) - ] - ) - self.norm = Idefics2RMSNorm( - prefix=f"{prefix}.norm", - weights=weights, - eps=config.text_config.rms_norm_eps, - ) - - def forward( - self, - context: torch.Tensor, - attention_mask, - ) -> torch.Tensor: - # seq embed -> bsz seq embed - latents = self.latents.unsqueeze(0).expand( - (context.shape[0], *self.latents.size()) - ) - - latent_attention_mask = torch.ones( - (attention_mask.size(0), latents.size(1)), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) - attention_mask = _prepare_4d_attention_mask( - attention_mask, latents.dtype, tgt_len=self.n_latents - ) - - compressed_context = latents - for perceiver_layer in self.layers: - compressed_context = perceiver_layer( - compressed_context, - context, - attention_mask=attention_mask, - ) - compressed_context = self.norm(compressed_context) - - return compressed_context - - -class Idefics2Connector(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.modality_projection = Idefics2MLP( - prefix=f"{prefix}.modality_projection", config=config, weights=weights - ) - self.perceiver_resampler = Idefics2PerceiverResampler( - prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights - ) - - def forward(self, image_hidden_states, attention_mask): - image_hidden_states = self.modality_projection(image_hidden_states) - image_hidden_states = self.perceiver_resampler( - context=image_hidden_states, attention_mask=attention_mask - ) - return image_hidden_states + def forward(self, x): + return self.proj(x) class Idefics3Connector(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.modality_projection = TensorParallelRowLinear.load( - prefix=f"{prefix}.modality_projection.proj", - config=config, - weights=weights, - bias=False, - ) + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) self.scale_factor = config.scale_factor def pixel_shuffle(self, x, scale_factor=2): @@ -706,8 +415,7 @@ def pixel_shuffle(self, x, scale_factor=2): x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) return x - def forward(self, image_hidden_states, attention_mask): - print(image_hidden_states.device, self.modality_projection.linear.weight.device) + def forward(self, image_hidden_states): image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states @@ -726,7 +434,7 @@ def __init__(self, prefix, config, weights): vision_config = config.vision_config self.text_model = load_text_model( - prefix=f"{prefix}.model.text_model" if prefix else "model.text_model", + prefix="model" if not prefix else f"{prefix}.model", config=config.text_config, weights=weights, name="text_model", @@ -735,7 +443,7 @@ def __init__(self, prefix, config, weights): # The vision and connector models are not quantized. with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): - self.vision_model = Idefics2VisionTransformer( + self.vision_model = Idefics3VisionTransformer( prefix=( f"{prefix}.model.vision_model" if prefix else "model.vision_model" ), @@ -810,7 +518,6 @@ def forward( dim=(-1, -2, -3) ) != nb_values_per_image pixel_values = pixel_values[real_images_inds].contiguous() - # Handle the vision attention mask if pixel_attention_mask is None: pixel_attention_mask = torch.ones( @@ -850,7 +557,6 @@ def forward( # Modality projection & resampling image_hidden_states = self.connector( image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), ) all_states.append(image_hidden_states) @@ -877,164 +583,3 @@ def forward( hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits - - -class Idefics2ForConditionalGeneration(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - config.vision_config.quantize = None - config.vision_config.speculator = config.speculator - config.text_config.quantize = config.quantize - config.text_config.speculator = config.speculator - - vision_config = config.vision_config - self.text_model = load_text_model( - prefix="model" if not prefix else f"{prefix}.model", - config=config.text_config, - weights=weights, - name="text_model", - ) - self.dtype = weights.dtype - - # The vision and connector models are not quantized. - with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): - self.vision_model = Idefics2VisionTransformer( - prefix=( - f"{prefix}.model.vision_model" if prefix else "model.vision_model" - ), - config=vision_config, - weights=weights, - ) - - config.quantize = None - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) - - self.config = config - self.image_seq_len = config.perceiver_config.resampler_n_latents - self.image_token_id = config.image_token_id - self.pad_token_id = ( - config.pad_token_id if config.pad_token_id is not None else -1 - ) - - def _merge_input_ids_with_image_features( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - ): - """In place merges in vision_embeddings with inputs_embeds.""" - # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id - # Let's pray we have enabled enough slots ! - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, - adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), - ) - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - - hidden_states = self.text_model.model( - inputs_embeds=inputs_embeds, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, - adapter_data=adapter_data, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.text_model.lm_head(hidden_states) - return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 82e409a673c..94b8522d4b6 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 082f4b81840..daf5d063009 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ FlashCausalLM, ) from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -29,25 +30,32 @@ def get_image_prompt_string( - rows=0, - cols=0, - seq_len=1, - fake_token=IDEFICS3_FAKE_IMAGE_TOKEN, - img_token=IDEFICS3_IMAGE_TOKEN, - global_token=IDEFICS3_GLOBAL_IMG_TOKEN, + *, + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, ): - tokens = img_token * seq_len - end_token = f"{fake_token}{global_token}{tokens}{fake_token}" - - if rows == 0 or cols == 0: - return end_token + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" - grid = "\n".join( - "".join(f"{fake_token}{tokens}" for j in range(cols)) - for i in range(rows) + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" ) - - return f"{grid}\n\n{end_token}" + return text_split_images def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): @@ -89,18 +97,17 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str / (config.scale_factor**2) ) image_str = get_image_prompt_string( - rows=n_rows, - cols=n_cols, - seq_len=image_seq_len, - fake_token=IDEFICS3_FAKE_IMAGE_TOKEN, - img_token=IDEFICS3_IMAGE_TOKEN, - global_token=IDEFICS3_GLOBAL_IMG_TOKEN, + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) - from loguru import logger log_master( logger.info, @@ -238,9 +245,11 @@ def batch_tokenized_inputs( if images: kwargs = {} - match processor.image_processor_class: - case "Idefics3ImageProcessor": - kwargs["return_row_col_info"] = True + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True image_inputs = processor.image_processor( images, return_tensors="pt", **kwargs From df504e9f1e274727c5fc62f512ec62e5cd800d29 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 7 Jan 2025 22:06:54 +0000 Subject: [PATCH 14/17] fix: improve typing --- .../text_generation_server/models/vlm_causal_lm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index daf5d063009..0ff3c3fa461 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -31,12 +31,12 @@ def get_image_prompt_string( *, - image_seq_len, - image_rows, - image_cols, - fake_token_around_image, - image_token, - global_img_token, + image_seq_len: int, + image_rows: int, + image_cols: int, + fake_token_around_image: str, + image_token: str, + global_img_token: str, ): """Prompt with expanded image tokens for when the image is split into patches.""" text_split_images = "" From d397748ca86bfd4aafb4f4e92f49ad4914cd61d2 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 7 Jan 2025 22:09:52 +0000 Subject: [PATCH 15/17] fix: improve prompt_split_image with ref to original impl --- server/text_generation_server/models/vlm_causal_lm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 0ff3c3fa461..8937c1c5917 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -29,6 +29,7 @@ IDEFICS3_GLOBAL_IMG_TOKEN = "" +# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 def get_image_prompt_string( *, image_seq_len: int, From 78004db1e6a0190a5e39b2a60baa9e25582251e0 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 7 Jan 2025 22:25:38 +0000 Subject: [PATCH 16/17] fix: adjust ruff lints and small refactors --- .../models/custom_modeling/flash_llama_modeling.py | 2 +- .../text_generation_server/models/custom_modeling/idefics3.py | 1 - server/text_generation_server/models/vlm_causal_lm.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7525940a341..4a96d27ad32 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -660,7 +660,7 @@ def __init__(self, prefix: str, config, weights, name=None): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head" + prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 81e03943575..580398cb32e 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -19,7 +19,6 @@ import torch import torch.utils.checkpoint from torch import nn -import math from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 8937c1c5917..db78341d1ed 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -30,7 +30,7 @@ # copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 -def get_image_prompt_string( +def _prompt_split_image( *, image_seq_len: int, image_rows: int, @@ -97,7 +97,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) ) - image_str = get_image_prompt_string( + image_str = _prompt_split_image( image_seq_len=image_seq_len, image_rows=n_rows, image_cols=n_cols, From daa397c515203a3161c03f7c479e7f276c9ccc34 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 8 Jan 2025 13:49:11 +0000 Subject: [PATCH 17/17] fix: adjust FlashLlamaModel prefix logic --- .../custom_modeling/flash_llama_modeling.py | 27 ++++--------------- .../models/flash_causal_lm.py | 2 +- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4a96d27ad32..b89b42981b0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -507,7 +507,6 @@ def __init__(self, prefix, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - base_model = "" if prefix.endswith("text_model") else ".model" # Skip fp8 quant for first and last layers self.layers = nn.ModuleList() @@ -516,11 +515,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" - if not prefix - else f"{prefix}{base_model}.layers.0" - ), + prefix=f"{prefix}.layers.0", config=config, weights=weights, ) @@ -536,11 +531,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaCrossLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}{base_model}.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -549,11 +540,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}{base_model}.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -564,18 +551,14 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=last_layer_id, - prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}{base_model}.layers.{last_layer_id}" - ), + prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, ) ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm", + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5d37699074f..87987ee5303 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1288,7 +1288,7 @@ def __init__( weights_loader=weights_loader, ) - prefix = "" + prefix = None model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group)