From c0e9b46795153c452e05dc0b3b7cf23a3cd7b188 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Mon, 2 Dec 2024 13:23:46 -0800 Subject: [PATCH 01/15] Update registry.py --- vllm/model_executor/models/registry.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f17f31d208575..6bdd9a823f82a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -90,6 +90,9 @@ "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 + "T5Model": ("t5", "T5ForConditionalGeneration"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), + "T5WithLMHeadModel": ("t5", "T5ForConditionalGeneration") } _EMBEDDING_MODELS = { @@ -484,4 +487,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() From d9abfb916a8cc80bbbd0fcd063b00a115362dcb1 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Mon, 2 Dec 2024 13:29:42 -0800 Subject: [PATCH 02/15] Update preprocess.py --- vllm/inputs/preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 82ce7d392b719..87a574fb2ba62 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -160,7 +160,8 @@ def _prepare_decoder_input_ids_for_generation( if decoder_input_ids is None: # no decoder prompt input -> # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + # decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + decoder_input_ids = [decoder_start_token_id] if force_bos and (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): From e78a4957602f22a6ae50fbe4cdbc38862e70c092 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Mon, 2 Dec 2024 13:33:18 -0800 Subject: [PATCH 03/15] added t5 model script --- vllm/model_executor/models/t5.py | 1017 ++++++++++++++++++++++++++++++ 1 file changed, 1017 insertions(+) create mode 100644 vllm/model_executor/models/t5.py diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py new file mode 100644 index 0000000000000..1368939bdfe42 --- /dev/null +++ b/vllm/model_executor/models/t5.py @@ -0,0 +1,1017 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + + +import copy + + +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import T5Config +from transformers.utils import logging + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +# from flash_attn import flash_attn_func + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = get_act_fn(config.dense_act_fn, quant_config) + + def forward(self, hidden_states): + + hidden_states = self.wi(hidden_states)[0] + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states)[0] + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) + self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = get_act_fn(config.dense_act_fn, quant_config) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config, quant_config) + else: + self.DenseReluDense = T5DenseActDense(config, quant_config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.inner_dim // self.n_heads, + self.n_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.inner_dim, + self.d_model, + bias=False, + quant_config=quant_config, + ) + self.attn = Attention(self.n_heads, + self.inner_dim // self.n_heads, + scale = 1, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([2048, 2048, 2048], dim=-1) + if encoder_hidden_states is None: + attn_output = F.scaled_dot_product_attention(q, + k, + v, + dropout_p=0.0) + else: + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv.split([2048, 2048, 2048], dim=-1) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + dropout_p=0.0) + output, _ = self.out_proj(attn_output) + present_key_value_state = (k, v) if self.is_decoder else None + return output, present_key_value_state + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, cache_config, quant_config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + hidden_states, + kv_cache, + attn_metadata) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.EncDecAttention = T5Attention(config, cache_config, quant_config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states, + kv_cache, + attn_metadata, + encoder_hidden_states, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config: T5Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.self_attn = T5LayerSelfAttention(config, cache_config, quant_config, has_relative_attention_bias=has_relative_attention_bias) + if self.is_decoder: + self.cross_attn = T5LayerCrossAttention(config, cache_config, quant_config) + self.fc = T5LayerFF(config, quant_config) + + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + self_attention_outputs = self.self_attn(hidden_states, kv_cache, attn_metadata) + hidden, _ = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden).any(), + torch.finfo(hidden.dtype).max - 1000, + torch.finfo(hidden.dtype).max, + ) + hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.cross_attn(hidden, kv_cache, attn_metadata, encoder_hidden_states) + hidden = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden).any(), + torch.finfo(hidden.dtype).max - 1000, + torch.finfo(hidden.dtype).max, + ) + hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) + + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden = self.fc(hidden) + + # clamp inf values to enable fp16 training + if hidden.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden).any(), + torch.finfo(hidden.dtype).max - 1000, + torch.finfo(hidden.dtype).max, + ) + hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) + + outputs = (hidden,) + attention_outputs + return outputs + + +class T5ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: T5Config): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + +class T5Stack(nn.Module): + def __init__(self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + embed_tokens=None): + super().__init__() + self.cache_config = cache_config + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [T5Block(config, cache_config, quant_config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor]=None) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.dropout(inputs_embeds) + # print('t5 stack', type(hidden_states)) + for i, layer in enumerate(self.block): + layer_outputs = layer(hidden_states, + kv_caches[i], + attn_metadata, + encoder_hidden_states) + hidden_states = layer_outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + +class T5Model(nn.Module): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + super().__init__() + # self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.shared = VocabParallelEmbedding( + self.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + ) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, cache_config, quant_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, cache_config, quant_config, self.shared) + + def forward( + self, + input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, T5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") + >>> model = T5Model.from_pretrained("t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=encoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + return decoder_outputs + +class T5ForConditionalGeneration(nn.Module): + def __init__(self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + super().__init__() + self.config = config + self.model_dim = config.d_model + self.model = T5Model(config, + cache_config, + quant_config, + lora_config=lora_config) + print('lora_config', lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(num_embeddings= self.unpadded_vocab_size, + embedding_dim=config.d_model, + org_num_embeddings=config.vocab_size, + bias=False) + + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Output torch.Tensor + """ + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions, kv_caches, attn_metadata) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + stacked_params_mapping = { + "q.weight": { + "param_name": "qkv_proj.weight", + "shard_id": "q", + }, + "k.weight": { + "param_name": "qkv_proj.weight", + "shard_id": "k", + }, + "v.weight": { + "param_name": "qkv_proj.weight", + "shard_id": "v", + }, + "o.weight": { + "param_name": "out_proj.weight", + "shard_id": None, + } + } + + + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + layer_type_mapping = { + "encoder": { + "layer.0": "self_attn", + "layer.1": "fc", + }, + "decoder": { + "layer.0": "self_attn", + "layer.1": "cross_attn", + "layer.2": "fc", + } + } + + def _rename_layer_types( + self, + name: str, + ) -> str: + for enc_dec, mapping in self.layer_type_mapping.items(): + if enc_dec in name: + for layer_num in mapping.keys(): + if layer_num in name: + name = name.replace(layer_num, mapping[layer_num]) + return name + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name and '.wo.' not in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + + # def get_set(self, model_params_dict): + # out = set() + # for key in model_params_dict.keys(): + # if "bias" in key: + # print('BBBIIIAAASSSSS..................') + # if 'decoder' not in key and 'encoder' not in key: + # print(key) + # # print(key.split('.')) + # lst = key.split('.') + # if len(lst)>=4: + # out.add(lst[3]) + # return out + + def match_weight_name(self, weights_tuple_list): + out = set() + for name, _ in weights_tuple_list: + # print(name) + if 'decoder' in name and 'layer_norm' not in name: + if 'layer.0' in name and 'SelfAttention' not in name: + print(name) + out.add(False) + elif 'layer.1' in name and 'EncDecAttention' not in name: + print(name) + out.add(False) + elif 'layer.2' in name and 'DenseReluDense' not in name: + print(name) + out.add(False) + else: + out.add(True) + elif 'encoder' in name and 'layer_norm' not in name: + if 'layer.0' in name and 'SelfAttention' not in name: + print(name) + out.add(False) + elif 'layer.1' in name and 'DenseReluDense' not in name: + print(name) + out.add(False) + else: + out.add(True) + elif 'decoder' not in name and 'encoder' not in name: + print(name) + return out + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + model_params_dict = dict(self.model.named_parameters()) + # types = self.get_set(model_params_dict) + top_params_dict = dict(self.named_parameters()) + + weights_tuple_list = list(weights) + + shared_embedding_weight = None + shared_embedding_shard_id = None + + for name, loaded_weight in weights_tuple_list: + name = self._rename_layer_types(name) + name, shard_id = self._rename_stacked_param(name) + if ('encoder.embed_tokens.weight' in name + or 'decoder.embed_tokens.weight' in name + or 'lm_head.weight' in name): + assert shared_embedding_weight is None, ( + "Conflicting embedding weights.") + shared_embedding_weight = loaded_weight + shared_embedding_shard_id = shard_id + else: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + if "bias.weight" in name and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) From 6cc440ba112a1513c981540185aacbf93defef3f Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 16:15:16 -0800 Subject: [PATCH 04/15] added source script link --- vllm/model_executor/models/t5.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 1368939bdfe42..8ef67f9143695 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -1,4 +1,5 @@ -# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); From 124c56bbaa9c50465cd65bdef544d1c4b7b92e69 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 16:46:52 -0800 Subject: [PATCH 05/15] fix github actions --- vllm/model_executor/models/t5.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 8ef67f9143695..8ae48d569dc34 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -17,20 +17,16 @@ import copy - - -import math +import os from typing import Iterable, List, Optional, Tuple - import torch from torch import nn import torch.nn.functional as F from transformers import T5Config from transformers.utils import logging -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -43,7 +39,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors # from flash_attn import flash_attn_func logger = logging.get_logger(__name__) @@ -83,13 +78,13 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ) raise tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + logger.info("Converting TensorFlow checkpoint from %s", tf_path) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) names = [] tf_weights = {} for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") + logger.info("Loading TF weight name is %s", name) array = tf.train.load_variable(tf_path, name) names.append(name) tf_weights[name] = array @@ -243,7 +238,7 @@ def __init__(self, hidden_size, eps=1e-6): def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 From e8b3a9d08fb7b4032cc8dfa379429e1d08a8f744 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 17:08:50 -0800 Subject: [PATCH 06/15] fix github actions --- vllm/model_executor/models/t5.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 8ae48d569dc34..019827aeeb1b7 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -73,7 +73,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): import tensorflow as tf except ImportError: logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "TensorFlow is to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions." ) raise @@ -91,17 +91,18 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): for txt_name in names: name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", "global_step"] for n in name ): - logger.info(f"Skipping {'/'.join(name)}") + log_name = '/'.join(name) + logger.info("Skipping %s", log_name) tf_weights.pop(txt_name, None) continue if "_slot_" in name[-1]: - logger.info(f"Skipping {'/'.join(name)}") + log_name = '/'.join(name) + logger.info("Skipping %s", log_name) tf_weights.pop(txt_name, None) continue pointer = model @@ -113,19 +114,19 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): else: scope_names = [m_name] if scope_names[0] in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") + pointer = getattr(pointer, "weight", None) elif scope_names[0] == "self_attention": - pointer = getattr(pointer, "layer") + pointer = getattr(pointer, "layer", None) pointer = pointer[0] elif scope_names[0] == "enc_dec_attention": - pointer = getattr(pointer, "layer") + pointer = getattr(pointer, "layer", None) pointer = pointer[1] elif scope_names[0] == "dense_relu_dense": - pointer = getattr(pointer, "layer") + pointer = getattr(pointer, "layer", None) pointer = pointer[2] elif scope_names[0] == "rms_norm": if hasattr(pointer, "layer_norm"): - pointer = getattr(pointer, "layer_norm") + pointer = getattr(pointer, "layer_norm", None) elif hasattr(pointer, "final_layer_norm"): pointer = getattr(pointer, "final_layer_norm") elif scope_names[0] == "scale": From 4f74d7220a4f5d7d079001e1876d778a35321702 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 18:26:17 -0800 Subject: [PATCH 07/15] fixed github actions --- vllm/model_executor/models/t5.py | 443 +++++++++---------------------- 1 file changed, 123 insertions(+), 320 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 019827aeeb1b7..b3ef11a441546 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -128,123 +128,76 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): if hasattr(pointer, "layer_norm"): pointer = getattr(pointer, "layer_norm", None) elif hasattr(pointer, "final_layer_norm"): - pointer = getattr(pointer, "final_layer_norm") + pointer = getattr(pointer, "final_layer_norm", None) elif scope_names[0] == "scale": - pointer = getattr(pointer, "weight") + pointer = getattr(pointer, "weight", None) elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") + pointer = getattr(pointer, "bias", None) elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") + pointer = getattr(pointer, "classifier", None) elif scope_names[0] == "decoder" and name[1] == "logits": continue elif scope_names[0] == "logits": - pointer = getattr(pointer, "lm_head") - elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, "lm_head", None) + elif scope_names[0] == "wi" and + len(scope_names) > 1 and + scope_names[1].isdigit(): pointer = getattr(pointer, f"wi_{scope_names[1]}") continue else: try: pointer = getattr(pointer, scope_names[0]) except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") + log_name = '/'.join(name) + logger.info("Skipping %s", log_name) continue if len(scope_names) >= 2: num = int(scope_names[1]) pointer = pointer[num] if scope_names[0] not in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") + pointer = getattr(pointer, "weight", None) if scope_names[0] != "embedding": - logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + logger.info("Transpose weight of + shape %s for %s", str(array.shape), name) array = np.transpose(array) try: if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + raise ValueError("Pointer and array shape mismatched") except AssertionError as e: e.args += (pointer.shape, array.shape) raise - logger.info(f"Initialize PyTorch weight {name}") + logger.info("Initialize PyTorch weight %s", name) pointer.data = torch.from_numpy(array.astype(np.float32)) tf_weights.pop(txt_name, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + weight_not_copied = ', '.join(tf_weights.keys()) + logger.info("Weights not copied to PyTorch + model: %s.", weight_not_copied) return model - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, optional, defaults to None): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the t5 models have the - following number of attention modules: - - - t5-small: 6 - - t5-base: 12 - - t5-large: 24 - - t5-3b: 24 - - t5-11b: 24 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: - model = T5ForConditionalGeneration.from_pretrained("t5-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with t5-3b: - model = T5ForConditionalGeneration.from_pretrained("t5-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + Construct a layernorm module in the T5 style. + No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 + """ + T5 uses a layer_norm which only scales and doesn't + shift, which is also known as Root Mean + Square Layer Normalization https://arxiv.org/abs/1910.07467 + thus variance is calculated + w/o mean and there is no bias. Additionally we want to + make sure that the accumulation for half-precision + inputs is done in fp32 + """ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + adj_var = variance + self.variance_epsilon + hidden_states = hidden_states * torch.rsqrt(adj_var) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -254,10 +207,13 @@ def forward(self, hidden_states): class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + def __init__(self, config: T5Config, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) - self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.wi = ColumnParallelLinear(config.d_model, + config.d_ff, bias=False, quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, config.d_model, + bias=False, quant_config=quant_config) self.dropout = nn.Dropout(config.dropout_rate) self.act = get_act_fn(config.dense_act_fn, quant_config) @@ -277,11 +233,15 @@ def forward(self, hidden_states): class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + def __init__(self, config: T5Config, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) - self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) - self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, bias=False, quant_config=quant_config) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, bias=False, quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, config.d_model, + bias=False, quant_config=quant_config) self.dropout = nn.Dropout(config.dropout_rate) self.act = get_act_fn(config.dense_act_fn, quant_config) @@ -291,9 +251,11 @@ def forward(self, hidden_states): hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) - # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # To make 8bit quantization work for google/flan-t5-xxl, + # self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + # we also make sure the weights are not in `int8` in case users + # will force `_keep_in_fp32_modules` to be `None`` if ( isinstance(self.wo.weight, torch.Tensor) and hidden_states.dtype != self.wo.weight.dtype @@ -306,14 +268,16 @@ def forward(self, hidden_states): class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + def __init__(self, config: T5Config, + quant_config: Optional[QuantizationConfig] = None): super().__init__() if config.is_gated_act: self.DenseReluDense = T5DenseGatedActDense(config, quant_config) else: self.DenseReluDense = T5DenseActDense(config, quant_config) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): @@ -323,12 +287,17 @@ def forward(self, hidden_states): return hidden_states class T5Attention(nn.Module): - def __init__(self, config: T5Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, has_relative_attention_bias=False): + def __init__(self, config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance + rel_num_bucket = config.relative_attention_num_buckets + rel_max_dist = config.relative_attention_max_distance + self.relative_attention_num_buckets = rel_num_bucket + self.relative_attention_max_distance = rel_max_dist self.d_model = config.d_model self.key_value_proj_dim = config.d_kv self.n_heads = config.num_heads @@ -364,7 +333,9 @@ def forward( ) -> torch.Tensor: """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + Self-attention (if key_value_states is None) or + attention over source sentence (provided by + key_value_states). """ qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([2048, 2048, 2048], dim=-1) @@ -385,10 +356,16 @@ def forward( return output, present_key_value_state class T5LayerSelfAttention(nn.Module): - def __init__(self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, has_relative_attention_bias=False): + def __init__(self, config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + has_relative_attention_bias=False): super().__init__() - self.SelfAttention = T5Attention(config, cache_config, quant_config, has_relative_attention_bias=has_relative_attention_bias) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.SelfAttention = T5Attention(config, cache_config, + quant_config, + has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, @@ -399,15 +376,19 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, kv_cache, attn_metadata) hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + outputs = (hidden_states,) + attention_output[1:] return outputs class T5LayerCrossAttention(nn.Module): - def __init__(self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): + def __init__(self, config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.EncDecAttention = T5Attention(config, cache_config, quant_config, has_relative_attention_bias=False) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.EncDecAttention = T5Attention(config, cache_config, + quant_config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, @@ -424,17 +405,23 @@ def forward(self, encoder_hidden_states, ) layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + outputs = (layer_output,) + attention_output[1:] return outputs class T5Block(nn.Module): - def __init__(self, config: T5Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, has_relative_attention_bias=False): + def __init__(self, config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder - self.self_attn = T5LayerSelfAttention(config, cache_config, quant_config, has_relative_attention_bias=has_relative_attention_bias) + self.self_attn = T5LayerSelfAttention(config, cache_config, + quant_config, + has_relative_attention_bias=has_relative_attention_bias) if self.is_decoder: - self.cross_attn = T5LayerCrossAttention(config, cache_config, quant_config) + self.cross_attn = T5LayerCrossAttention(config, + cache_config, quant_config) self.fc = T5LayerFF(config, quant_config) @@ -445,9 +432,10 @@ def forward( attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - self_attention_outputs = self.self_attn(hidden_states, kv_cache, attn_metadata) + self_attention_outputs = self.self_attn(hidden_states, + kv_cache, attn_metadata) hidden, _ = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + attention_outputs = self_attention_outputs[2:] # clamp inf values to enable fp16 training if hidden.dtype == torch.float16: @@ -458,9 +446,11 @@ def forward( ) hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) - do_cross_attention = self.is_decoder and encoder_hidden_states is not None + do_cross_attention = (self.is_decoder and + encoder_hidden_states is not None) if do_cross_attention: - cross_attention_outputs = self.cross_attn(hidden, kv_cache, attn_metadata, encoder_hidden_states) + cross_attention_outputs = self.cross_attn(hidden, kv_cache, + attn_metadata, encoder_hidden_states) hidden = cross_attention_outputs[0] # clamp inf values to enable fp16 training @@ -519,16 +509,20 @@ def __init__(self, self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [T5Block(config, cache_config, quant_config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [T5Block(config, cache_config, quant_config, + has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers)] ) - self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - encoder_hidden_states: Optional[torch.Tensor]=None) -> torch.Tensor: + encoder_hidden_states: Optional[torch.Tensor]=None + ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.dropout(inputs_embeds) # print('t5 stack', type(hidden_states)) @@ -543,171 +537,14 @@ def forward(self, input_ids: torch.Tensor, hidden_states = self.dropout(hidden_states) return hidden_states - -T5_START_DOCSTRING = r""" - - The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`T5Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -T5_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 - Training](./t5#training). - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in - `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -T5_ENCODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" + class T5Model(nn.Module): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", + "decoder.embed_tokens.weight"] def __init__(self, config: T5Config, @@ -730,13 +567,15 @@ def __init__(self, encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, cache_config, quant_config, self.shared) + self.encoder = T5Stack(encoder_config, + cache_config, quant_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, cache_config, quant_config, self.shared) + self.decoder = T5Stack(decoder_config, + cache_config, quant_config, self.shared) def forward( self, @@ -744,30 +583,6 @@ def forward( encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata) -> torch.Tensor: - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, T5Model - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") - >>> model = T5Model.from_pretrained("t5-small") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - ... ).input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. - >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" encoder_hidden_states = None if encoder_input_ids.numel() > 0: @@ -914,7 +729,7 @@ def _rename_layer_types( ) -> str: for enc_dec, mapping in self.layer_type_mapping.items(): if enc_dec in name: - for layer_num in mapping.keys(): + for layer_num in mapping: if layer_num in name: name = name.replace(layer_num, mapping[layer_num]) return name @@ -929,40 +744,33 @@ def _rename_stacked_param( return name, mapping["shard_id"] return name, None - # def get_set(self, model_params_dict): - # out = set() - # for key in model_params_dict.keys(): - # if "bias" in key: - # print('BBBIIIAAASSSSS..................') - # if 'decoder' not in key and 'encoder' not in key: - # print(key) - # # print(key.split('.')) - # lst = key.split('.') - # if len(lst)>=4: - # out.add(lst[3]) - # return out def match_weight_name(self, weights_tuple_list): out = set() for name, _ in weights_tuple_list: # print(name) if 'decoder' in name and 'layer_norm' not in name: - if 'layer.0' in name and 'SelfAttention' not in name: + if 'layer.0' in name and + 'SelfAttention' not in name: print(name) out.add(False) - elif 'layer.1' in name and 'EncDecAttention' not in name: + elif 'layer.1' in name and + 'EncDecAttention' not in name: print(name) out.add(False) - elif 'layer.2' in name and 'DenseReluDense' not in name: + elif 'layer.2' in name and + 'DenseReluDense' not in name: print(name) out.add(False) else: out.add(True) elif 'encoder' in name and 'layer_norm' not in name: - if 'layer.0' in name and 'SelfAttention' not in name: + if 'layer.0' in name and + 'SelfAttention' not in name: print(name) out.add(False) - elif 'layer.1' in name and 'DenseReluDense' not in name: + elif 'layer.1' in name and + 'DenseReluDense' not in name: print(name) out.add(False) else: @@ -974,13 +782,9 @@ def match_weight_name(self, weights_tuple_list): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): model_params_dict = dict(self.model.named_parameters()) - # types = self.get_set(model_params_dict) - top_params_dict = dict(self.named_parameters()) - weights_tuple_list = list(weights) shared_embedding_weight = None - shared_embedding_shard_id = None for name, loaded_weight in weights_tuple_list: name = self._rename_layer_types(name) @@ -991,7 +795,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): assert shared_embedding_weight is None, ( "Conflicting embedding weights.") shared_embedding_weight = loaded_weight - shared_embedding_shard_id = shard_id else: # Skip the specific downstream task weight. if name.startswith('cls.'): From 36a16044054eff399a713947462ffd755500c8ef Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 18:36:30 -0800 Subject: [PATCH 08/15] fix github actions --- vllm/model_executor/models/t5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index b3ef11a441546..67d9e0c95a7f1 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -139,8 +139,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): continue elif scope_names[0] == "logits": pointer = getattr(pointer, "lm_head", None) - elif scope_names[0] == "wi" and - len(scope_names) > 1 and + elif scope_names[0] == "wi" and \ + len(scope_names) > 1 and \ scope_names[1].isdigit(): pointer = getattr(pointer, f"wi_{scope_names[1]}") continue @@ -148,7 +148,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): try: pointer = getattr(pointer, scope_names[0]) except AttributeError: - log_name = '/'.join(name) + log_name = '/'.join(name) logger.info("Skipping %s", log_name) continue if len(scope_names) >= 2: From a9f54f44c1a192b0c0775142ae3a32b55bd981b5 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 18:59:57 -0800 Subject: [PATCH 09/15] fixed github action ruff fails --- vllm/model_executor/models/t5.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 67d9e0c95a7f1..f3df412f6259f 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -157,7 +157,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): if scope_names[0] not in ["kernel", "scale", "embedding"]: pointer = getattr(pointer, "weight", None) if scope_names[0] != "embedding": - logger.info("Transpose weight of + logger.info("Transpose weight of \ shape %s for %s", str(array.shape), name) array = np.transpose(array) try: @@ -170,7 +170,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): pointer.data = torch.from_numpy(array.astype(np.float32)) tf_weights.pop(txt_name, None) weight_not_copied = ', '.join(tf_weights.keys()) - logger.info("Weights not copied to PyTorch + logger.info("Weights not copied to PyTorch \ model: %s.", weight_not_copied) return model @@ -748,28 +748,28 @@ def _rename_stacked_param( def match_weight_name(self, weights_tuple_list): out = set() for name, _ in weights_tuple_list: - # print(name) + if 'decoder' in name and 'layer_norm' not in name: - if 'layer.0' in name and + if 'layer.0' in name and \ 'SelfAttention' not in name: print(name) out.add(False) - elif 'layer.1' in name and + elif 'layer.1' in name and \ 'EncDecAttention' not in name: print(name) out.add(False) - elif 'layer.2' in name and + elif 'layer.2' in name and \ 'DenseReluDense' not in name: print(name) out.add(False) else: out.add(True) elif 'encoder' in name and 'layer_norm' not in name: - if 'layer.0' in name and + if 'layer.0' in name and \ 'SelfAttention' not in name: print(name) out.add(False) - elif 'layer.1' in name and + elif 'layer.1' in name and \ 'DenseReluDense' not in name: print(name) out.add(False) From 8b78df753404dda9871d1ad3be1ee03c427e8c9e Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 19:17:09 -0800 Subject: [PATCH 10/15] fixed ruff fails --- vllm/model_executor/models/t5.py | 38 ++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index f3df412f6259f..435608b6c3c36 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -750,29 +750,35 @@ def match_weight_name(self, weights_tuple_list): for name, _ in weights_tuple_list: if 'decoder' in name and 'layer_norm' not in name: - if 'layer.0' in name and \ - 'SelfAttention' not in name: - print(name) - out.add(False) - elif 'layer.1' in name and \ - 'EncDecAttention' not in name: - print(name) - out.add(False) - elif 'layer.2' in name and \ - 'DenseReluDense' not in name: + if ('layer.0' in name and \ + 'SelfAttention' not in name) \ + or ('layer.1' in name and \ + 'EncDecAttention' not in name) \ + or ('layer.2' in name and \ + 'DenseReluDense' not in name): print(name) out.add(False) + # elif 'layer.1' in name and \ + # 'EncDecAttention' not in name: + # print(name) + # out.add(False) + # elif 'layer.2' in name and \ + # 'DenseReluDense' not in name: + # print(name) + # out.add(False) else: out.add(True) elif 'encoder' in name and 'layer_norm' not in name: - if 'layer.0' in name and \ - 'SelfAttention' not in name: - print(name) - out.add(False) - elif 'layer.1' in name and \ - 'DenseReluDense' not in name: + if ('layer.0' in name and \ + 'SelfAttention' not in name) \ + or ('layer.1' in name and \ + 'DenseReluDense' not in name): print(name) out.add(False) + # elif 'layer.1' in name and \ + # 'DenseReluDense' not in name: + # print(name) + # out.add(False) else: out.add(True) elif 'decoder' not in name and 'encoder' not in name: From b5542667038e4dd9750632f64dd3015932e377ca Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 19:41:32 -0800 Subject: [PATCH 11/15] fix import errors --- vllm/model_executor/models/t5.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 435608b6c3c36..f981fe7606343 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -19,9 +19,9 @@ import copy import os from typing import Iterable, List, Optional, Tuple -import torch -from torch import nn -import torch.nn.functional as F +import torch # type: ignore +from torch import nn # type: ignore +import torch.nn.functional as F # type: ignore from transformers import T5Config from transformers.utils import logging @@ -39,7 +39,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -# from flash_attn import flash_attn_func logger = logging.get_logger(__name__) @@ -68,9 +67,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: import re - - import numpy as np - import tensorflow as tf + import numpy as np # type: ignore + import tensorflow as tf # type: ignore except ImportError: logger.error( "TensorFlow is to be installed. Please see " From 2a8bb2e7d996cb77e3c8179720d398b9a65e625b Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 19:58:04 -0800 Subject: [PATCH 12/15] fix isort error --- vllm/model_executor/models/t5.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index f981fe7606343..7957a2ac02ec7 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -19,9 +19,10 @@ import copy import os from typing import Iterable, List, Optional, Tuple -import torch # type: ignore -from torch import nn # type: ignore -import torch.nn.functional as F # type: ignore + +import torch # type: ignore +import torch.nn.functional as F # type: ignore +from torch import nn # type: ignore from transformers import T5Config from transformers.utils import logging @@ -67,8 +68,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: import re - import numpy as np # type: ignore - import tensorflow as tf # type: ignore + import numpy as np # type: ignore + import tensorflow as tf # type: ignore except ImportError: logger.error( "TensorFlow is to be installed. Please see " From ea736ed071b881dc48506fa238180e8c49bd6c44 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 20:17:21 -0800 Subject: [PATCH 13/15] fixed ruff fails --- vllm/model_executor/models/t5.py | 528 ++++++++++++++++++------------- 1 file changed, 314 insertions(+), 214 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 7957a2ac02ec7..82210ab259d19 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -13,8 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch T5 model.""" - +"""PyTorch T5 model.""" import copy import os @@ -29,18 +28,25 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, +) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +# from flash_attn import flash_attn_func + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "T5Config" @@ -68,6 +74,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: import re + import numpy as np # type: ignore import tensorflow as tf # type: ignore except ImportError: @@ -91,16 +98,22 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): for txt_name in names: name = txt_name.split("/") if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", - "AdamWeightDecayOptimizer_1", "global_step"] + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] for n in name ): - log_name = '/'.join(name) + log_name = "/".join(name) logger.info("Skipping %s", log_name) tf_weights.pop(txt_name, None) continue if "_slot_" in name[-1]: - log_name = '/'.join(name) + log_name = "/".join(name) logger.info("Skipping %s", log_name) tf_weights.pop(txt_name, None) continue @@ -138,16 +151,18 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): continue elif scope_names[0] == "logits": pointer = getattr(pointer, "lm_head", None) - elif scope_names[0] == "wi" and \ - len(scope_names) > 1 and \ - scope_names[1].isdigit(): + elif ( + scope_names[0] == "wi" + and len(scope_names) > 1 + and scope_names[1].isdigit() + ): pointer = getattr(pointer, f"wi_{scope_names[1]}") continue else: try: pointer = getattr(pointer, scope_names[0]) except AttributeError: - log_name = '/'.join(name) + log_name = "/".join(name) logger.info("Skipping %s", log_name) continue if len(scope_names) >= 2: @@ -156,8 +171,12 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): if scope_names[0] not in ["kernel", "scale", "embedding"]: pointer = getattr(pointer, "weight", None) if scope_names[0] != "embedding": - logger.info("Transpose weight of \ - shape %s for %s", str(array.shape), name) + logger.info( + "Transpose weight of \ + shape %s for %s", + str(array.shape), + name, + ) array = np.transpose(array) try: if pointer.shape != array.shape: @@ -168,15 +187,19 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): logger.info("Initialize PyTorch weight %s", name) pointer.data = torch.from_numpy(array.astype(np.float32)) tf_weights.pop(txt_name, None) - weight_not_copied = ', '.join(tf_weights.keys()) - logger.info("Weights not copied to PyTorch \ - model: %s.", weight_not_copied) + weight_not_copied = ", ".join(tf_weights.keys()) + logger.info( + "Weights not copied to PyTorch \ + model: %s.", + weight_not_copied, + ) return model + class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - Construct a layernorm module in the T5 style. + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() @@ -185,12 +208,12 @@ def __init__(self, hidden_size, eps=1e-6): def forward(self, hidden_states): """ - T5 uses a layer_norm which only scales and doesn't + T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - w/o mean and there is no bias. Additionally we want to - make sure that the accumulation for half-precision + w/o mean and there is no bias. Additionally we want to + make sure that the accumulation for half-precision inputs is done in fp32 """ @@ -206,18 +229,22 @@ def forward(self, hidden_states): class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() - self.wi = ColumnParallelLinear(config.d_model, - config.d_ff, bias=False, quant_config=quant_config) - self.wo = RowParallelLinear(config.d_ff, config.d_model, - bias=False, quant_config=quant_config) + self.wi = ColumnParallelLinear( + config.d_model, config.d_ff, bias=False, quant_config=quant_config + ) + self.wo = RowParallelLinear( + config.d_ff, config.d_model, bias=False, quant_config=quant_config + ) self.dropout = nn.Dropout(config.dropout_rate) self.act = get_act_fn(config.dense_act_fn, quant_config) def forward(self, hidden_states): - hidden_states = self.wi(hidden_states)[0] hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) @@ -232,15 +259,21 @@ def forward(self, hidden_states): class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() - self.wi_0 = ColumnParallelLinear(config.d_model, - config.d_ff, bias=False, quant_config=quant_config) - self.wi_1 = ColumnParallelLinear(config.d_model, - config.d_ff, bias=False, quant_config=quant_config) - self.wo = RowParallelLinear(config.d_ff, config.d_model, - bias=False, quant_config=quant_config) + self.wi_0 = ColumnParallelLinear( + config.d_model, config.d_ff, bias=False, quant_config=quant_config + ) + self.wi_1 = ColumnParallelLinear( + config.d_model, config.d_ff, bias=False, quant_config=quant_config + ) + self.wo = RowParallelLinear( + config.d_ff, config.d_model, bias=False, quant_config=quant_config + ) self.dropout = nn.Dropout(config.dropout_rate) self.act = get_act_fn(config.dense_act_fn, quant_config) @@ -250,7 +283,7 @@ def forward(self, hidden_states): hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) - # To make 8bit quantization work for google/flan-t5-xxl, + # To make 8bit quantization work for google/flan-t5-xxl, # self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 # we also make sure the weights are not in `int8` in case users @@ -267,16 +300,20 @@ def forward(self, hidden_states): class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() if config.is_gated_act: self.DenseReluDense = T5DenseGatedActDense(config, quant_config) else: self.DenseReluDense = T5DenseActDense(config, quant_config) - self.layer_norm = T5LayerNorm(config.d_model, - eps=config.layer_norm_epsilon) + self.layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): @@ -285,11 +322,15 @@ def forward(self, hidden_states): hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states + class T5Attention(nn.Module): - def __init__(self, config: T5Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - has_relative_attention_bias=False): + def __init__( + self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + has_relative_attention_bias=False, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -303,7 +344,6 @@ def __init__(self, config: T5Config, self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim - self.qkv_proj = QKVParallelLinear( self.d_model, self.inner_dim // self.n_heads, @@ -317,11 +357,13 @@ def __init__(self, config: T5Config, bias=False, quant_config=quant_config, ) - self.attn = Attention(self.n_heads, - self.inner_dim // self.n_heads, - scale = 1, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.n_heads, + self.inner_dim // self.n_heads, + scale=1, + cache_config=cache_config, + quant_config=quant_config, + ) def forward( self, @@ -330,69 +372,82 @@ def forward( attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Self-attention (if key_value_states is None) or - attention over source sentence (provided by + Self-attention (if key_value_states is None) or + attention over source sentence (provided by key_value_states). """ qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([2048, 2048, 2048], dim=-1) if encoder_hidden_states is None: - attn_output = F.scaled_dot_product_attention(q, - k, - v, - dropout_p=0.0) + attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) else: qkv_enc, _ = self.qkv_proj(encoder_hidden_states) _, k, v = qkv.split([2048, 2048, 2048], dim=-1) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - dropout_p=0.0) + attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) output, _ = self.out_proj(attn_output) present_key_value_state = (k, v) if self.is_decoder else None return output, present_key_value_state + class T5LayerSelfAttention(nn.Module): - def __init__(self, config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - has_relative_attention_bias=False): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + has_relative_attention_bias=False, + ): super().__init__() - self.SelfAttention = T5Attention(config, cache_config, - quant_config, - has_relative_attention_bias=has_relative_attention_bias) - self.layer_norm = T5LayerNorm(config.d_model, - eps=config.layer_norm_epsilon) + self.SelfAttention = T5Attention( + config, + cache_config, + quant_config, + has_relative_attention_bias=has_relative_attention_bias, + ) + self.layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( - hidden_states, - kv_cache, - attn_metadata) + hidden_states, kv_cache, attn_metadata + ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] return outputs class T5LayerCrossAttention(nn.Module): - def __init__(self, config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() - self.EncDecAttention = T5Attention(config, cache_config, - quant_config, has_relative_attention_bias=False) - self.layer_norm = T5LayerNorm(config.d_model, - eps=config.layer_norm_epsilon) + self.EncDecAttention = T5Attention( + config, + cache_config, + quant_config, + has_relative_attention_bias=False, + ) + self.layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -409,20 +464,26 @@ def forward(self, class T5Block(nn.Module): - def __init__(self, config: T5Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - has_relative_attention_bias=False): + def __init__( + self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + has_relative_attention_bias=False, + ): super().__init__() self.is_decoder = config.is_decoder - self.self_attn = T5LayerSelfAttention(config, cache_config, - quant_config, - has_relative_attention_bias=has_relative_attention_bias) + self.self_attn = T5LayerSelfAttention( + config, + cache_config, + quant_config, + has_relative_attention_bias=has_relative_attention_bias, + ) if self.is_decoder: - self.cross_attn = T5LayerCrossAttention(config, - cache_config, quant_config) + self.cross_attn = T5LayerCrossAttention( + config, cache_config, quant_config + ) self.fc = T5LayerFF(config, quant_config) - def forward( self, @@ -431,8 +492,9 @@ def forward( attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - self_attention_outputs = self.self_attn(hidden_states, - kv_cache, attn_metadata) + self_attention_outputs = self.self_attn( + hidden_states, kv_cache, attn_metadata + ) hidden, _ = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] @@ -445,11 +507,13 @@ def forward( ) hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) - do_cross_attention = (self.is_decoder and - encoder_hidden_states is not None) + do_cross_attention = ( + self.is_decoder and encoder_hidden_states is not None + ) if do_cross_attention: - cross_attention_outputs = self.cross_attn(hidden, kv_cache, - attn_metadata, encoder_hidden_states) + cross_attention_outputs = self.cross_attn( + hidden, kv_cache, attn_metadata, encoder_hidden_states + ) hidden = cross_attention_outputs[0] # clamp inf values to enable fp16 training @@ -496,65 +560,85 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.out_proj(hidden_states) return hidden_states + class T5Stack(nn.Module): - def __init__(self, - config: T5Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - embed_tokens=None): + def __init__( + self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + embed_tokens=None, + ): super().__init__() self.cache_config = cache_config self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [T5Block(config, cache_config, quant_config, - has_relative_attention_bias=bool(i == 0)) - for i in range(config.num_layers)] + [ + T5Block( + config, + cache_config, + quant_config, + has_relative_attention_bias=bool(i == 0), + ) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon ) - self.final_layer_norm = T5LayerNorm(config.d_model, - eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - encoder_hidden_states: Optional[torch.Tensor]=None - ) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.dropout(inputs_embeds) # print('t5 stack', type(hidden_states)) for i, layer in enumerate(self.block): - layer_outputs = layer(hidden_states, - kv_caches[i], - attn_metadata, - encoder_hidden_states) + layer_outputs = layer( + hidden_states, + kv_caches[i], + attn_metadata, + encoder_hidden_states, + ) hidden_states = layer_outputs[0] hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states - class T5Model(nn.Module): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", - "decoder.embed_tokens.weight"] - - def __init__(self, - config: T5Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + + def __init__( + self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): super().__init__() # self.shared = nn.Embedding(config.vocab_size, config.d_model) self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.shared = VocabParallelEmbedding( self.vocab_size, @@ -566,76 +650,88 @@ def __init__(self, encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, - cache_config, quant_config, self.shared) + self.encoder = T5Stack( + encoder_config, cache_config, quant_config, self.shared + ) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, - cache_config, quant_config, self.shared) + self.decoder = T5Stack( + decoder_config, cache_config, quant_config, self.shared + ) def forward( self, - input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: encoder_hidden_states = None if encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + encoder_hidden_states = self.encoder( + input_ids=encoder_input_ids, + positions=encoder_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, encoder_hidden_states=encoder_hidden_states, kv_caches=kv_caches, - attn_metadata=attn_metadata) + attn_metadata=attn_metadata, + ) return decoder_outputs + class T5ForConditionalGeneration(nn.Module): - def __init__(self, - config: T5Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None): + def __init__( + self, + config: T5Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): super().__init__() self.config = config self.model_dim = config.d_model - self.model = T5Model(config, - cache_config, - quant_config, - lora_config=lora_config) - print('lora_config', lora_config) + self.model = T5Model( + config, cache_config, quant_config, lora_config=lora_config + ) + print("lora_config", lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead(num_embeddings= self.unpadded_vocab_size, - embedding_dim=config.d_model, - org_num_embeddings=config.vocab_size, - bias=False) - - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.lm_head = ParallelLMHead( + num_embeddings=self.unpadded_vocab_size, + embedding_dim=config.d_model, + org_num_embeddings=config.vocab_size, + bias=False, + ) + + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.sampler = Sampler() def forward( self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, ) -> torch.Tensor: r""" Args: @@ -654,16 +750,23 @@ def forward( Returns: Output torch.Tensor """ - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + return self.model( + input_ids, + positions, + encoder_input_ids, + encoder_positions, + kv_caches, + attn_metadata, + ) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor( + self.lm_head, hidden_states, sampling_metadata + ) return logits def sample( @@ -690,11 +793,9 @@ def sample( "o.weight": { "param_name": "out_proj.weight", "shard_id": None, - } + }, } - - params_mapping = { "beta": "bias", "gamma": "weight", @@ -703,13 +804,13 @@ def sample( def _rename_key(self, key: str): prefix = f"{self.base_model_prefix}." - key = key[len(prefix):] if key.startswith(prefix) else key + key = key[len(prefix) :] if key.startswith(prefix) else key for src, dst in self.params_mapping.items(): key = key.replace(src, dst) return key - + layer_type_mapping = { "encoder": { "layer.0": "self_attn", @@ -719,42 +820,39 @@ def _rename_key(self, key: str): "layer.0": "self_attn", "layer.1": "cross_attn", "layer.2": "fc", - } + }, } - + def _rename_layer_types( self, - name: str, + name: str, ) -> str: for enc_dec, mapping in self.layer_type_mapping.items(): if enc_dec in name: for layer_num in mapping: if layer_num in name: name = name.replace(layer_num, mapping[layer_num]) - return name + return name def _rename_stacked_param( self, name: str, ) -> Tuple[str, Optional[str]]: for key, mapping in self.stacked_params_mapping.items(): - if key in name and '.wo.' not in name: + if key in name and ".wo." not in name: name = name.replace(key, mapping["param_name"]) return name, mapping["shard_id"] return name, None - - + def match_weight_name(self, weights_tuple_list): out = set() for name, _ in weights_tuple_list: - - if 'decoder' in name and 'layer_norm' not in name: - if ('layer.0' in name and \ - 'SelfAttention' not in name) \ - or ('layer.1' in name and \ - 'EncDecAttention' not in name) \ - or ('layer.2' in name and \ - 'DenseReluDense' not in name): + if "decoder" in name and "layer_norm" not in name: + if ( + ("layer.0" in name and "SelfAttention" not in name) + or ("layer.1" in name and "EncDecAttention" not in name) + or ("layer.2" in name and "DenseReluDense" not in name) + ): print(name) out.add(False) # elif 'layer.1' in name and \ @@ -767,11 +865,10 @@ def match_weight_name(self, weights_tuple_list): # out.add(False) else: out.add(True) - elif 'encoder' in name and 'layer_norm' not in name: - if ('layer.0' in name and \ - 'SelfAttention' not in name) \ - or ('layer.1' in name and \ - 'DenseReluDense' not in name): + elif "encoder" in name and "layer_norm" not in name: + if ("layer.0" in name and "SelfAttention" not in name) or ( + "layer.1" in name and "DenseReluDense" not in name + ): print(name) out.add(False) # elif 'layer.1' in name and \ @@ -780,12 +877,11 @@ def match_weight_name(self, weights_tuple_list): # out.add(False) else: out.add(True) - elif 'decoder' not in name and 'encoder' not in name: + elif "decoder" not in name and "encoder" not in name: print(name) return out def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - model_params_dict = dict(self.model.named_parameters()) weights_tuple_list = list(weights) @@ -794,18 +890,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights_tuple_list: name = self._rename_layer_types(name) name, shard_id = self._rename_stacked_param(name) - if ('encoder.embed_tokens.weight' in name - or 'decoder.embed_tokens.weight' in name - or 'lm_head.weight' in name): - assert shared_embedding_weight is None, ( - "Conflicting embedding weights.") + if ( + "encoder.embed_tokens.weight" in name + or "decoder.embed_tokens.weight" in name + or "lm_head.weight" in name + ): + assert ( + shared_embedding_weight is None + ), "Conflicting embedding weights." shared_embedding_weight = loaded_weight else: # Skip the specific downstream task weight. - if name.startswith('cls.'): + if name.startswith("cls."): continue # use Pooler instead. - if name.startswith('pooler.'): + if name.startswith("pooler."): continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in model_params_dict: @@ -814,8 +913,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = model_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) if shard_id: weight_loader(param, loaded_weight, shard_id) else: From a9d0d8d99a495ae75bebed1beec67b08d00826a5 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 20:30:17 -0800 Subject: [PATCH 14/15] yapf format --- vllm/model_executor/models/t5.py | 228 +++++++++++++++---------------- 1 file changed, 107 insertions(+), 121 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 82210ab259d19..095c282a72411 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -35,8 +35,7 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, -) + QuantizationConfig, ) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -97,17 +96,13 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): for txt_name in names: name = txt_name.split("/") - if any( - n - in [ + if any(n in [ "adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step", - ] - for n in name - ): + ] for n in name): log_name = "/".join(name) logger.info("Skipping %s", log_name) tf_weights.pop(txt_name, None) @@ -151,11 +146,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): continue elif scope_names[0] == "logits": pointer = getattr(pointer, "lm_head", None) - elif ( - scope_names[0] == "wi" - and len(scope_names) > 1 - and scope_names[1].isdigit() - ): + elif (scope_names[0] == "wi" and len(scope_names) > 1 + and scope_names[1].isdigit()): pointer = getattr(pointer, f"wi_{scope_names[1]}") continue else: @@ -197,6 +189,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. @@ -217,7 +210,8 @@ def forward(self, hidden_states): inputs is done in fp32 """ - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) adj_var = variance + self.variance_epsilon hidden_states = hidden_states * torch.rsqrt(adj_var) @@ -229,18 +223,21 @@ def forward(self, hidden_states): class T5DenseActDense(nn.Module): + def __init__( self, config: T5Config, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.wi = ColumnParallelLinear( - config.d_model, config.d_ff, bias=False, quant_config=quant_config - ) - self.wo = RowParallelLinear( - config.d_ff, config.d_model, bias=False, quant_config=quant_config - ) + self.wi = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config) self.dropout = nn.Dropout(config.dropout_rate) self.act = get_act_fn(config.dense_act_fn, quant_config) @@ -248,32 +245,34 @@ def forward(self, hidden_states): hidden_states = self.wi(hidden_states)[0] hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): + if (isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states)[0] return hidden_states class T5DenseGatedActDense(nn.Module): + def __init__( self, config: T5Config, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.wi_0 = ColumnParallelLinear( - config.d_model, config.d_ff, bias=False, quant_config=quant_config - ) - self.wi_1 = ColumnParallelLinear( - config.d_model, config.d_ff, bias=False, quant_config=quant_config - ) - self.wo = RowParallelLinear( - config.d_ff, config.d_model, bias=False, quant_config=quant_config - ) + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config) self.dropout = nn.Dropout(config.dropout_rate) self.act = get_act_fn(config.dense_act_fn, quant_config) @@ -288,11 +287,9 @@ def forward(self, hidden_states): # See https://github.com/huggingface/transformers/issues/20287 # we also make sure the weights are not in `int8` in case users # will force `_keep_in_fp32_modules` to be `None`` - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): + if (isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) @@ -300,6 +297,7 @@ def forward(self, hidden_states): class T5LayerFF(nn.Module): + def __init__( self, config: T5Config, @@ -311,9 +309,8 @@ def __init__( else: self.DenseReluDense = T5DenseActDense(config, quant_config) - self.layer_norm = T5LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): @@ -324,6 +321,7 @@ def forward(self, hidden_states): class T5Attention(nn.Module): + def __init__( self, config: T5Config, @@ -380,17 +378,24 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([2048, 2048, 2048], dim=-1) if encoder_hidden_states is None: - attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + dropout_p=0.0) else: qkv_enc, _ = self.qkv_proj(encoder_hidden_states) _, k, v = qkv.split([2048, 2048, 2048], dim=-1) - attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + dropout_p=0.0) output, _ = self.out_proj(attn_output) present_key_value_state = (k, v) if self.is_decoder else None return output, present_key_value_state class T5LayerSelfAttention(nn.Module): + def __init__( self, config, @@ -405,9 +410,8 @@ def __init__( quant_config, has_relative_attention_bias=has_relative_attention_bias, ) - self.layer_norm = T5LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -417,15 +421,15 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - hidden_states, kv_cache, attn_metadata - ) + attention_output = self.SelfAttention(hidden_states, kv_cache, + attn_metadata) hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[1:] + outputs = (hidden_states, ) + attention_output[1:] return outputs class T5LayerCrossAttention(nn.Module): + def __init__( self, config, @@ -439,9 +443,8 @@ def __init__( quant_config, has_relative_attention_bias=False, ) - self.layer_norm = T5LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -459,11 +462,12 @@ def forward( encoder_hidden_states, ) layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[1:] + outputs = (layer_output, ) + attention_output[1:] return outputs class T5Block(nn.Module): + def __init__( self, config: T5Config, @@ -480,9 +484,8 @@ def __init__( has_relative_attention_bias=has_relative_attention_bias, ) if self.is_decoder: - self.cross_attn = T5LayerCrossAttention( - config, cache_config, quant_config - ) + self.cross_attn = T5LayerCrossAttention(config, cache_config, + quant_config) self.fc = T5LayerFF(config, quant_config) def forward( @@ -492,9 +495,8 @@ def forward( attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - self_attention_outputs = self.self_attn( - hidden_states, kv_cache, attn_metadata - ) + self_attention_outputs = self.self_attn(hidden_states, kv_cache, + attn_metadata) hidden, _ = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] @@ -507,13 +509,12 @@ def forward( ) hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) - do_cross_attention = ( - self.is_decoder and encoder_hidden_states is not None - ) + do_cross_attention = (self.is_decoder + and encoder_hidden_states is not None) if do_cross_attention: - cross_attention_outputs = self.cross_attn( - hidden, kv_cache, attn_metadata, encoder_hidden_states - ) + cross_attention_outputs = self.cross_attn(hidden, kv_cache, + attn_metadata, + encoder_hidden_states) hidden = cross_attention_outputs[0] # clamp inf values to enable fp16 training @@ -539,7 +540,7 @@ def forward( ) hidden = torch.clamp(hidden, min=-clamp_value, max=clamp_value) - outputs = (hidden,) + attention_outputs + outputs = (hidden, ) + attention_outputs return outputs @@ -562,6 +563,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class T5Stack(nn.Module): + def __init__( self, config: T5Config, @@ -574,20 +576,16 @@ def __init__( self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder - self.block = nn.ModuleList( - [ - T5Block( - config, - cache_config, - quant_config, - has_relative_attention_bias=bool(i == 0), - ) - for i in range(config.num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.block = nn.ModuleList([ + T5Block( + config, + cache_config, + quant_config, + has_relative_attention_bias=bool(i == 0), + ) for i in range(config.num_layers) + ]) + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -634,11 +632,8 @@ def __init__( super().__init__() # self.shared = nn.Embedding(config.vocab_size, config.d_model) self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.shared = VocabParallelEmbedding( self.vocab_size, @@ -650,17 +645,15 @@ def __init__( encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = T5Stack( - encoder_config, cache_config, quant_config, self.shared - ) + self.encoder = T5Stack(encoder_config, cache_config, quant_config, + self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack( - decoder_config, cache_config, quant_config, self.shared - ) + self.decoder = T5Stack(decoder_config, cache_config, quant_config, + self.shared) def forward( self, @@ -694,6 +687,7 @@ def forward( class T5ForConditionalGeneration(nn.Module): + def __init__( self, config: T5Config, @@ -704,9 +698,10 @@ def __init__( super().__init__() self.config = config self.model_dim = config.d_model - self.model = T5Model( - config, cache_config, quant_config, lora_config=lora_config - ) + self.model = T5Model(config, + cache_config, + quant_config, + lora_config=lora_config) print("lora_config", lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -718,9 +713,8 @@ def __init__( bias=False, ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( @@ -764,9 +758,8 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor( - self.lm_head, hidden_states, sampling_metadata - ) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) return logits def sample( @@ -804,7 +797,7 @@ def sample( def _rename_key(self, key: str): prefix = f"{self.base_model_prefix}." - key = key[len(prefix) :] if key.startswith(prefix) else key + key = key[len(prefix):] if key.startswith(prefix) else key for src, dst in self.params_mapping.items(): key = key.replace(src, dst) @@ -848,11 +841,9 @@ def match_weight_name(self, weights_tuple_list): out = set() for name, _ in weights_tuple_list: if "decoder" in name and "layer_norm" not in name: - if ( - ("layer.0" in name and "SelfAttention" not in name) - or ("layer.1" in name and "EncDecAttention" not in name) - or ("layer.2" in name and "DenseReluDense" not in name) - ): + if (("layer.0" in name and "SelfAttention" not in name) or + ("layer.1" in name and "EncDecAttention" not in name) or + ("layer.2" in name and "DenseReluDense" not in name)): print(name) out.add(False) # elif 'layer.1' in name and \ @@ -867,8 +858,7 @@ def match_weight_name(self, weights_tuple_list): out.add(True) elif "encoder" in name and "layer_norm" not in name: if ("layer.0" in name and "SelfAttention" not in name) or ( - "layer.1" in name and "DenseReluDense" not in name - ): + "layer.1" in name and "DenseReluDense" not in name): print(name) out.add(False) # elif 'layer.1' in name and \ @@ -890,14 +880,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights_tuple_list: name = self._rename_layer_types(name) name, shard_id = self._rename_stacked_param(name) - if ( - "encoder.embed_tokens.weight" in name - or "decoder.embed_tokens.weight" in name - or "lm_head.weight" in name - ): - assert ( - shared_embedding_weight is None - ), "Conflicting embedding weights." + if ("encoder.embed_tokens.weight" in name + or "decoder.embed_tokens.weight" in name + or "lm_head.weight" in name): + assert (shared_embedding_weight is + None), "Conflicting embedding weights." shared_embedding_weight = loaded_weight else: # Skip the specific downstream task weight. @@ -913,9 +900,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = model_params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) if shard_id: weight_loader(param, loaded_weight, shard_id) else: From fbaa5fea03d0eb7b33c45ab37476b8dded604af7 Mon Sep 17 00:00:00 2001 From: bmullick-amd Date: Tue, 17 Dec 2024 20:39:17 -0800 Subject: [PATCH 15/15] fix isort check --- vllm/model_executor/models/t5.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 095c282a72411..90826cc46f456 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -28,19 +28,15 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, ) + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata