From 59cb7be0b4bad075998bea4975a7920eff0c569d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 1 Nov 2024 15:13:19 +0000 Subject: [PATCH] Add prefix arg Signed-off-by: Jee Jee Li --- vllm/model_executor/layers/resampler.py | 12 +++-- vllm/model_executor/models/minicpmv.py | 58 ++++++++++++++----------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index afd5b97a785ae..bce91f1d7fd5e 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -162,7 +162,8 @@ def __init__(self, kv_dim: Optional[int] = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.num_queries = num_queries @@ -175,7 +176,8 @@ def __init__(self, self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) else: # Maintain the same return value with ReplicatedLinear.forward self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa @@ -220,14 +222,16 @@ def __init__(self, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, adaptive: bool = False, do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, norm_layer, do_post_projection=do_post_projection, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) self.adaptive = adaptive pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 54c9954c9cfb5..47b538002f496 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -138,13 +138,15 @@ def __init__(self, kv_dim: Optional[int] = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, max_size: Tuple[int, int] = (70, 70), - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) self.max_size = max_size self._set_2d_pos_cache(self.max_size) @@ -410,7 +412,8 @@ def __init__( self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim, - quant_config=quant_config) + quant_config=quant_config, + prefix="resampler") self.resampler.to(device="cuda", dtype=param_dtype) # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm self.lm_head = ParallelLMHead(config.vocab_size, @@ -667,11 +670,11 @@ def init_vision_module( ) -> nn.Module: raise NotImplementedError - def init_resampler( - self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: raise NotImplementedError def get_vision_embedding( @@ -748,11 +751,11 @@ def init_vision_module( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_tokens(input_ids) - def init_resampler( - self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): resampler = Resampler2(embed_dim=embed_dim, num_heads=embed_dim // 128, @@ -761,7 +764,8 @@ def init_resampler( kv_dim=vision_dim, adaptive=False, do_post_projection=True, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) return resampler @@ -896,17 +900,18 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler( - self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): resampler = Resampler2_5(num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( @@ -1047,18 +1052,19 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler( - self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. resampler = Resampler2_5(num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding(