From be72d33fdbb985d1c2469e646225b5f5c6a03f79 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 3 Apr 2024 11:14:25 -0700 Subject: [PATCH] Apply black formatting --- server/lorax_server/adapters/__init__.py | 9 +- server/lorax_server/adapters/config.py | 8 +- server/lorax_server/adapters/lora.py | 79 +++--- server/lorax_server/adapters/medusa.py | 38 +-- server/lorax_server/adapters/utils.py | 12 +- server/lorax_server/adapters/weights.py | 23 +- server/lorax_server/cli.py | 24 +- server/lorax_server/interceptor.py | 4 +- server/lorax_server/models/__init__.py | 20 +- server/lorax_server/models/bloom.py | 57 ++-- server/lorax_server/models/cache_manager.py | 16 +- server/lorax_server/models/causal_lm.py | 126 ++++----- .../models/custom_modeling/bloom_modeling.py | 93 +++--- .../custom_modeling/flash_gemma_modeling.py | 108 ++++--- .../custom_modeling/flash_gpt2_modeling.py | 85 +++--- .../custom_modeling/flash_llama_modeling.py | 125 +++++---- .../custom_modeling/flash_mistral_modeling.py | 125 +++++---- .../custom_modeling/flash_mixtral_modeling.py | 264 +++++++++--------- .../custom_modeling/flash_neox_modeling.py | 28 +- .../custom_modeling/flash_phi_modeling.py | 94 ++++--- .../custom_modeling/flash_qwen2_modeling.py | 112 ++++---- .../custom_modeling/flash_qwen_modeling.py | 110 +++++--- .../custom_modeling/flash_rw_modeling.py | 36 +-- .../flash_santacoder_modeling.py | 44 +-- .../models/custom_modeling/mpt_modeling.py | 125 ++------- .../models/custom_modeling/neox_modeling.py | 87 ++---- .../models/custom_modeling/opt_modeling.py | 101 ++----- .../models/custom_modeling/t5_modeling.py | 158 +++-------- server/lorax_server/models/flash_causal_lm.py | 232 +++++++-------- server/lorax_server/models/flash_gemma.py | 35 ++- server/lorax_server/models/flash_gpt2.py | 18 +- server/lorax_server/models/flash_llama.py | 46 ++- server/lorax_server/models/flash_mistral.py | 44 ++- server/lorax_server/models/flash_mixtral.py | 96 +++---- server/lorax_server/models/flash_neox.py | 4 +- server/lorax_server/models/flash_phi.py | 40 ++- server/lorax_server/models/flash_qwen.py | 41 ++- server/lorax_server/models/flash_qwen2.py | 58 ++-- server/lorax_server/models/galactica.py | 22 +- server/lorax_server/models/gpt_neox.py | 10 +- server/lorax_server/models/model.py | 106 ++++--- server/lorax_server/models/mpt.py | 10 +- server/lorax_server/models/opt.py | 10 +- server/lorax_server/models/rw.py | 2 +- server/lorax_server/models/santacoder.py | 2 +- server/lorax_server/models/seq2seq_lm.py | 52 ++-- server/lorax_server/models/t5.py | 2 +- server/lorax_server/models/types.py | 15 +- server/lorax_server/server.py | 68 +++-- server/lorax_server/tracing.py | 4 +- server/lorax_server/utils/adapter.py | 53 ++-- server/lorax_server/utils/awq/awq.py | 24 +- server/lorax_server/utils/convert.py | 20 +- server/lorax_server/utils/flash_attn.py | 11 +- .../utils/gptq/custom_autotune.py | 15 +- server/lorax_server/utils/gptq/exllamav2.py | 53 ++-- .../lorax_server/utils/gptq/quant_linear.py | 17 +- server/lorax_server/utils/gptq/quantize.py | 62 +--- server/lorax_server/utils/graph.py | 80 +++--- server/lorax_server/utils/layers.py | 232 ++++++++------- server/lorax_server/utils/logits_process.py | 50 ++-- .../lorax_server/utils/merges/strategies.py | 42 ++- server/lorax_server/utils/merges/utils.py | 9 +- server/lorax_server/utils/paged_attn.py | 22 +- server/lorax_server/utils/segments.py | 6 +- server/lorax_server/utils/sgmv.py | 14 +- server/lorax_server/utils/sources/__init__.py | 25 +- server/lorax_server/utils/sources/hub.py | 38 ++- server/lorax_server/utils/sources/local.py | 20 +- server/lorax_server/utils/sources/s3.py | 58 ++-- server/lorax_server/utils/sources/source.py | 46 +-- server/lorax_server/utils/tokenizer.py | 10 +- server/lorax_server/utils/tokens.py | 67 ++--- server/lorax_server/utils/watermark.py | 12 +- server/lorax_server/utils/weights.py | 34 +-- server/pyproject.toml | 10 + 76 files changed, 1978 insertions(+), 2080 deletions(-) diff --git a/server/lorax_server/adapters/__init__.py b/server/lorax_server/adapters/__init__.py index 4fc62fc9b..2b8336dd3 100644 --- a/server/lorax_server/adapters/__init__.py +++ b/server/lorax_server/adapters/__init__.py @@ -15,14 +15,15 @@ def load_adapter_config( ) -> AdapterConfig: if adapter_config_path is not None and adapter_config_path.exists(): return LoraConfig.load(str(adapter_config_path.parent), api_token) - + if config_path is not None and config_path.exists(): config = json.load(config_path.open()) if "medusa_num_heads" in config: return MedusaConfig.load(config) - - raise ValueError(f"No valid adapter config file found: " - f"tried {adapter_config_path} and {config_path}") + + raise ValueError( + f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}" + ) __all__ = [ diff --git a/server/lorax_server/adapters/config.py b/server/lorax_server/adapters/config.py index 2c5bf8d36..0ce801ecc 100644 --- a/server/lorax_server/adapters/config.py +++ b/server/lorax_server/adapters/config.py @@ -19,15 +19,17 @@ class AdapterConfig(ABC): @abstractmethod def map_weights_for_model( - self, adapter_weights: Dict, weight_names: Tuple[str], + self, + adapter_weights: Dict, + weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: pass @abstractmethod def load_batched_adapter_weights( - self, + self, model: "Model", - module_map: Dict[str, Dict], + module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], dynamic: bool, diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 522a1be77..a4c3d7426 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -26,7 +26,9 @@ class LoraConfig(AdapterConfig): use_rslora: bool def map_weights_for_model( - self, adapter_weights: Dict, weight_names: Tuple[str], + self, + adapter_weights: Dict, + weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: adapter_weight_names = set() module_map = {} @@ -35,7 +37,7 @@ def map_weights_for_model( lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: continue - + module_map[weight_name] = { "lora_A": (adapter_weights[lora_a_name], lora_a_name), "lora_B": (adapter_weights[lora_b_name], lora_b_name), @@ -43,11 +45,11 @@ def map_weights_for_model( adapter_weight_names.add(lora_a_name) adapter_weight_names.add(lora_b_name) return module_map, adapter_weight_names - + def load_batched_adapter_weights( - self, + self, model: "Model", - module_map: Dict[str, Dict], + module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], dynamic: bool, @@ -59,7 +61,7 @@ def load_batched_adapter_weights( layer_type, unused_weight_names, ) - + @classmethod def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) @@ -86,27 +88,24 @@ def __init__( self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 # [num_layers, hidden_size, r] - weights_a = [ - orient_for_rank(w, w.size(1)).contiguous() - for w in weights_a - ] + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] self.weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] self.weights_b = torch.stack(weights_b) self.adapter_config = adapter_config - + @classmethod def get_batch_type(cls) -> BatchAdapterWeights: return BatchLoraWeights @classmethod def load( - cls, + cls, config: LoraConfig, model: "Model", - module_map: Dict[str, Dict], + module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], ) -> Optional[AdapterWeights]: @@ -155,7 +154,8 @@ def load( config.r = padded_rank return LoraWeights( - *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), config, + *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), + config, ) @@ -176,60 +176,55 @@ class BatchLoraWeights(BatchAdapterWeights): lora_b: Dict[int, torch.Tensor] adapter_index_configs: Dict[int, LoraConfig] rank_data: Dict[int, RankSegments] - + def has_adapter(self, adapter_index: int) -> bool: return adapter_index in self.adapter_index_configs - + def can_vectorize(self, pg: ProcessGroup) -> bool: return all( - rank_data.rank // pg.size() <= MAX_RANK_CUSTOM - for rank_data in self.rank_data.values() + rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values() ) - + @classmethod def key(cls) -> str: return LORA @classmethod - def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights": - adapter_weights = { - k: v - for k, v in adapter_weights.items() - if isinstance(v, LoraWeights) - } + def load( + self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata + ) -> "BatchLoraWeights": + adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)} first_weights = list(adapter_weights.values())[0] device = first_weights.weights_a.device segment_indices = meta.segment_indices lora_a = { - idx: adapter_weights[idx].weights_a - for idx in segment_indices - if idx in adapter_weights + idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights } lora_a_ptr = torch.tensor( [ ( adapter_weights[idx].weights_a.data_ptr() - if idx in adapter_weights + if idx in adapter_weights else EMPTY_TENSOR.data_ptr() - ) for idx in segment_indices + ) + for idx in segment_indices ], dtype=torch.int64, device=device, ) lora_b = { - idx: adapter_weights[idx].weights_b - for idx in segment_indices - if idx in adapter_weights + idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights } lora_b_ptr = torch.tensor( [ ( - adapter_weights[idx].weights_b.data_ptr() - if idx in adapter_weights + adapter_weights[idx].weights_b.data_ptr() + if idx in adapter_weights else EMPTY_TENSOR.data_ptr() - ) for idx in segment_indices + ) + for idx in segment_indices ], dtype=torch.int64, device=device, @@ -250,11 +245,7 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet rank_data = {} for rank, indices in rank_indices.items(): lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors( - lora_a_ptr_indices.size(0), - rank, - device - ) + tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) rank_data[rank] = RankSegments( rank=rank, @@ -263,11 +254,11 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet lora_a_ptr=lora_a_ptr_indices, lora_b_ptr=lora_b_ptr[indices], segment_starts=meta.adapter_segments[indices], - segment_ends=meta.adapter_segments[[i+1 for i in indices]], + segment_ends=meta.adapter_segments[[i + 1 for i in indices]], ) return BatchLoraWeights( - lora_a=lora_a, + lora_a=lora_a, lora_b=lora_b, adapter_index_configs=adapter_index_configs, rank_data=rank_data, @@ -281,5 +272,5 @@ def get_scaling_factor( ) -> float: """Computes the scaling factor for the lora weights.""" if uses_rslora: - return lora_alpha / (r ** 0.5) + return lora_alpha / (r**0.5) return lora_alpha / r diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 174ff224d..8fbc4fe83 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -19,15 +19,17 @@ class MedusaConfig(AdapterConfig): medusa_num_layers: int def map_weights_for_model( - self, adapter_weights: Dict, weight_names: Tuple[str], + self, + adapter_weights: Dict, + weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: # TODO(travis): this isn't technically the ModuleMap structure, make this more generic return adapter_weights, set(weight_names) - + def load_batched_adapter_weights( - self, + self, model: "Model", - module_map: Dict[str, Dict], + module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], dynamic: bool, @@ -37,7 +39,7 @@ def load_batched_adapter_weights( "Dynamic adapter loading is not supported for Medusa at this time. " "Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every request." ) - + return MedusaWeights.load( self, model, @@ -58,9 +60,7 @@ def load(cls, config: dict) -> "MedusaConfig": class ResBlock(torch.nn.Module): def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights): super().__init__() - self.linear = FastLinear.load( - config, prefix=f"{prefix}.linear", weights=weights, bias=True - ) + self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True) self.act = torch.nn.SiLU() def forward(self, x): @@ -77,9 +77,7 @@ def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights): ] ) n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) + self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False) def forward(self, x): for block in self.blocks: @@ -107,21 +105,21 @@ class MedusaWeights(AdapterWeights): def __init__(self, config: MedusaConfig, module_map: ModuleMap, model: "Model"): self.config = config self.model = MedusaModel(config, InMemoryWeights(module_map, model.device, model.dtype)) - + @classmethod def get_batch_type(cls) -> BatchAdapterWeights: return BatchMedusaWeights - + @property def speculative_tokens(self) -> int: return self.config.medusa_num_heads @classmethod def load( - cls, + cls, config: MedusaConfig, model: "Model", - module_map: Dict[str, Dict], + module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], ) -> Optional[AdapterWeights]: @@ -146,16 +144,10 @@ def key(cls) -> str: def load( cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata" ) -> "BatchMedusaWeights": - adapter_weights = { - k: v - for k, v in adapter_weights.items() - if isinstance(v, MedusaWeights) - } + adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)} adapter_to_medusa = { - idx: adapter_weights[idx] - for idx in meta.segment_indices - if idx in adapter_weights + idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights } return BatchMedusaWeights( diff --git a/server/lorax_server/adapters/utils.py b/server/lorax_server/adapters/utils.py index 88aea4a17..58ab33a28 100644 --- a/server/lorax_server/adapters/utils.py +++ b/server/lorax_server/adapters/utils.py @@ -14,18 +14,18 @@ def download_adapter( if adapter_source == PBASE: adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) adapter_source = S3 - + if adapter_source == HUB: # Quick auth check on the repo against the token HfApi(token=api_token).model_info(adapter_id, revision=None) - + # fail fast if ID is not an adapter (i.e. it is a full model) - source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) + source = get_model_source( + adapter_source, adapter_id, extension=".safetensors", api_token=api_token + ) source.load_config() - download_weights( - adapter_id, source=adapter_source, api_token=api_token - ) + download_weights(adapter_id, source=adapter_source, api_token=api_token) # Calculate size of adapter to be loaded return source.get_weight_bytes() diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 94544888c..2cd7d8eb4 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -45,7 +45,9 @@ def key(cls) -> str: pass @abstractclassmethod - def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchAdapterWeights": + def load( + cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata" + ) -> "BatchAdapterWeights": pass @@ -62,12 +64,11 @@ def remove_adapter(self, adapter_idx: int): if adapter_idx not in self.adapter_weights: return del self.adapter_weights[adapter_idx] - + @property def max_speculative_tokens(self) -> int: return max( - adapter_weights.speculative_tokens - for adapter_weights in self.adapter_weights.values() + adapter_weights.speculative_tokens for adapter_weights in self.adapter_weights.values() ) def is_empty(self) -> bool: @@ -75,10 +76,12 @@ def is_empty(self) -> bool: def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights]: # bucket adapters by batch class - adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict) + adapter_batch_types: Dict[ + Type[BatchAdapterWeights], Dict[int, AdapterWeights] + ] = defaultdict(dict) for adapter_index, adapter_weights in self.adapter_weights.items(): adapter_batch_types[adapter_weights.get_batch_type()][adapter_index] = adapter_weights - + batch_data = {} for batch_type, adapter_weights in adapter_batch_types.items(): batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta) @@ -93,14 +96,16 @@ class AdapterBatchData: data: Dict[str, Dict[str, BatchAdapterWeights]] @staticmethod - def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]) -> "AdapterBatchData": + def from_meta( + meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights] + ) -> "AdapterBatchData": data = {} for k, v in weights.items(): if v.is_empty(): continue data[k] = v.get_data(meta) return AdapterBatchData(meta=meta, data=data) - + def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation return set( @@ -108,7 +113,7 @@ def ranks(self) -> Set[int]: for layer_data in self.data.values() for rank_data in layer_data.get(LORA, []).rank_data.values() ) - + @property def max_rank(self) -> int: ranks = self.ranks() diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 6ffef7f7e..e3575a63d 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -49,9 +49,7 @@ def serve( speculative_tokens: int = 0, ): if sharded: - assert ( - os.getenv("RANK", None) is not None - ), "RANK must be set when sharded is True" + assert os.getenv("RANK", None) is not None, "RANK must be set when sharded is True" assert ( os.getenv("WORLD_SIZE", None) is not None ), "WORLD_SIZE must be set when sharded is True" @@ -90,16 +88,16 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, - adapter_id, - revision, - sharded, - quantize, - compile, - dtype, - trust_remote_code, - uds_path, - source, + model_id, + adapter_id, + revision, + sharded, + quantize, + compile, + dtype, + trust_remote_code, + uds_path, + source, adapter_source, speculative_tokens, ) diff --git a/server/lorax_server/interceptor.py b/server/lorax_server/interceptor.py index 56733d5d0..0b72620d2 100644 --- a/server/lorax_server/interceptor.py +++ b/server/lorax_server/interceptor.py @@ -44,7 +44,5 @@ async def intercept( torch.cuda.empty_cache() await context.abort_with_status( - rpc_status.to_status( - status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) - ) + rpc_status.to_status(status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))) ) diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 606564c23..91e248f0c 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -72,9 +72,9 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) - else: + else: raise ValueError(f"Unknown source {source}") - + model_type = config_dict["model_type"] if dtype is None: @@ -117,10 +117,14 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - + if model_type == "mpt": return MPTSharded( - model_id, revision, quantize=quantize, compile=compile, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + compile=compile, + trust_remote_code=trust_remote_code, ) if model_type == "gpt_neox": @@ -188,7 +192,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - + if model_type == "mixtral": from lorax_server.models.flash_mixtral import FlashMixtral @@ -202,7 +206,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - + if model_type == "qwen": from lorax_server.models.flash_qwen import FlashQwen @@ -230,7 +234,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - + if model_type in ["phi-msft", "phi"]: from lorax_server.models.flash_phi import FlashPhi @@ -244,7 +248,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - + if model_type == "gemma": from lorax_server.models.flash_gemma import FlashGemma diff --git a/server/lorax_server/models/bloom.py b/server/lorax_server/models/bloom.py index 020492415..9bde0ff36 100644 --- a/server/lorax_server/models/bloom.py +++ b/server/lorax_server/models/bloom.py @@ -42,7 +42,9 @@ def from_pb( dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device) + batch = super().from_pb( + pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device + ) batch.keys_head_dim_last = False return batch @@ -59,7 +61,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with Bloom") - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -88,9 +90,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) + weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq", "eetq"]: weights._set_gptq_params(model_id) @@ -110,22 +110,21 @@ def __init__( self.dynamic_adapter_loading_enabled = True - @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch - + @property def has_adapter_data(self) -> bool: return True def forward( - self, - input_ids, - attention_mask, - position_ids, - past_key_values: Optional = None, - adapter_data: Optional[AdapterBatchData] = None + self, + input_ids, + attention_mask, + position_ids, + past_key_values: Optional = None, + adapter_data: Optional[AdapterBatchData] = None, ): outputs = self.model.forward( input_ids=input_ids, @@ -138,32 +137,44 @@ def forward( logits = outputs.logits return logits, outputs.past_key_values - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "transformer.h" for i, layer in enumerate(self.model.transformer.h): - layer_weights[(i, ATTN_QKV)] = (f"{prefix}.{i}.self_attention.query_key_value", layer.self_attention.query_key_value) - layer_weights[(i, ATTN_DENSE)] = (f"{prefix}.{i}.self_attention.dense", layer.self_attention.dense) - - layer_weights[(i, MLP_DENSE_H_TO_4H)] = (f"{prefix}.{i}.mlp.dense_h_to_4h", layer.mlp.dense_h_to_4h) - layer_weights[(i, MLP_DENSE_4H_TO_H)] = (f"{prefix}.{i}.mlp.dense_4h_to_h", layer.mlp.dense_4h_to_h) + layer_weights[(i, ATTN_QKV)] = ( + f"{prefix}.{i}.self_attention.query_key_value", + layer.self_attention.query_key_value, + ) + layer_weights[(i, ATTN_DENSE)] = ( + f"{prefix}.{i}.self_attention.dense", + layer.self_attention.dense, + ) + + layer_weights[(i, MLP_DENSE_H_TO_4H)] = ( + f"{prefix}.{i}.mlp.dense_h_to_4h", + layer.mlp.dense_h_to_4h, + ) + layer_weights[(i, MLP_DENSE_4H_TO_H)] = ( + f"{prefix}.{i}.mlp.dense_4h_to_h", + layer.mlp.dense_4h_to_h, + ) # TODO: make Embedding layers adapter-compatible # layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/cache_manager.py b/server/lorax_server/models/cache_manager.py index d6df44bab..c83a5b224 100644 --- a/server/lorax_server/models/cache_manager.py +++ b/server/lorax_server/models/cache_manager.py @@ -42,9 +42,9 @@ def __init__( for _ in range(num_layers) ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") - self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int64 - ).view(num_blocks, self.block_size) + self.slots = torch.arange(0, num_blocks * self.block_size, dtype=torch.int64).view( + num_blocks, self.block_size + ) def allocate( self, @@ -64,9 +64,7 @@ def allocate( block_indices = block_indices.flatten() # Padded block tables - block_tables_tensor = torch.zeros( - (len(needed_blocks_slots), max_blocks), dtype=torch.int32 - ) + block_tables_tensor = torch.zeros((len(needed_blocks_slots), max_blocks), dtype=torch.int32) # Allocate paged attention blocks cumulative_blocks = 0 @@ -74,9 +72,7 @@ def allocate( block_tables = [] for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): # Get allocated blocks for this sequence - allocated_blocks = block_indices[ - cumulative_blocks : cumulative_blocks + needed_blocks - ] + allocated_blocks = block_indices[cumulative_blocks : cumulative_blocks + needed_blocks] # Get slots for the allocated blocks all_slots = self.slots[allocated_blocks].flatten() @@ -132,4 +128,4 @@ def get_cache_manager() -> CacheManager: if CACHE_MANAGER is None: raise RuntimeError("cache manager was not initialized") - return CACHE_MANAGER \ No newline at end of file + return CACHE_MANAGER diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 3d8f745e7..270998c55 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -95,15 +95,11 @@ def from_pb( req_inputs = tokenizers.get_inputs(r, tokenizer) inputs.append(req_inputs) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) + stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) + padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) adapter_indices_list.append(r.adapter_index) adapter_set.add(r.adapter_index) @@ -127,9 +123,7 @@ def from_pb( input_ids = tokenized_inputs["input_ids"] # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) + attention_mask = input_ids.new_zeros((pb.size, max_input_length + padding_right_offset)) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] @@ -211,12 +205,10 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) + new_padding_right_offset = max(new_padding_right_offset, remaining_decode_tokens) adapter_set.add(self.requests[idx].adapter_index) @@ -225,14 +217,12 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: position_ids = self.position_ids[keep_indices] adapter_indices = self.adapter_meta.adapter_indices[keep_indices] self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length): ( - self.attention_mask.shape[ - 1] - - self.padding_right_offset - ) - + new_padding_right_offset, - ] + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: @@ -354,13 +344,19 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # Copy over adapter indices adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] - adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices + adapter_end_index = ( + cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[ + adapter_start_index:adapter_end_index + ] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) # Update adapter segments - adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) # Create padded tensor if attention_mask is None: @@ -372,17 +368,15 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset + batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset ) attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, + start_index:end_index, + left_offset:-padding_right_offset, ] = batch.attention_mask[ :, - batch_left_offset: -batch.padding_right_offset, - ] + batch_left_offset : -batch.padding_right_offset, + ] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -405,9 +399,9 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": layer[k] = t.view(len(batch), -1, *t.shape[-2:]) # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) + max_tokens += batch.max_tokens + (max_input_length - batch.max_input_length) * len( + batch + ) start_index = end_index @@ -447,21 +441,19 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # We slice the keys to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : - ] = past_keys[:, :, -past_seq_len:, :] + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = past_keys[ + :, :, -past_seq_len:, : + ] else: # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: - ] = past_keys[:, :, :, -past_seq_len:] + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = past_keys[ + :, :, :, -past_seq_len: + ] del past_keys start_index = end_index - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) + padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) start_index = 0 for batch in batches: past_values = batch.past_key_values[j][1] @@ -472,9 +464,9 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": end_index = start_index + len(batch) # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = past_values[:, :, -past_seq_len:, :] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = past_values[ + :, :, -past_seq_len:, : + ] del past_values # Update values @@ -597,15 +589,15 @@ def forward( attention_mask, position_ids, past_key_values: Optional = None, - adapter_data: Optional[AdapterBatchData] = None + adapter_data: Optional[AdapterBatchData] = None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { - "input_ids" : input_ids, - "attention_mask" : attention_mask, + "input_ids": input_ids, + "attention_mask": attention_mask, "past_key_values": past_key_values, - "use_cache" : True, - "return_dict" : True, + "use_cache": True, + "return_dict": True, } if self.has_position_ids: kwargs["position_ids"] = position_ids @@ -654,19 +646,17 @@ def generate_token( # For each member of the batch for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) + next_token_id, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits[-1:, :]) # Append next token to all tokens all_input_ids = torch.cat([all_input_ids, next_token_id]) @@ -693,9 +683,7 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens:, 0] - ) + output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :, 0]) # Get seed if isinstance(next_token_chooser.choice, Sampling): seed = next_token_chooser.choice.seed @@ -711,11 +699,9 @@ def generate_token( # Prefill if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() + prefill_logprobs = [float("nan")] + torch.log_softmax(logits, -1).gather( + 1, all_input_ids[1:] + ).squeeze(1)[-new_input_length:-1].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/lorax_server/models/custom_modeling/bloom_modeling.py b/server/lorax_server/models/custom_modeling/bloom_modeling.py index 22d9a9fc3..606feae4e 100644 --- a/server/lorax_server/models/custom_modeling/bloom_modeling.py +++ b/server/lorax_server/models/custom_modeling/bloom_modeling.py @@ -199,15 +199,9 @@ def _split_heads( fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim) query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1) - query_layer = query_layer.transpose(1, 2).reshape( - batch_size * num_heads, seq_length, head_dim - ) - key_layer = key_layer.permute(0, 2, 3, 1).reshape( - batch_size * num_heads, head_dim, seq_length - ) - value_layer = value_layer.transpose(1, 2).reshape( - batch_size * num_heads, seq_length, head_dim - ) + query_layer = query_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * num_heads, head_dim, seq_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim) return query_layer, key_layer, value_layer @@ -278,7 +272,7 @@ def __init__(self, prefix, config: BloomConfig, weights, layer_id): weights=weights, bias=True, ), - layer_id, + layer_id, [ATTN_QKV], sizes=None, process_group=weights.process_group, @@ -351,9 +345,7 @@ def compute_attention( attn_weights = attention_scores.masked_fill_( attention_mask, torch.finfo(attention_scores.dtype).min ) - attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - input_dtype - ) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) # # [batch_size, num_heads, q_length, kv_length] # attention_probs = self.attention_dropout(attention_probs) @@ -365,9 +357,7 @@ def compute_attention( context_layer = torch.bmm(attention_probs, value_layer, out=query_layer) # change view [batch_size, num_heads, q_length, head_dim] - context_layer = _merge_heads( - context_layer, num_heads=num_heads, head_dim=head_dim - ) + context_layer = _merge_heads(context_layer, num_heads=num_heads, head_dim=head_dim) return context_layer, present, attention_probs @@ -384,7 +374,8 @@ def forward( output_attentions: bool = False, ): fused_qkv = self.query_key_value( - hidden_states, adapter_data, + hidden_states, + adapter_data, ) # [batch_size, seq_length, 3 x hidden_size] batch_size, q_length, _ = fused_qkv.shape @@ -397,9 +388,7 @@ def forward( if CUSTOM_KERNELS_ENABLED: assert self.training is False, "Only foward pass was implemented" - assert ( - attention_mask.shape[-1] < 4096 - ), "Custom kernel support only up to 4096 tokens" + assert attention_mask.shape[-1] < 4096, "Custom kernel support only up to 4096 tokens" ( context_layer, present, @@ -460,7 +449,7 @@ def __init__(self, prefix, config: BloomConfig, weights, layer_id): TensorParallelColumnLinear.load( config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True ), - layer_id, + layer_id, [MLP_DENSE_H_TO_4H], sizes=None, process_group=weights.process_group, @@ -477,24 +466,24 @@ def __init__(self, prefix, config: BloomConfig, weights, layer_id): self.hidden_dropout = config.hidden_dropout def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, adapter_data: Optional[AdapterBatchData] = None, ) -> torch.Tensor: hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states, adapter_data)) - if self.pretraining_tp > 1 and self.slow_but_exact and ( - adapter_data is None or adapter_data.max_rank == 0 + if ( + self.pretraining_tp > 1 + and self.slow_but_exact + and (adapter_data is None or adapter_data.max_rank == 0) ): intermediate_output = torch.zeros_like(residual) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): intermediate_output = intermediate_output + F.linear( hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense_4h_to_h.weight[ - :, int(i * slices) : int((i + 1) * slices) - ], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: intermediate_output = self.dense_4h_to_h(hidden_states, adapter_data) @@ -517,7 +506,10 @@ def __init__(self, layer_id: int, config: BloomConfig, weights): ) self.num_heads = config.n_head self.self_attention = BloomAttention( - prefix=f"{prefix}.self_attention", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attention", + config=config, + weights=weights, + layer_id=layer_id, ) self.post_attention_layernorm = LayerNorm.load( prefix=f"{prefix}.post_attention_layernorm", @@ -525,7 +517,9 @@ def __init__(self, layer_id: int, config: BloomConfig, weights): eps=config.layer_norm_epsilon, ) - self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) + self.mlp = BloomMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id + ) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm ) @@ -645,9 +639,7 @@ def __init__(self, config: BloomConfig, weights): self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.word_embeddings = TensorParallelEmbedding( - prefix="word_embeddings", weights=weights - ) + self.word_embeddings = TensorParallelEmbedding(prefix="word_embeddings", weights=weights) self.word_embeddings_layernorm = LayerNorm.load( prefix="word_embeddings_layernorm", @@ -664,9 +656,7 @@ def __init__(self, config: BloomConfig, weights): ) # Final Layer Norm - self.ln_f = LayerNorm.load( - prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon - ) + self.ln_f = LayerNorm.load(prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon) def _prepare_attn_mask( self, @@ -725,9 +715,7 @@ def forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states @@ -735,14 +723,10 @@ def forward( else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: @@ -792,9 +776,7 @@ def forward( if hasattr(self, "tp_rank"): assert self.num_heads % self.tp_world_size == 0 block_size = self.num_heads // self.tp_world_size - alibi = alibi[ - :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size - ] + alibi = alibi[:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size] alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past) causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) else: @@ -823,9 +805,7 @@ def forward( presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], - ) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -863,7 +843,10 @@ def __init__(self, config, weights): config, prefix="word_embeddings", weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group + ), + 0, + LM_HEAD, + process_group=weights.process_group, ) def prepare_inputs_for_generation( @@ -928,9 +911,7 @@ def forward( if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index b5ecb6b0e..b609252ff 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -37,7 +37,16 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) class GemmaConfig(PretrainedConfig): @@ -106,17 +115,21 @@ def _norm(self, x): def forward(self, hidden_states): output = self._norm(hidden_states.float()).type_as(hidden_states) return output * (1 + self.weight) - + def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.head_dim return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[ + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -154,9 +167,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) class GemmaAttention(torch.nn.Module): @@ -188,18 +199,21 @@ def __init__( f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -207,9 +221,9 @@ def __init__( def get_query_key_value_weights(self, clone=True): """Gets the query, key, and value weights from the attention layer. - + If `clone`, then the weights are cloned before being returned. - + NOTE: if not `clone`, then the weights are returned as views, meaning that changes to the weights will be reflected in the attention layer. """ @@ -253,9 +267,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -299,9 +311,7 @@ def __init__(self, prefix, config, weights, layer_id): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) # Fuse gate and up proj @@ -313,7 +323,11 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_proj = TensorParallelMultiAdapterLinear.load( - gate_proj, layer_id, [GATE_PROJ], sizes=[config.intermediate_size], process_group=weights.process_group + gate_proj, + layer_id, + [GATE_PROJ], + sizes=[config.intermediate_size], + process_group=weights.process_group, ) up_proj = TensorParallelColumnLinear.load_multi( @@ -324,18 +338,25 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.up_proj = TensorParallelMultiAdapterLinear.load( - up_proj, layer_id, [UP_PROJ], sizes=[config.intermediate_size], process_group=weights.process_group + up_proj, + layer_id, + [UP_PROJ], + sizes=[config.intermediate_size], + process_group=weights.process_group, ) - self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ), layer_id, DOWN_PROJ, process_group=weights.process_group) - self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() def forward(self, hidden_states, adapter_data): gate_states = self.gate_proj(hidden_states, adapter_data) @@ -353,9 +374,14 @@ def __init__(self, layer_id, config, weights): prefix = f"model.layers.{layer_id}" self.self_attn = GemmaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = GemmaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.input_layernorm = GemmaRMSNorm( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -414,9 +440,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ GemmaDecoderLayer( @@ -427,9 +451,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = GemmaRMSNorm( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps - ) + self.norm = GemmaRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -520,5 +542,5 @@ def forward( # lm_head reuses the weights of the embedding layer logits = hidden_states @ self.embed_t - logits = logits[:, :self.vocab_size] + logits = logits[:, : self.vocab_size] return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index 11f6fbcbf..b6d3439d0 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -65,11 +65,14 @@ def load_attention(config, prefix, weights, layer_id, layer_names, fan_in_fan_ou base_layer = load_attention_multi(config, prefix, weights, fan_in_fan_out=fan_in_fan_out) projection_size = config.n_embd return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, layer_names, sizes=[ + base_layer, + layer_id, + layer_names, + sizes=[ 3 * projection_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - class FlashGPT2Attention(torch.nn.Module): @@ -98,10 +101,10 @@ def __init__(self, config, prefix, weights, layer_id): self.scale_attn_weights = config.scale_attn_weights if self.scale_attn_weights: - self.softmax_scale = self.head_dim ** -0.5 + self.softmax_scale = self.head_dim**-0.5 else: self.softmax_scale = 1.0 - + if config.add_cross_attention: raise ValueError("Cross attention in GPT-2 is not supported.") @@ -110,14 +113,21 @@ def __init__(self, config, prefix, weights, layer_id): self.layer_idx = layer_id self.reorder_and_upcast_attn = config.reorder_and_upcast_attn - self.c_attn = load_attention(config, prefix, weights, layer_id, [ATTN_C_ATTN], fan_in_fan_out=True) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=True, - fan_in_fan_out=True, - ), layer_id, ATTN_C_PROJ, process_group=weights.process_group) + self.c_attn = load_attention( + config, prefix, weights, layer_id, [ATTN_C_ATTN], fan_in_fan_out=True + ) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + fan_in_fan_out=True, + ), + layer_id, + ATTN_C_PROJ, + process_group=weights.process_group, + ) self.pruned_heads = set() @@ -140,7 +150,6 @@ def __init__(self, config, prefix, weights, layer_id): ) self.num_key_value_heads = self.num_heads - def forward( self, hidden_states, @@ -150,14 +159,12 @@ def forward( slots, input_lengths, max_s, - adapter_data + adapter_data, ): qkv = self.c_attn(hidden_states, adapter_data) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - paged_attn.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(qkv[:, 0]) @@ -208,25 +215,26 @@ def __init__(self, config, prefix, weights, layer_id): # https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config.n_inner n_inner = config.n_inner if config.n_inner is not None else config.n_embd * 4 self.c_fc = TensorParallelMultiAdapterLinear.load( - c_fc, - layer_id, - [MLP_C_FC], - sizes=[n_inner], - process_group=weights.process_group + c_fc, layer_id, [MLP_C_FC], sizes=[n_inner], process_group=weights.process_group ) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=True, - fan_in_fan_out=True, - ), layer_id, MLP_C_PROJ, process_group=weights.process_group) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + fan_in_fan_out=True, + ), + layer_id, + MLP_C_PROJ, + process_group=weights.process_group, + ) self.act = ACT2FN[config.activation_function] def forward( - self, + self, hidden_states: Optional[Tuple[torch.FloatTensor]], adapter_data: AdapterBatchData, ) -> torch.FloatTensor: @@ -243,15 +251,11 @@ def __init__(self, layer_id, config, weights): layer_norm_eps = config.layer_norm_epsilon prefix = f"h.{layer_id}" - self.ln_1 = FastLayerNorm.load( - prefix=f"{prefix}.ln_1", weights=weights, eps=layer_norm_eps - ) + self.ln_1 = FastLayerNorm.load(prefix=f"{prefix}.ln_1", weights=weights, eps=layer_norm_eps) self.attn = FlashGPT2Attention( config, prefix=f"{prefix}.attn", weights=weights, layer_id=layer_id ) - self.ln_2 = FastLayerNorm.load( - prefix=f"{prefix}.ln_2", weights=weights, eps=layer_norm_eps - ) + self.ln_2 = FastLayerNorm.load(prefix=f"{prefix}.ln_2", weights=weights, eps=layer_norm_eps) self.mlp = GPT2MLP(config, prefix=f"{prefix}.mlp", weights=weights, layer_id=layer_id) self.process_group = weights.process_group @@ -310,10 +314,7 @@ def __init__(self, config, weights): self.wpe = TensorParallelEmbedding(prefix="wpe", weights=weights) self.h = nn.ModuleList( - [ - GPT2Block(layer_id, config, weights) - for layer_id in range(config.num_hidden_layers) - ] + [GPT2Block(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] ) self.ln_f = FastLayerNorm.load( prefix="ln_f", @@ -396,5 +397,5 @@ def forward( # lm_head reuses the weights of the embedding layer # https://github.com/huggingface/transformers/issues/6291 logits = hidden_states @ self.wte_t - logits = logits[:, :self.transformer.config.vocab_size] + logits = logits[:, : self.transformer.config.vocab_size] return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 04d097c8e..7b69c3c61 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -43,7 +43,16 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) class LlamaConfig(PretrainedConfig): @@ -117,9 +126,7 @@ def forward(self, hidden_states, residual=None): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -149,17 +156,21 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - + def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[ + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -197,9 +208,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) class FlashLlamaAttention(torch.nn.Module): @@ -231,18 +240,21 @@ def __init__( f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -250,9 +262,9 @@ def __init__( def get_query_key_value_weights(self, clone=True): """Gets the query, key, and value weights from the attention layer. - + If `clone`, then the weights are cloned before being returned. - + NOTE: if not `clone`, then the weights are returned as views, meaning that changes to the weights will be reflected in the attention layer. """ @@ -296,9 +308,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -342,9 +352,7 @@ def __init__(self, prefix, config, weights, layer_id): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) # Fuse gate and up proj @@ -356,21 +364,28 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [GATE_PROJ, UP_PROJ], sizes=[ + gate_up_proj, + layer_id, + [GATE_PROJ, UP_PROJ], + sizes=[ config.intermediate_size, config.intermediate_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ), layer_id, DOWN_PROJ, process_group=weights.process_group) - self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) @@ -383,9 +398,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = FlashLlamaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.input_layernorm = LlamaRMSNorm( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -427,9 +447,7 @@ def forward( ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) + normed_attn_res_output, attn_res = self.post_attention_layernorm(attn_output, res) mlp_output = self.mlp(normed_attn_res_output, adapter_data) @@ -443,9 +461,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ FlashLlamaLayer( @@ -456,9 +472,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = LlamaRMSNorm( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps - ) + self.norm = LlamaRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -512,11 +526,16 @@ def __init__(self, config, weights): super().__init__() self.model = FlashLlamaModel(config, weights) - self.lm_head = MultiAdapterHead.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 464c697f0..2b9f4d5e5 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -44,7 +44,16 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) if not HAS_FLASH_ATTN_V2: raise ImportError("Mistral model requires flash attn v2") @@ -123,9 +132,7 @@ def forward(self, hidden_states, residual=None): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -155,17 +162,21 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - + def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[ + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -203,9 +214,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) class MistralAttention(torch.nn.Module): @@ -217,9 +226,7 @@ def __init__( layer_id: int, ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 - ) + self.max_past = config.sliding_window if config.sliding_window is not None else -1 self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -240,18 +247,21 @@ def __init__( f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -259,9 +269,9 @@ def __init__( def get_query_key_value_weights(self, clone=True): """Gets the query, key, and value weights from the attention layer. - + If `clone`, then the weights are cloned before being returned. - + NOTE: if not `clone`, then the weights are returned as views, meaning that changes to the weights will be reflected in the attention layer. """ @@ -358,9 +368,7 @@ def __init__(self, prefix, config, weights, layer_id): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) # Fuse gate and up proj @@ -372,21 +380,28 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [GATE_PROJ, UP_PROJ], sizes=[ + gate_up_proj, + layer_id, + [GATE_PROJ, UP_PROJ], + sizes=[ config.intermediate_size, config.intermediate_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ), layer_id, DOWN_PROJ, process_group=weights.process_group) - self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) @@ -399,9 +414,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = MistralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = MistralMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.input_layernorm = MistralRMSNorm( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -445,9 +465,7 @@ def forward( ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) + normed_attn_res_output, attn_res = self.post_attention_layernorm(attn_output, res) mlp_output = self.mlp(normed_attn_res_output, adapter_data) @@ -461,9 +479,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ MistralLayer( @@ -474,9 +490,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = MistralRMSNorm( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps - ) + self.norm = MistralRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -532,11 +546,16 @@ def __init__(self, config, weights): super().__init__() self.model = MistralModel(config, weights) - self.lm_head = MultiAdapterHead.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) self.max_past = config.sliding_window diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 013d32102..f5e4b215d 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -76,28 +76,28 @@ class MixtralConfig(PretrainedConfig): model_type = "mixtral" def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - max_position_embeddings=4096 * 32, - initializer_range=0.02, - rms_norm_eps=1e-05, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - sliding_window=4096, - num_experts_per_tok=2, - num_local_experts=8, - **kwargs, + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + num_experts_per_tok=2, + num_local_experts=8, + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -138,11 +138,15 @@ def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], sizes=[ + base_layer, + layer_id, + [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -180,9 +184,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) def _load_experts(config, prefix, mat, weights): @@ -195,16 +197,18 @@ def _load_experts(config, prefix, mat, weights): rank = weights.process_group.rank() assert ( - config.intermediate_size % world_size == 0 + config.intermediate_size % world_size == 0 ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" block_size = config.intermediate_size // world_size start = rank * block_size stop = (rank + 1) * block_size - tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), - dtype=weights.dtype, - device=weights.device) + tensor = torch.empty( + (config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) for i in range(config.num_local_experts): slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") @@ -213,7 +217,9 @@ def _load_experts(config, prefix, mat, weights): expert_slice = slice_[:, start:stop].t().contiguous() else: expert_slice = slice_[start:stop] - tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to( + device=weights.device + ) return tensor @@ -236,9 +242,7 @@ def forward(self, hidden_states, residual=None): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -269,6 +273,7 @@ def forward(self, hidden_states, residual=None): return normed_hidden_states, res + class MixtralAttention(torch.nn.Module): """ MixtralAttention module performs attention computation for the Mixtral model. @@ -304,9 +309,7 @@ def __init__( layer_id: int, ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else 0 - ) + self.max_past = config.sliding_window if config.sliding_window is not None else 0 self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -319,7 +322,7 @@ def __init__( dtype=weights.dtype, ) - self.softmax_scale = self.head_size ** -0.5 + self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -327,36 +330,39 @@ def __init__( f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, ATTN_O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + ATTN_O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) def forward( - self, - hidden_states, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - adapter_data, - prefill_cache_indices, + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + prefill_cache_indices, ): """ Performs forward pass of the attention module. @@ -476,9 +482,7 @@ def __init__(self, prefix, config: MixtralConfig, weights): if "gelu" in act: self.act = lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) elif "silu" in act: self.act = torch.nn.functional.silu @@ -530,8 +534,7 @@ def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, self.blocking, block_rows, - blocks_per_row) + column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) # For now, use meta init to save the device memory. data = torch.empty( @@ -575,8 +578,7 @@ def indices_and_padded_bins(self, selected_experts: torch.Tensor): # position of each bin. # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, - self.blocking) + padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) # padded_tokens_per_expert => [128, O, 128, ...] # Cumulative selected experts per token @@ -615,8 +617,7 @@ def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: # Permute tokens and pad to prepare expert computation # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, - self.top_k) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) # Create the sparse matrix topology with torch.no_grad(): @@ -658,7 +659,7 @@ def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(x, group=self.process_group) return x.view(*input_shape) - + def dense_forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (sequence_length, model_dim) @@ -691,18 +692,12 @@ def dense_forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) # Permute to [num_experts, model_dim, ffn_dim] - w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 - ) - w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 - ) + w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(0, 2, 1) + w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(0, 2, 1) inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3) - out = torch.bmm( - inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim) - ) + out = torch.bmm(inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)) # Mask not selected experts out *= weights.t().view(self.num_experts, -1, 1) @@ -734,9 +729,7 @@ def __init__(self, prefix, config: MixtralConfig, weights): if "gelu" in act: self.act = lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) elif "silu" in act: self.act = torch.nn.functional.silu @@ -760,7 +753,11 @@ def __init__(self, prefix, config: MixtralConfig, weights): ] self.w2 = [ TensorParallelRowLinear.load( - config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False, all_reduce=False, + config, + prefix=f"{prefix}.experts.{i}.w2", + weights=weights, + bias=False, + all_reduce=False, ) for i in range(self.num_experts) ] @@ -831,19 +828,19 @@ def __init__(self, layer_id, config, weights): ) def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - adapter_data, - prefill_cache_indices, + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + prefill_cache_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -863,9 +860,7 @@ def forward( ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) + normed_attn_res_output, attn_res = self.post_attention_layernorm(attn_output, res) moe_output = self.moe(normed_attn_res_output) @@ -876,9 +871,7 @@ class MixtralModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ @@ -890,26 +883,24 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = MixtralRMSNorm( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps - ) + self.norm = MixtralRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - adapter_data: AdapterBatchData, - prefill_cache_indices: Optional[torch.Tensor], + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -946,28 +937,33 @@ def __init__(self, config, weights): super().__init__() self.model = MixtralModel(config, weights) - self.lm_head = MultiAdapterHead.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) self.max_past = config.sliding_window if self.max_past is None: raise ValueError("max_past cannot be None") def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - adapter_data: AdapterBatchData, - prefill_cache_indices: Optional[torch.Tensor] = None, - lm_head_indices: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -993,4 +989,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states, adapter_data) - return logits, speculative_logits \ No newline at end of file + return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index cf0d91873..d0c269f24 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -112,9 +112,7 @@ def __init__(self, config, prefix, weights): head_size=self.head_size, hidden_size=self.hidden_size, ) - self.dense = load_row( - config, prefix=f"{prefix}.dense", weights=weights, bias=True - ) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=True) self.kv_head_mapping = torch.arange( 0, self.num_heads, dtype=torch.int32, device=weights.device ) @@ -138,9 +136,7 @@ def forward( self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - paged_attn.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(qkv[:, 0]) @@ -184,9 +180,7 @@ def __init__(self, config, prefix, weights): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) @@ -221,9 +215,7 @@ def __init__(self, layer_id, config, weights): weights=weights, eps=layer_norm_eps, ) - self.attention = FlashNeoxAttention( - config, prefix=f"{prefix}.attention", weights=weights - ) + self.attention = FlashNeoxAttention(config, prefix=f"{prefix}.attention", weights=weights) self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) self.process_group = weights.process_group @@ -280,9 +272,7 @@ def forward( max_s, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) mlp_output = self.mlp(hidden_states) @@ -301,9 +291,7 @@ def __init__(self, config, weights): super().__init__(config) self.config = config - self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights - ) + self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) self.layers = nn.ModuleList( [ @@ -366,9 +354,7 @@ def __init__(self, config, weights): super().__init__(config) self.gpt_neox = FlashGPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( - config, prefix="embed_out", weights=weights - ) + self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 361a5e8fe..da1694b3e 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -40,16 +40,20 @@ ATTN_DENSE = "self_attn.dense" MLP_FC1 = "mlp.fc1" MLP_FC2 = "mlp.fc2" - + def load_attention(config, prefix, weights, layer_id, head_dim, n_head, n_head_kv): base_layer = load_attention_multi(config, prefix, weights, head_dim, n_head, n_head_kv) return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], sizes=[ + base_layer, + layer_id, + [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], + sizes=[ head_dim * n_head, head_dim * n_head_kv, head_dim * n_head_kv, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -80,7 +84,9 @@ def __init__( rope_theta = 10000 config.max_position_embeddings = getattr(config, "n_positions", 2048) - rotary_dim = int(config.partial_rotary_factor * (config.hidden_size // config.num_attention_heads)) + rotary_dim = int( + config.partial_rotary_factor * (config.hidden_size // config.num_attention_heads) + ) self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=rotary_dim, @@ -98,13 +104,26 @@ def __init__( ) self.num_key_value_heads = getattr(config, "n_head_kv", None) or self.num_heads - self.qkv_proj = load_attention(config, prefix, weights, layer_id, self.head_size, self.num_heads, self.num_key_value_heads) - self.dense = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( + self.qkv_proj = load_attention( config, - prefix=f"{prefix}.dense", - weights=weights, - bias=True, - ), layer_id, ATTN_DENSE, process_group=weights.process_group) + prefix, + weights, + layer_id, + self.head_size, + self.num_heads, + self.num_key_value_heads, + ) + self.dense = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.dense", + weights=weights, + bias=True, + ), + layer_id, + ATTN_DENSE, + process_group=weights.process_group, + ) # After initializing layers, scale num heads by num shards for use in forward() to split outputs self.num_heads = self.num_heads // weights.process_group.size() @@ -143,9 +162,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -189,9 +206,7 @@ def __init__(self, prefix, config, weights, layer_id): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) @@ -207,13 +222,18 @@ def __init__(self, prefix, config, weights, layer_id): self.fc1 = TensorParallelMultiAdapterLinear.load( fc1, layer_id, [MLP_FC1], sizes=[out_size], process_group=weights.process_group ) - self.fc2 = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.fc2", - weights=weights, - bias=True, - ), layer_id, MLP_FC2, process_group=weights.process_group) - + self.fc2 = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=True, + ), + layer_id, + MLP_FC2, + process_group=weights.process_group, + ) + def forward(self, hidden_states, adapter_data): hidden_states = self.fc1(hidden_states, adapter_data) hidden_states = self.act(hidden_states) @@ -225,12 +245,15 @@ class FlashPhiLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" - + self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps ) self.self_attn = FlashPhiAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, ) self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.process_group = weights.process_group @@ -280,9 +303,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ FlashPhiLayer( @@ -338,7 +359,7 @@ def forward( max_s, adapter_data, ) - + hidden_states, _ = self.final_layernorm(hidden_states) return hidden_states @@ -348,11 +369,16 @@ def __init__(self, config, weights): super().__init__() self.model = FlashPhiModel(config, weights) - self.lm_head = MultiAdapterHead.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 672e853f4..b3ba8479a 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -60,9 +60,7 @@ def forward(self, hidden_states, residual=None): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -98,13 +96,17 @@ def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], sizes=[ + base_layer, + layer_id, + [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - + def load_attention_multi(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: @@ -140,9 +142,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=True, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=True, quantize=config.quantize)) class FlashQwen2Attention(torch.nn.Module): @@ -155,9 +155,7 @@ def __init__( ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 - ) + self.max_past = config.sliding_window if config.sliding_window is not None else -1 self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -180,18 +178,21 @@ def __init__( f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, ATTN_O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + ATTN_O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -277,9 +278,7 @@ def __init__(self, prefix, config, weights, layer_id): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) # Fuse gate and up proj @@ -291,21 +290,28 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [MLP_GATE_PROJ, MLP_UP_PROJ], sizes=[ + gate_up_proj, + layer_id, + [MLP_GATE_PROJ, MLP_UP_PROJ], + sizes=[ config.intermediate_size // 2, config.intermediate_size // 2, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ), layer_id, MLP_DOWN_PROJ, process_group=weights.process_group) - self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + MLP_DOWN_PROJ, + process_group=weights.process_group, ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) @@ -318,9 +324,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = FlashQwen2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = Qwen2MLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.input_layernorm = Qwen2RMSNorm( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -364,9 +375,7 @@ def forward( ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) + normed_attn_res_output, attn_res = self.post_attention_layernorm(attn_output, res) mlp_output = self.mlp(normed_attn_res_output, adapter_data) @@ -380,9 +389,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ FlashQwen2Layer( @@ -393,9 +400,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = Qwen2RMSNorm( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps - ) + self.norm = Qwen2RMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -451,11 +456,16 @@ def __init__(self, config, weights): super().__init__() self.model = FlashQwen2Model(config, weights) - self.lm_head = MultiAdapterHead.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) self.max_past = config.sliding_window @@ -481,7 +491,7 @@ def forward( # kernel requires the true values max_s = min(self.max_past, max_s) input_lengths = torch.clamp(input_lengths, max=self.max_past) - + hidden_states = self.model( input_ids, position_ids, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 785aab4a3..616a71e76 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -109,9 +109,7 @@ def forward(self, hidden_states, residual=None): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -141,15 +139,21 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - + def load_attention(config, prefix, weights, layer_id): - projection_size = (config.hidden_size // config.num_attention_heads) * config.num_attention_heads + projection_size = ( + config.hidden_size // config.num_attention_heads + ) * config.num_attention_heads base_layer = load_attention_multi(config, prefix, weights, projection_size) return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_C_ATTN], sizes=[ + base_layer, + layer_id, + [ATTN_C_ATTN], + sizes=[ 3 * projection_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -179,7 +183,9 @@ def __init__( self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.projection_size = (self.head_size * config.num_attention_heads) // weights.process_group.size() + self.projection_size = ( + self.head_size * config.num_attention_heads + ) // weights.process_group.size() self.process_group = weights.process_group self.rotary_emb = PositionRotaryEmbedding.static( @@ -202,12 +208,17 @@ def __init__( self.c_attn = load_attention(config, prefix, weights, layer_id) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=False, - ), layer_id, ATTN_C_PROJ, process_group=weights.process_group) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=False, + ), + layer_id, + ATTN_C_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -240,9 +251,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -286,9 +295,7 @@ def __init__(self, prefix, config, weights, layer_id): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) # Fuse gate and up proj @@ -300,21 +307,28 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [MLP_W2, MLP_W1], sizes=[ + gate_up_proj, + layer_id, + [MLP_W2, MLP_W1], + sizes=[ config.intermediate_size // 2, config.intermediate_size // 2, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=False, - ), layer_id, MLP_C_PROJ, process_group=weights.process_group) - self.intermediate_size = ( - config.intermediate_size // weights.process_group.size() + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=False, + ), + layer_id, + MLP_C_PROJ, + process_group=weights.process_group, ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) @@ -327,9 +341,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" self.attn = FlashQwenAttention( - prefix=f"{prefix}.attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = QwenMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = QwenMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.ln_1 = QwenRMSNorm( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon @@ -371,9 +390,7 @@ def forward( ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.ln_2( - attn_output, res - ) + normed_attn_res_output, attn_res = self.ln_2(attn_output, res) mlp_output = self.mlp(normed_attn_res_output, adapter_data) @@ -387,9 +404,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.wte = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights - ) + self.wte = TensorParallelEmbedding(prefix="transformer.wte", weights=weights) self.h = nn.ModuleList( [ FlashQwenLayer( @@ -426,9 +441,7 @@ def forward( # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.h[0].attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.h[0].attn.rotary_emb.get_cos_sin(position_ids, max_s, hidden_states.dtype) residual = None for i, layer in enumerate(self.h): @@ -456,11 +469,16 @@ def __init__(self, config, weights): super().__init__() self.transformer = FlashQwenModel(config, weights) - self.lm_head = MultiAdapterHead.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 865e666fe..aede9e8cd 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -64,9 +64,7 @@ def __init__( **kwargs, ): if alibi: - raise NotImplementedError( - "alibi is not supported by this version of the model" - ) + raise NotImplementedError("alibi is not supported by this version of the model") self.model_type = model_type self.alibi = False @@ -77,14 +75,10 @@ def __init__( n_embed = kwargs.pop("n_embed", None) self.hidden_size = hidden_size if n_embed is None else n_embed self.n_layer = ( - num_hidden_layers - if num_hidden_layers is not None - else kwargs.pop("n_layer", 2) + num_hidden_layers if num_hidden_layers is not None else kwargs.pop("n_layer", 2) ) self.n_head = ( - num_attention_heads - if num_attention_heads is not None - else kwargs.pop("n_head", 8) + num_attention_heads if num_attention_heads is not None else kwargs.pop("n_head", 8) ) self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range @@ -147,9 +141,7 @@ def __init__( weights=weights, bias=config.bias, ) - self.dense = load_row( - config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias - ) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) if self.num_heads_kv == 1: self.kv_head_mapping = torch.zeros( @@ -188,9 +180,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -267,9 +257,7 @@ def __init__( weights=weights, bias=config.bias, ) - self.dense = load_row( - config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias - ) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) self.kv_head_mapping = torch.arange( 0, self.num_groups, dtype=torch.int32, device=weights.device @@ -340,9 +328,7 @@ def forward( max_s, ) - return self.dense( - attn_output.view(-1, self.num_groups * self.num_heads * self.head_size) - ) + return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)) class FlashMLP(nn.Module): @@ -456,9 +442,7 @@ def forward( max_s, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) mlp_output = self.mlp(hidden_states) @@ -614,9 +598,7 @@ def __init__(self, config, weights): self.transformer = FlashRWModel(config, weights) - self.lm_head = TensorParallelHead.load( - config, prefix="lm_head", weights=weights - ) + self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 70b9e6128..232d58270 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,17 +17,13 @@ ) -def load_multi_mqa( - config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size -): +def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size): if config.quantize in ["gptq", "awq", "eetq"]: return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) else: - return _load_multi_mqa( - config, prefix, weights, bias, head_size, num_heads, hidden_size - ) + return _load_multi_mqa(config, prefix, weights, bias, head_size, num_heads, hidden_size) def _load_multi_mqa_gptq( @@ -97,9 +93,7 @@ def _load_multi_mqa_gptq( raise NotImplementedError("Gptq loading with santacoder is not implemented") -def _load_multi_mqa( - config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size -): +def _load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size): if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -171,9 +165,7 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=0 - ) + weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -226,12 +218,8 @@ def __init__(self, prefix, config, weights): hidden_size=hidden_size, num_heads=self.num_heads, ) - self.c_proj = load_row( - config, prefix=f"{prefix}.c_proj", weights=weights, bias=True - ) - self.kv_head_mapping = torch.zeros( - self.num_heads, dtype=torch.int32, device=weights.device - ) + self.c_proj = load_row(config, prefix=f"{prefix}.c_proj", weights=weights, bias=True) + self.kv_head_mapping = torch.zeros(self.num_heads, dtype=torch.int32, device=weights.device) def forward( self, @@ -246,9 +234,7 @@ def forward( qkv = self.c_attn(hidden_states) # Split query from key_value - query, key_value = qkv.split( - [self.head_size * self.num_heads, 2 * self.head_size], dim=1 - ) + query, key_value = qkv.split([self.head_size * self.num_heads, 2 * self.head_size], dim=1) # Prepare query and key_value for indexing query = query.view(-1, self.num_heads, self.head_size) @@ -300,18 +286,12 @@ def __init__(self, prefix, config, weights): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) - self.c_fc = load_col( - config, prefix=f"{prefix}.c_fc", weights=weights, bias=True - ) - self.c_proj = load_row( - config, prefix=f"{prefix}.c_proj", weights=weights, bias=True - ) + self.c_fc = load_col(config, prefix=f"{prefix}.c_fc", weights=weights, bias=True) + self.c_proj = load_row(config, prefix=f"{prefix}.c_proj", weights=weights, bias=True) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) @@ -442,9 +422,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() self.transformer = FlashSantacoderModel(config, weights) - self.lm_head = TensorParallelHead.load( - config, prefix="transformer.wte", weights=weights - ) + self.lm_head = TensorParallelHead.load(config, prefix="transformer.wte", weights=weights) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/mpt_modeling.py b/server/lorax_server/models/custom_modeling/mpt_modeling.py index 7d1ea1f1c..bb1a421fc 100644 --- a/server/lorax_server/models/custom_modeling/mpt_modeling.py +++ b/server/lorax_server/models/custom_modeling/mpt_modeling.py @@ -50,9 +50,7 @@ def load_col(config, prefix, weights, bias): return TensorParallelColumnLinear(linear) -def _reset_is_causal( - num_query_tokens: int, num_key_tokens: int, original_is_causal: bool -): +def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: raise NotImplementedError( @@ -113,9 +111,7 @@ def scaled_multihead_dot_product_attention( + "into attn_bias once and passing that to each attention " + "module instead." ) - attn_weight = attn_weight.masked_fill( - ~key_padding_mask.view((b, 1, 1, s_k)), min_val - ) + attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) if is_causal and (not q.size(2) == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) @@ -143,9 +139,7 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." ) if not tensor.is_cuda: - raise TypeError( - f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." - ) + raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).") def flash_attn_fn( @@ -187,21 +181,13 @@ def flash_attn_fn( query, query_padding_mask ) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( - key, key_padding_mask - ) - key_unpad = rearrange( - key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) + key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange( - value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) + value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads) if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand( - value_unpad.size(0), n_heads, value_unpad.size(-1) - ) + value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) dropout_p = dropout_p if training else 0.0 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func( @@ -287,9 +273,7 @@ def triton_flash_attn_fn( key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_func( - query, key, value, attn_bias, reset_is_causal, softmax_scale - ) + attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None, past_key_value) @@ -451,9 +435,7 @@ def forward( qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2 - ) + (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype @@ -477,9 +459,7 @@ def forward( return (self.out_proj(context), attn_weights, past_key_value) -def attn_bias_shape( - attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id -): +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): if attn_impl == "flash": return None elif attn_impl in ["torch", "triton"]: @@ -527,9 +507,7 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None): return slopes.view(1, n_heads, 1, 1) -def build_alibi_bias( - n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None -): +def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None): alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( 1, 1, 1, seq_len ) @@ -577,16 +555,10 @@ def __init__(self, config, prefix, weights): super().__init__() self.prefix = prefix if config.attn_config["attn_type"] != "multihead_attention": - raise NotImplementedError( - f"""Not implemented attn {config.attn_config["attn_type"]}""" - ) + raise NotImplementedError(f"""Not implemented attn {config.attn_config["attn_type"]}""") resid_pdrop = config.resid_pdrop - self.norm_1 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) + self.norm_1 = nn.LayerNorm.load_no_bias(prefix=f"{prefix}.norm_1", weights=weights, eps=EPS) + self.norm_2 = nn.LayerNorm.load_no_bias(prefix=f"{prefix}.norm_2", weights=weights, eps=EPS) self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) self.resid_attn_dropout = nn.Dropout(resid_pdrop) @@ -648,13 +620,9 @@ def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - downcast_bias = ( - _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) + downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm( downcast_x, @@ -673,9 +641,7 @@ def rms_norm(x, weight=None, eps=1e-05): class RMSNorm(torch.nn.Module): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): + def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): super().__init__() self.eps = eps if weight: @@ -690,9 +656,7 @@ def forward(self, x): class LPRMSNorm(RMSNorm): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): + def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): super().__init__( normalized_shape=normalized_shape, eps=eps, @@ -704,9 +668,7 @@ def __init__( def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight + _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) @@ -806,9 +768,7 @@ def _attn_bias( ): if not self._attn_bias_initialized: if self.attn_bias_shape: - self.attn_bias = torch.zeros( - self.attn_bias_shape, device=device, dtype=dtype - ) + self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) self.attn_bias = build_attn_bias( self.attn_impl, self.attn_bias, @@ -849,9 +809,7 @@ def _attn_bias( + f"and prefix_mask shape={prefix_mask.shape} are not equal." ) min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill( - ~attention_mask.view(-1, 1, 1, s_k), min_val - ) + attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): @@ -877,9 +835,7 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias - def _apply_sequence_id( - self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor - ): + def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( @@ -905,18 +861,14 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if attention_mask is not None: attention_mask = attention_mask.bool() if prefix_mask is not None: prefix_mask = prefix_mask.bool() if not return_dict: - raise NotImplementedError( - "return_dict False is not implemented yet for MPT" - ) + raise NotImplementedError("return_dict False is not implemented yet for MPT") if output_attentions: if self.attn_impl != "torch": raise NotImplementedError( @@ -927,9 +879,7 @@ def forward( and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training ): - raise NotImplementedError( - "MPT does not support training with left padding." - ) + raise NotImplementedError("MPT does not support training with left padding.") if self.prefix_lm and prefix_mask is None: raise ValueError( "prefix_mask is a required argument when MPT is configured with prefix_lm=True." @@ -975,10 +925,7 @@ def forward( ).unsqueeze(0) if attention_mask is not None: pos = torch.clamp( - pos - - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ - :, past_position: - ], + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0, ) pos_emb = self.wpe(pos) @@ -998,9 +945,7 @@ def forward( if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - past_key_value = ( - past_key_values[b_idx] if past_key_values is not None else None - ) + past_key_value = past_key_values[b_idx] if past_key_values is not None else None (x, attn_weights, past_key_value) = block( x, past_key_value=past_key_value, @@ -1031,9 +976,7 @@ def __init__(self, config, weights): if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = MPTModel(config, weights) - self.lm_head = TensorParallelHead.load( - config, prefix="transformer.wte", weights=weights - ) + self.lm_head = TensorParallelHead.load(config, prefix="transformer.wte", weights=weights) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale @@ -1059,9 +1002,7 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache outputs = self.transformer( input_ids=input_ids, @@ -1103,9 +1044,7 @@ def prepare_inputs_for_generation( raise NotImplementedError("inputs_embeds is not implemented for MPT yet") attention_mask = kwargs["attention_mask"].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError( - "MPT does not support generation with right padding." - ) + raise NotImplementedError("MPT does not support generation with right padding.") if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: @@ -1139,8 +1078,6 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past = [] for layer_past in past_key_values: reordered_past += [ - tuple( - (past_state.index_select(0, beam_idx) for past_state in layer_past) - ) + tuple((past_state.index_select(0, beam_idx) for past_state in layer_past)) ] return reordered_past diff --git a/server/lorax_server/models/custom_modeling/neox_modeling.py b/server/lorax_server/models/custom_modeling/neox_modeling.py index 1ce018ddd..df9cae6f9 100644 --- a/server/lorax_server/models/custom_modeling/neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/neox_modeling.py @@ -147,9 +147,7 @@ def __init__(self, config, prefix, weights): config.max_position_embeddings, base=config.rotary_emb_base, ) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) + self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")) self.inv_norm_factor = 1.0 / torch.sqrt( torch.tensor(self.head_size, dtype=torch.float32) ).to(torch.get_default_dtype()) @@ -160,9 +158,7 @@ def __init__(self, config, prefix, weights): f"(got `num_attention_heads`: {self.num_attention_heads} " f"and `num_shards`: {weights.process_group.size()}" ) - self.num_attention_heads = ( - self.num_attention_heads // weights.process_group.size() - ) + self.num_attention_heads = self.num_attention_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True ) @@ -230,14 +226,10 @@ def forward( present = (key, value) if use_cache else None # Compute attention - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask - ) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) # Reshape outputs - attn_output = self._merge_heads( - attn_output, self.num_attention_heads, self.head_size - ) + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self.dense(attn_output) @@ -268,9 +260,7 @@ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): # tensor [bs, num_attention_heads, seq_len, attn_head_size] tensor = tensor.permute(0, 2, 1, 3).contiguous() # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view( - tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size - ) + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) # -> [bs, seq_len, hidden_size] return tensor @@ -280,9 +270,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - query = query.view( - batch_size * num_attention_heads, query_length, attn_head_size - ) + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) attn_scores = torch.zeros( 1, @@ -301,12 +289,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): input_dtype = attn_scores.dtype if input_dtype in [torch.float16, torch.bfloat16]: attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where( - attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores - ) - attn_scores = attn_scores.view( - batch_size, num_attention_heads, query_length, key_length - ) + attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = attn_weights.to(value.dtype) @@ -322,9 +306,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - self.true_inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2).float().to(device) / dim) - ) + self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", self.true_inv_freq) # Build here to make `torch.jit.trace` work. @@ -341,9 +323,7 @@ def rotate_half(x): @staticmethod def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange( - max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype - ) + t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -351,11 +331,7 @@ def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): def forward(self, q, k, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if ( - seq_len > self.max_seq_len_cached - or self.cos_cached is None - or self.sin_cached is None - ): + if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None: if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len self.cos_cached, self.sin_cached = self._create_cos_sin( @@ -420,9 +396,7 @@ def __init__(self, layer_id, config, weights): self.attention = GPTNeoXAttention( config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights ) - self.mlp = GPTNeoXMLP( - config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights - ) + self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights) def forward( self, @@ -462,9 +436,7 @@ def forward( hidden_states = mlp_output + attn_output if use_cache: - outputs = ( - hidden_states, - ) + outputs # hidden_states, present, (attn_weights) + outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) else: outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) @@ -478,9 +450,7 @@ def __init__(self, config, weights): self.num_attention_heads = config.num_attention_heads - self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights - ) + self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) self.layers = nn.ModuleList( [ GPTNeoXLayer(layer_id, config, weights) @@ -518,24 +488,18 @@ def forward( `past_key_values`). """ output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: @@ -643,9 +607,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.gpt_neox = GPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( - config, prefix="embed_out", weights=weights - ) + self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) def forward( self, @@ -700,9 +662,7 @@ def forward( >>> prediction_logits = outputs.logits ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.gpt_neox( input_ids, @@ -728,9 +688,7 @@ def forward( shift_logits = lm_logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -790,10 +748,7 @@ def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past[:2] - ) + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past diff --git a/server/lorax_server/models/custom_modeling/opt_modeling.py b/server/lorax_server/models/custom_modeling/opt_modeling.py index d14294230..961d6dd0e 100644 --- a/server/lorax_server/models/custom_modeling/opt_modeling.py +++ b/server/lorax_server/models/custom_modeling/opt_modeling.py @@ -60,16 +60,12 @@ def _make_causal_mask( if past_key_values_length > 0: mask = torch.cat( [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), + torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask, ], dim=-1, ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): @@ -83,9 +79,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class OPTLearnedPositionalEmbedding(nn.Module): @@ -96,13 +90,9 @@ class OPTLearnedPositionalEmbedding(nn.Module): def __init__(self, weights): super().__init__() self.offset = 2 - self.weight = nn.Parameter( - weights.get_tensor("model.decoder.embed_positions.weight") - ) + self.weight = nn.Parameter(weights.get_tensor("model.decoder.embed_positions.weight")) - def forward( - self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 - ): + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long() @@ -169,11 +159,7 @@ def __init__( ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, @@ -243,10 +229,7 @@ def forward( raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) ) @@ -254,9 +237,9 @@ def forward( # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 if attn_weights.dtype == torch.float16: - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(torch.float16) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + torch.float16 + ) else: attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -276,18 +259,12 @@ def forward( # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) @@ -347,9 +324,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` @@ -380,9 +355,7 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention @@ -402,9 +375,7 @@ def forward( hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = (residual + hidden_states).view(hidden_states_shape) @@ -449,9 +420,7 @@ def __init__(self, config: OPTConfig, weights): self.project_out = None if config.word_embed_proj_dim != config.hidden_size: - self.project_in = FastLinear.load( - config, prefix="model.decoder.project_in", bias=False - ) + self.project_in = FastLinear.load(config, prefix="model.decoder.project_in", bias=False) else: self.project_in = None @@ -560,9 +529,7 @@ def forward( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states @@ -571,9 +538,7 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: @@ -602,9 +567,7 @@ def forward( # embed positions if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) causal_attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -638,9 +601,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) + past_key_value = past_key_values[idx] if past_key_values is not None else None layer_outputs = decoder_layer( hidden_states, @@ -703,9 +664,7 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states @@ -713,9 +672,7 @@ def forward( else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( @@ -765,18 +722,14 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.decoder( @@ -834,8 +787,6 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past diff --git a/server/lorax_server/models/custom_modeling/t5_modeling.py b/server/lorax_server/models/custom_modeling/t5_modeling.py index ebef36b6b..a4368f198 100644 --- a/server/lorax_server/models/custom_modeling/t5_modeling.py +++ b/server/lorax_server/models/custom_modeling/t5_modeling.py @@ -92,9 +92,7 @@ def forward(self, hidden_states): T5LayerNorm = FusedRMSNorm # noqa - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" - ) + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm") except ImportError: # using the normal T5LayerNorm pass @@ -219,9 +217,7 @@ def forward(self, hidden_states): class T5Attention(nn.Module): - def __init__( - self, config: T5Config, prefix, weights, has_relative_attention_bias=False - ): + def __init__(self, config: T5Config, prefix, weights, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -291,9 +287,7 @@ def _relative_position_bucket( relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -311,24 +305,16 @@ def _relative_position_bucket( torch.full_like(relative_position_if_large, num_buckets - 1), ) - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets def compute_bias(self, query_length, key_length, device=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), @@ -370,25 +356,19 @@ def forward( assert ( len(past_key_value) == 2 ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length - ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose( + 1, 2 + ) def unshape(states): """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -449,9 +429,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): dtype=scores.dtype, ) else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device - ) + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -529,9 +507,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs @@ -576,9 +552,7 @@ def forward( output_attentions=output_attentions, ) layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs @@ -598,16 +572,12 @@ def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): if self.is_decoder: i = 2 self.layer.append( - T5LayerCrossAttention( - config, prefix=f"{prefix}.layer.1", weights=weights - ) + T5LayerCrossAttention(config, prefix=f"{prefix}.layer.1", weights=weights) ) else: i = 1 - self.layer.append( - T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) - ) + self.layer.append(T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights)) def forward( self, @@ -664,9 +634,7 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: @@ -697,15 +665,11 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Combine self attn and cross attn key value states if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) + present_key_value_state = present_key_value_state + cross_attention_outputs[1] # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -720,9 +684,7 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) @@ -754,20 +716,14 @@ def _shift_right(self, input_ids): # shift inputs to the right if is_torch_fx_proxy(input_ids): # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) else: shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = decoder_start_token_id - assert ( - pad_token_id is not None - ), "self.model.config.pad_token_id has to be defined." + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) @@ -817,18 +773,14 @@ def forward( # Model parallel use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" @@ -867,14 +819,8 @@ def forward( ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - if ( - self.is_decoder - and encoder_attention_mask is None - and encoder_hidden_states is not None - ): + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: encoder_seq_length = encoder_hidden_states.shape[1] encoder_attention_mask = torch.ones( batch_size, @@ -889,9 +835,7 @@ def forward( # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -906,17 +850,13 @@ def forward( encoder_attention_mask = torch.ones( encoder_hidden_shape, device=inputs_embeds.device ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -926,9 +866,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel @@ -961,14 +899,10 @@ def forward( # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] # append next layer key value states if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) + present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1032,9 +966,7 @@ def __init__(self, config: T5Config, weights): embed_tokens=self.shared, ) - self.lm_head = TensorParallelHead.load( - config, prefix="lm_head", weights=weights - ) + self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) def forward( self, @@ -1056,9 +988,7 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: @@ -1087,11 +1017,7 @@ def forward( hidden_states = encoder_outputs[0] - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -1193,15 +1119,11 @@ def _reorder_cache(self, past_key_values, beam_idx): for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), ) assert reordered_layer_past_states[0].shape == layer_past_states[0].shape assert len(reordered_layer_past_states) == len(layer_past_states) - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) return reordered_decoder_past diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index c27ba2461..cbdaedd8f 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -184,13 +184,11 @@ def from_pb( max_blocks = 0 # Parse batch - for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) - ): + for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate:] + tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -209,9 +207,7 @@ def from_pb( next_token_chooser_parameters.append(r.parameters) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) + stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) @@ -253,16 +249,12 @@ def from_pb( if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) + prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) + torch.tensor([cumulative_length + input_length - 1], dtype=torch.int32) ) prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) @@ -278,8 +270,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) - for r in pb.requests + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, request_tokenizers, dtype, device @@ -287,16 +278,12 @@ def from_pb( start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor - all_input_ids_tensor = np.zeros( - (len(all_input_ids), max_length), dtype=np.int64 - ) + all_input_ids_tensor = np.zeros((len(all_input_ids), max_length), dtype=np.int64) for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) + all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) @@ -311,18 +298,14 @@ def from_pb( if SLIDING_WINDOW is not None: prefill_cache_indices = prefill_cache_indices[0] - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) + cu_seqlen_prefill = torch.tensor(cu_seqlen_prefill, device=device, dtype=torch.int32) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) if SLIDING_WINDOW is not None: prefill_cache_indices = prefill_cache_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) + input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32, device=device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) @@ -395,9 +378,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": indices = [] # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) + slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool, device=device) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -442,9 +423,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_set.add(self.requests[idx].adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) + remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens request_block_table = self.block_tables[idx] blocks += len(request_block_table) @@ -456,10 +435,10 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Set slice slot_filtering_indices[ - self.start_slots[idx]: self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 + self.start_slots[idx] : self.start_slots[idx] + + request_input_length + + remaining_tokens + - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 @@ -486,7 +465,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) - speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None + speculative_ids = ( + self.speculative_ids[indices] if self.speculative_ids is not None else None + ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -570,9 +551,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( - total_batch_size - ) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) @@ -627,20 +606,26 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Copy over adapter indices adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] - adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices + adapter_end_index = ( + cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[ + adapter_start_index:adapter_end_index + ] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) # Update adapter segments - adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] + start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] block_tables_tensor[ - start_index:end_index, : batch.block_tables_tensor.shape[1] + start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] start_slots.append(batch.start_slots + cumulative_slots) @@ -654,7 +639,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) if batch.next_token_chooser.schema_processor is not None: - sequence_processors.extend(batch.next_token_chooser.schema_processor.sequence_processors) + sequence_processors.extend( + batch.next_token_chooser.schema_processor.sequence_processors + ) else: # No sequence processors, so pad with Nones sequence_processors.extend([None for _ in batch.requests]) @@ -676,7 +663,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch speculative_ids = ( torch.cat([b.speculative_ids for b in batches], dim=0) - if batches[0].speculative_ids is not None else None + if batches[0].speculative_ids is not None + else None ) adapter_segments, adapter_segment_indices = adapter_segment_builder.build() @@ -725,9 +713,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch def __del__(self): if self.block_tables is not None and self.block_tables: # Free blocks - get_cache_manager().free( - list(itertools.chain.from_iterable(self.block_tables)) - ) + get_cache_manager().free(list(itertools.chain.from_iterable(self.block_tables))) def __len__(self): return len(self.requests) @@ -780,7 +766,7 @@ def __init__( self.compile = compile self.model_graph_wrapper: GraphCache = None - + @property def sliding_window_blocks(self) -> Optional[int]: return SLIDING_WINDOW_BLOCKS @@ -788,7 +774,7 @@ def sliding_window_blocks(self) -> Optional[int]: @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch - + def adapter_memory_size(self) -> int: total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory return ADAPTER_MEMORY_FRACTION * total_gpu_memory @@ -835,11 +821,11 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set. self.model_graph_wrapper = GraphCache( - self.model, - self.device, - self.adapter_layers, - max_total_tokens, - self.sliding_window_blocks + self.model, + self.device, + self.adapter_layers, + max_total_tokens, + self.sliding_window_blocks, ) graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory() logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024) @@ -857,14 +843,15 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory free_memory = max( - 0, total_free_memory - (1 - MEMORY_FRACTION + ADAPTER_MEMORY_FRACTION) * total_gpu_memory + 0, + total_free_memory - (1 - MEMORY_FRACTION + ADAPTER_MEMORY_FRACTION) * total_gpu_memory, ) logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024) num_blocks = ( - int(free_memory // total_cache_size) - # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + cache_manager.num_blocks + int(free_memory // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + cache_manager.num_blocks ) del batch @@ -894,14 +881,16 @@ def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) - - def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + + def forward( + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = batch.cu_seqlen_prefill is not None model = self.model if ( - self.model_graph_wrapper is not None and - not prefill and - self.model_graph_wrapper.can_use_graph(batch, adapter_data) + self.model_graph_wrapper is not None + and not prefill + and self.model_graph_wrapper.can_use_graph(batch, adapter_data) ): model = self.model_graph_wrapper @@ -915,16 +904,23 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> if batch.speculative_ids is not None: speculative_ids = batch.speculative_ids - B, speculative_length = speculative_ids.shape + B, speculative_length = speculative_ids.shape new_length = speculative_length + 1 new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view( + -1 + ) - block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) max_s = max_s + speculative_length input_ids = new_input_ids @@ -954,7 +950,9 @@ def generate_token( ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) + return_alternatives = any( + req.parameters.return_k_alternatives > 0 for req in batch.requests + ) if batch.needed_blocks_slots: # Allocate blocks to this batch @@ -974,7 +972,9 @@ def generate_token( if batch.speculative_ids is not None: B, speculative_length = batch.speculative_ids.shape new_length = speculative_length + 1 - adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) + ) adapter_segments = adapter_meta.adapter_segments * new_length adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, @@ -994,27 +994,34 @@ def generate_token( raise e if prefill: - next_token_logits = ( - out[batch.prefill_next_token_indices] if prefill_logprobs else out - ) + next_token_logits = out[batch.prefill_next_token_indices] if prefill_logprobs else out if speculative_logits is not None: speculative_logits = ( - speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits + speculative_logits[batch.prefill_next_token_indices] + if prefill_logprobs + else speculative_logits ) else: next_token_logits = out speculative_tokens = get_speculative_tokens() - next_input_ids, next_token_logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( + ( + next_input_ids, + next_token_logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], - next_token_logits, - speculative_tokens, - batch.speculative_ids, + next_token_logits, + speculative_tokens, + batch.speculative_ids, speculative_logits, ) if return_alternatives: - alternative_token_logprobs, alternative_token_ids = torch.sort(torch.log_softmax(next_token_logprobs, -1), dim=-1, stable=True, descending=True) + alternative_token_logprobs, alternative_token_ids = torch.sort( + torch.log_softmax(next_token_logprobs, -1), dim=-1, stable=True, descending=True + ) if prefill: if len(batch) > 1 and prefill_logprobs: @@ -1056,9 +1063,9 @@ def generate_token( # For each member of the batch idx = 0 for i, ( - input_length, - all_input_ids, - num_accepted_ids, + input_length, + all_input_ids, + num_accepted_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -1083,13 +1090,13 @@ def generate_token( if prefill_logprobs: if len(batch) > 1: prefill_tokens_indices[ - out_start_index: out_end_index - 1 - ] = batch.input_ids[start_index + 1: start_index + out_length] + out_start_index : out_end_index - 1 + ] = batch.input_ids[start_index + 1 : start_index + out_length] else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ - start_index + 1: start_index + out_length - ] + start_index + 1 : start_index + out_length + ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] @@ -1098,7 +1105,7 @@ def generate_token( idx += 1 cumulative_length += input_length - + # Set values in batch batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.position_ids = next_position_ids + accepted_ids @@ -1149,15 +1156,15 @@ def generate_token( # For each member of the batch idx = 0 for i, ( - request, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - num_accepted_ids, + request, + input_length, + prefix_offset, + read_offset, + stopping_criteria, + all_input_ids, + do_sample, + seed, + num_accepted_ids, ) in enumerate(iterator): all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None next_token_texts = [] @@ -1165,7 +1172,7 @@ def generate_token( current_stopped = False for j in range(num_accepted_ids): token_idx = idx + j - + # Generated token next_token_id = next_token_ids[token_idx] all_input_ids.append(next_token_id) @@ -1178,11 +1185,18 @@ def generate_token( if request.parameters.return_k_alternatives > 0: # Limit the number of alternatives to the vocabulary size - num_alternatives = min(request.parameters.return_k_alternatives, len(alternative_token_ids[token_idx])) + num_alternatives = min( + request.parameters.return_k_alternatives, + len(alternative_token_ids[token_idx]), + ) # Select top-k logprobs - request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] - request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] + request_alternative_token_ids = alternative_token_ids[token_idx][ + :num_alternatives + ] + request_alternative_token_logprobs = alternative_token_logprobs[token_idx][ + :num_alternatives + ] # Decode tokens request_alternative_token_texts = [] @@ -1198,7 +1212,7 @@ def generate_token( alternative_tokens = AlternativeTokens( request_alternative_token_ids, request_alternative_token_logprobs, - request_alternative_token_texts + request_alternative_token_texts, ) all_alternative_tokens.append(alternative_tokens) @@ -1215,8 +1229,8 @@ def generate_token( current_stopped = False stopped = stopped and current_stopped - accepted_token_ids = next_token_ids[idx: idx + num_accepted_ids - left] - accepted_token_logprobs = next_token_logprobs[idx: idx + num_accepted_ids - left] + accepted_token_ids = next_token_ids[idx : idx + num_accepted_ids - left] + accepted_token_logprobs = next_token_logprobs[idx : idx + num_accepted_ids - left] idx += num_accepted_ids # Shard generations @@ -1224,9 +1238,7 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens:] - ) + output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :]) generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, @@ -1243,8 +1255,8 @@ def generate_token( # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index: out_end_index - 1 - ] + out_start_index : out_end_index - 1 + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/lorax_server/models/flash_gemma.py b/server/lorax_server/models/flash_gemma.py index 0030907d9..c68bd2782 100644 --- a/server/lorax_server/models/flash_gemma.py +++ b/server/lorax_server/models/flash_gemma.py @@ -63,10 +63,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -90,33 +90,42 @@ def __init__( adapter_id=adapter_id, adapter_source=adapter_source, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_proj) layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.up_proj) layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - + return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index 426e18fd8..9782d68ce 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -67,11 +67,11 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -99,7 +99,7 @@ def __init__( @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} @@ -114,13 +114,13 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: # TODO: make Embedding layers adapter-compatible # layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 82dde9199..19c053dc4 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -18,7 +18,16 @@ Weights, ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) tracer = trace.get_tracer(__name__) @@ -64,10 +73,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -91,34 +100,43 @@ def __init__( adapter_id=adapter_id, adapter_source=adapter_source, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index b4fa228b5..19f6735df 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -17,7 +17,16 @@ Weights, ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) tracer = trace.get_tracer(__name__) @@ -61,10 +70,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -93,30 +102,39 @@ def __init__( @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 6c89c2b81..ead6cece7 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -48,7 +48,12 @@ SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None -ADAPTER_LAYERS = [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ, ATTN_O_PROJ] # TODO(travis): add back LM_HEAD following https://github.com/predibase/lorax/issues/231 +ADAPTER_LAYERS = [ + ATTN_Q_PROJ, + ATTN_K_PROJ, + ATTN_V_PROJ, + ATTN_O_PROJ, +] # TODO(travis): add back LM_HEAD following https://github.com/predibase/lorax/issues/231 ROW_PARALLEL = {ATTN_O_PROJ, LM_HEAD} @@ -120,9 +125,7 @@ def from_pb( max_blocks = 0 # Parse batch - for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) - ): + for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): # request id -> idx in list mapping requests_idx_mapping[r.id] = i @@ -145,9 +148,7 @@ def from_pb( next_token_chooser_parameters.append(r.parameters) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) + stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) # top_n_tokens.append(r.top_n_tokens) @@ -160,9 +161,7 @@ def from_pb( total_tokens = input_length + max_new_tokens - 1 # Needed blocks can not go over SLIDING_WINDOW_BLOCKS - needed_blocks = min( - math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS - ) + needed_blocks = min(math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -188,16 +187,12 @@ def from_pb( if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) + prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) + torch.tensor([cumulative_length + input_length - 1], dtype=torch.int32) ) prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) @@ -214,8 +209,7 @@ def from_pb( print("!!! ADAPTER INDICES", adapter_indices) request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) - for r in pb.requests + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, request_tokenizers, dtype, device @@ -223,16 +217,12 @@ def from_pb( start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor - all_input_ids_tensor = np.zeros( - (len(all_input_ids), max_length), dtype=np.int64 - ) + all_input_ids_tensor = np.zeros((len(all_input_ids), max_length), dtype=np.int64) for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) + all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) @@ -245,17 +235,13 @@ def from_pb( slot_indices = slot_indices[0] prefill_cache_indices = prefill_cache_indices[0] - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) + cu_seqlen_prefill = torch.tensor(cu_seqlen_prefill, device=device, dtype=torch.int32) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) prefill_cache_indices = prefill_cache_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) + input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32, device=device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) @@ -362,10 +348,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -399,16 +385,18 @@ def supports_adapter_loading(self) -> bool: def batch_type(self) -> Type[FlashMixtralBatch]: return FlashMixtralBatch - def forward(self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = batch.cu_seqlen_prefill is not None model = self.model if ( - self.model_graph_wrapper is not None and - not prefill and - self.model_graph_wrapper.can_use_graph(batch, adapter_data) + self.model_graph_wrapper is not None + and not prefill + and self.model_graph_wrapper.can_use_graph(batch, adapter_data) ): model = self.model_graph_wrapper - + # Model Forward logits = model.forward( input_ids=batch.input_ids, @@ -426,31 +414,43 @@ def forward(self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData) -> T if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, ATTN_Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + layer_weights[(i, ATTN_Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) # TODO(travis): requires implementing this for block sparse MoE # layer_weights[(i, MOE_W1)] = (f"{prefix}.{i}.moe.w1", layer.moe.w1) # layer_weights[(i, MOE_W2)] = (f"{prefix}.{i}.moe.w2", layer.moe.w2) # layer_weights[(i, MOE_W3)] = (f"{prefix}.{i}.moe.w3", layer.moe.w3) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_neox.py b/server/lorax_server/models/flash_neox.py index 5eec8f946..ae94a86e6 100644 --- a/server/lorax_server/models/flash_neox.py +++ b/server/lorax_server/models/flash_neox.py @@ -50,9 +50,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) + weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq", "eetq"]: weights._set_gptq_params(model_id) diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index f7d0156d2..f2d899db5 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -69,10 +69,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -97,33 +97,45 @@ def __init__( adapter_id=adapter_id, adapter_source=adapter_source, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, ATTN_Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.qkv_proj) - layer_weights[(i, ATTN_K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.qkv_proj) - layer_weights[(i, ATTN_V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.qkv_proj) - layer_weights[(i, ATTN_DENSE)] = (f"{prefix}.{i}.self_attn.dense", layer.self_attn.dense) + layer_weights[(i, ATTN_Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.qkv_proj, + ) + layer_weights[(i, ATTN_K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.qkv_proj, + ) + layer_weights[(i, ATTN_V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.qkv_proj, + ) + layer_weights[(i, ATTN_DENSE)] = ( + f"{prefix}.{i}.self_attn.dense", + layer.self_attn.dense, + ) layer_weights[(i, MLP_FC1)] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1) layer_weights[(i, MLP_FC2)] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index 01013b0eb..a987fdb8e 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -68,10 +68,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -96,11 +96,11 @@ def __init__( adapter_id=adapter_id, adapter_source=adapter_source, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} @@ -112,31 +112,28 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights[(i, MLP_W1)] = (f"{prefix}.{i}.mlp.w1", layer.mlp.gate_up_proj) layer_weights[(i, MLP_W2)] = (f"{prefix}.{i}.mlp.w2", layer.mlp.gate_up_proj) layer_weights[(i, MLP_C_PROJ)] = (f"{prefix}.{i}.mlp.c_proj", layer.mlp.c_proj) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - + def split_lora_b_qkv(self, t: torch.Tensor, projection_size: int) -> torch.Tensor: # Because we're splitting on the hidden size dimension, we need to # account for the separate q, k, and v matrices. chunks = torch.split(t, projection_size, dim=1) assert len(chunks) == 3 - chunks = [ - shard_on_dim(w, dim=1, process_group=self.process_group) - for w in chunks - ] + chunks = [shard_on_dim(w, dim=1, process_group=self.process_group) for w in chunks] return torch.cat(chunks, dim=1) - + def shard_lora_weights( self, weights_a: List[torch.Tensor], @@ -148,18 +145,16 @@ def shard_lora_weights( # [hidden_size, r] split_dim = 0 if self.is_row_parallel(layer_type) else 1 weights_a = [ - shard_on_dim(w, dim=split_dim, process_group=self.process_group) - for w in weights_a + shard_on_dim(w, dim=split_dim, process_group=self.process_group) for w in weights_a ] # [r, hidden_size] # Because we're splitting on the hidden size dimension, we need to # account for the separate q, k, and v matrices. - projection_size = (self.config.hidden_size // self.config.num_attention_heads) * self.config.num_attention_heads - weights_b = [ - self.split_lora_b_qkv(w, projection_size) - for w in weights_b - ] + projection_size = ( + self.config.hidden_size // self.config.num_attention_heads + ) * self.config.num_attention_heads + weights_b = [self.split_lora_b_qkv(w, projection_size) for w in weights_b] return weights_a, weights_b else: diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 1b3161c41..64f4ada3f 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -29,7 +29,16 @@ tracer = trace.get_tracer(__name__) -ADAPTER_LAYERS = [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ, ATTN_O_PROJ, MLP_GATE_PROJ, MLP_UP_PROJ, MLP_DOWN_PROJ, LM_HEAD] +ADAPTER_LAYERS = [ + ATTN_Q_PROJ, + ATTN_K_PROJ, + ATTN_V_PROJ, + ATTN_O_PROJ, + MLP_GATE_PROJ, + MLP_UP_PROJ, + MLP_DOWN_PROJ, + LM_HEAD, +] ROW_PARALLEL = {ATTN_O_PROJ, MLP_DOWN_PROJ, LM_HEAD} @@ -69,10 +78,10 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -99,34 +108,49 @@ def __init__( adapter_id=adapter_id, adapter_source=adapter_source, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, ATTN_Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) - - layer_weights[(i, MLP_GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) + layer_weights[(i, ATTN_Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, MLP_GATE_PROJ)] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) layer_weights[(i, MLP_UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) layer_weights[(i, MLP_DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/galactica.py b/server/lorax_server/models/galactica.py index be43d0244..8e62d0dc1 100644 --- a/server/lorax_server/models/galactica.py +++ b/server/lorax_server/models/galactica.py @@ -96,15 +96,11 @@ def from_pb( req_inputs = tokenizers.get_inputs(r, tokenizer) inputs.append(escape_custom_split_sequence(req_inputs)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) + stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) + padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) tokenized_inputs = tokenizer( inputs, @@ -124,9 +120,7 @@ def from_pb( input_ids = tokenized_inputs["input_ids"] # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) + attention_mask = input_ids.new_zeros((pb.size, max_input_length + padding_right_offset)) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] @@ -168,7 +162,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with GalacticaSharded") - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -196,9 +190,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) + weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq", "eetq"]: weights._set_gptq_params(model_id) @@ -226,9 +218,7 @@ def decode(self, generated_ids: List[int]) -> str: generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False ) - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): + def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, diff --git a/server/lorax_server/models/gpt_neox.py b/server/lorax_server/models/gpt_neox.py index 675ab7fc8..fe9b86953 100644 --- a/server/lorax_server/models/gpt_neox.py +++ b/server/lorax_server/models/gpt_neox.py @@ -30,7 +30,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with GPT-NeoX") - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -57,9 +57,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) + weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq", "eetq"]: weights._set_gptq_params(model_id) @@ -77,9 +75,7 @@ def __init__( world_size=world_size, ) - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): + def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 5d1c18fc1..e03661fb9 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -58,8 +58,7 @@ def __init__( self.static_adapter_id = adapter_id self.has_position_ids = ( - inspect.signature(model.forward).parameters.get("position_ids", None) - is not None + inspect.signature(model.forward).parameters.get("position_ids", None) is not None ) if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID: @@ -85,7 +84,7 @@ def info(self) -> InfoResponse: device_type=self.device.type, window_size=self.sliding_window, ) - + @property def sliding_window_blocks(self) -> Optional[int]: return None @@ -94,7 +93,7 @@ def sliding_window_blocks(self) -> Optional[int]: @abstractmethod def batch_type(self) -> Type[B]: raise NotImplementedError - + def adapter_memory_size(self) -> int: return 0 @@ -119,9 +118,7 @@ def decode_token( prefix_text = self.tokenizer.decode( all_input_ids[prefix_offset:read_offset], skip_special_tokens=False ) - new_text = self.tokenizer.decode( - all_input_ids[prefix_offset:], skip_special_tokens=False - ) + new_text = self.tokenizer.decode(all_input_ids[prefix_offset:], skip_special_tokens=False) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence @@ -142,30 +139,33 @@ def check_initialized(self): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) - + @property def supports_adapter_loading(self) -> bool: return False - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: return {} - + @property def adapter_layers(self) -> List[str]: return [] - + def get_num_layers_for_type(self, layer_type: str) -> int: return 0 - + def is_row_parallel(self, layer_type: str) -> bool: return False - + @property def max_speculative_tokens(self) -> int: - return max([ - layer_weights.max_speculative_tokens for - layer_weights in self.batched_lora_weights.values() - ], default=0) + return max( + [ + layer_weights.max_speculative_tokens + for layer_weights in self.batched_lora_weights.values() + ], + default=0, + ) def load_adapter( self, @@ -177,48 +177,64 @@ def load_adapter( ): """Loads adapter weights from disk / host memory on the GPU. - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded + adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded into model. Otherwise, the adapter weights are applied during the forward pass and stored separately from the base model parameters. """ if adapter_index in self.loaded_adapters: # Adapter already loaded return - + if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") - + if dynamic and not self.dynamic_adapter_loading_enabled: - raise ValueError(f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) - logger.info(f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}") + logger.info( + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + ) weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_and_merge_adapters( - self.model_id, adapter_parameters, adapter_source, adapter_index, weight_names, api_token + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + self.model_id, + adapter_parameters, + adapter_source, + adapter_index, + weight_names, + api_token, ) unused_weight_names = adapter_weight_names.copy() for layer_name in self.adapter_layers: adapter_weights = adapter_config.load_batched_adapter_weights( - self, + self, module_map, - layer_name, + layer_name, unused_weight_names, dynamic, ) if adapter_weights is None: continue - + batched_weights = self.batched_lora_weights[layer_name] batched_weights.add_adapter(adapter_index, adapter_weights) - + if len(unused_weight_names) > 0: - logger.warning(f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}") - + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + if adapter_tokenizer is not None: self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) @@ -233,18 +249,14 @@ def shard_lora_weights( # [hidden_size, r] split_dim = 0 if self.is_row_parallel(layer_type) else 1 weights_a = [ - shard_on_dim(w, dim=split_dim, process_group=self.process_group) - for w in weights_a + shard_on_dim(w, dim=split_dim, process_group=self.process_group) for w in weights_a ] # [r, hidden_size] - weights_b = [ - shard_on_dim(w, dim=1, process_group=self.process_group) - for w in weights_b - ] + weights_b = [shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b] return weights_a, weights_b - + def offload_adapter( self, adapter_parameters: AdapterParameters, @@ -255,15 +267,17 @@ def offload_adapter( if adapter_index not in self.loaded_adapters: # Adapter already offloaded return - + if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") - + if not self.dynamic_adapter_loading_enabled: - raise ValueError(f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) for layer_name in self.adapter_layers: if layer_name in self.batched_lora_weights: diff --git a/server/lorax_server/models/mpt.py b/server/lorax_server/models/mpt.py index 8c0040d53..42e1d045e 100644 --- a/server/lorax_server/models/mpt.py +++ b/server/lorax_server/models/mpt.py @@ -34,7 +34,9 @@ def from_pb( dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device) + batch = super().from_pb( + pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device + ) batch.keys_head_dim_last = False return batch @@ -50,7 +52,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with MPT") - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -72,9 +74,7 @@ def __init__( if local_path.exists(): filename = str(local_path.resolve()) else: - filename = hf_hub_download( - model_id, revision=revision, filename="config.json" - ) + filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(filename, "r") as f: config = json.load(f) config = PretrainedConfig(**config) diff --git a/server/lorax_server/models/opt.py b/server/lorax_server/models/opt.py index 6ffeb1bb1..298826500 100644 --- a/server/lorax_server/models/opt.py +++ b/server/lorax_server/models/opt.py @@ -28,7 +28,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with OPT") - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -55,9 +55,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) + weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq", "eetq"]: weights._set_gptq_params(model_id) @@ -75,9 +73,7 @@ def __init__( world_size=world_size, ) - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): + def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, diff --git a/server/lorax_server/models/rw.py b/server/lorax_server/models/rw.py index 4d80f255c..634f5d069 100644 --- a/server/lorax_server/models/rw.py +++ b/server/lorax_server/models/rw.py @@ -18,7 +18,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with RW") - + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/lorax_server/models/santacoder.py b/server/lorax_server/models/santacoder.py index ec1a0a9be..6aed48ace 100644 --- a/server/lorax_server/models/santacoder.py +++ b/server/lorax_server/models/santacoder.py @@ -25,7 +25,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with SantaCoder") - + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 2040f064d..f57f01575 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -98,15 +98,11 @@ def from_pb( requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) + stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) + padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) # Tokenize batch tokenized_inputs = tokenizer( @@ -123,9 +119,7 @@ def from_pb( # Decoder sequence only contains the bos_token decoder_input_ids = ( - torch.tensor(tokenizer.bos_token_id, device=device) - .repeat(len(pb.requests)) - .view(-1, 1) + torch.tensor(tokenizer.bos_token_id, device=device).repeat(len(pb.requests)).view(-1, 1) ) for _ in pb.requests: prefix_offsets.append(0) @@ -202,9 +196,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: request_decoder_input_length = self.decoder_input_lengths[idx] decoder_input_lengths.append(request_decoder_input_length) - max_decoder_input_length = max( - max_decoder_input_length, request_decoder_input_length - ) + max_decoder_input_length = max(max_decoder_input_length, request_decoder_input_length) next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] @@ -233,9 +225,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: - self.past_key_values = [ - [t for t in layer] for layer in self.past_key_values - ] + self.past_key_values = [[t for t in layer] for layer in self.past_key_values] decoder_past_seq_len = max_decoder_input_length - 1 for layer in self.past_key_values: @@ -283,9 +273,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": for batch in batches: total_batch_size += len(batch) max_input_length = max(max_input_length, batch.max_input_length) - max_decoder_input_length = max( - max_decoder_input_length, batch.max_decoder_input_length - ) + max_decoder_input_length = max(max_decoder_input_length, batch.max_decoder_input_length) padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes @@ -342,9 +330,9 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": (total_batch_size, max_input_length), ) # Copy to correct indices - attention_mask[ - start_index:end_index, -batch.max_input_length : - ] = batch.attention_mask[:, -batch.max_input_length :] + attention_mask[start_index:end_index, -batch.max_input_length :] = batch.attention_mask[ + :, -batch.max_input_length : + ] # Create padded tensor if decoder_input_ids is None: @@ -401,9 +389,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": # Ensure that we can update tensors in-place if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [ - [t for t in layer] for layer in batch.past_key_values - ] + batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( @@ -473,9 +459,9 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": # Slicing end index for this batch end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches - padded_past_values[ - start_index:end_index, :, -batch.max_input_length :, : - ] = t[:, :, -batch.max_input_length :, :] + padded_past_values[start_index:end_index, :, -batch.max_input_length :, :] = t[ + :, :, -batch.max_input_length :, : + ] del t start_index = end_index @@ -605,9 +591,7 @@ def generate_token( ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: if batch.decoder_attention_mask is not None: # slice to the correct shape - decoder_attention_mask = batch.decoder_attention_mask[ - :, : -batch.padding_right_offset - ] + decoder_attention_mask = batch.decoder_attention_mask[:, : -batch.padding_right_offset] else: decoder_attention_mask = None @@ -662,9 +646,7 @@ def generate_token( ) # Append next token to decoder tokens - all_decoder_input_ids = torch.cat( - [all_decoder_input_ids, next_token_id.squeeze(1)] - ) + all_decoder_input_ids = torch.cat([all_decoder_input_ids, next_token_id.squeeze(1)]) new_decoder_input_length = decoder_input_length + 1 # Generated token @@ -686,9 +668,7 @@ def generate_token( if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode( - all_decoder_input_ids[-decoder_input_length:] - ) + output_text = self.decode(all_decoder_input_ids[-decoder_input_length:]) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/lorax_server/models/t5.py b/server/lorax_server/models/t5.py index a68118206..834600f8d 100644 --- a/server/lorax_server/models/t5.py +++ b/server/lorax_server/models/t5.py @@ -31,7 +31,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with T5") - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 6344810a3..6ccf63d4a 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -72,6 +72,7 @@ def to_pb(self) -> generate_pb2.PrefillTokens: def __len__(self): return len(self.token_ids) + @dataclass class AlternativeTokens: token_ids: List[int] @@ -101,9 +102,9 @@ def to_pb(self) -> generate_pb2.PrefillTokens: logprobs=self.logprobs, texts=self.texts, is_special=self.is_special, - alternative_tokens=[ - alt_tokens.to_pb() for alt_tokens in self.alternative_tokens - ] if self.alternative_tokens is not None else None, + alternative_tokens=[alt_tokens.to_pb() for alt_tokens in self.alternative_tokens] + if self.alternative_tokens is not None + else None, ) def __len__(self): @@ -121,12 +122,8 @@ class Generation: def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( request_id=self.request_id, - prefill_tokens=self.prefill_tokens.to_pb() - if self.prefill_tokens is not None - else None, + prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, prefill_tokens_length=self.prefill_tokens_length, next_tokens=self.next_tokens.to_pb(), - generated_text=self.generated_text.to_pb() - if self.generated_text is not None - else None, + generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, ) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index d8b79842d..acfbbdab8 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -81,9 +81,7 @@ async def Warmup(self, request: generate_pb2.WarmupRequest, context): ) max_supported_total_tokens = self.model.warmup(batch, request.max_new_tokens) - return generate_pb2.WarmupResponse( - max_supported_total_tokens=max_supported_total_tokens - ) + return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) async def Prefill(self, request: generate_pb2.PrefillRequest, context): batch = self.model.batch_type.from_pb( @@ -128,7 +126,7 @@ async def Decode(self, request: generate_pb2.DecodeRequest, context): generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, ) - + async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context): adapter_parameters = request.adapter_parameters if is_base_model(adapter_parameters): @@ -142,13 +140,15 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co if adapter_id == BASE_MODEL_ADAPTER_ID: logger.info("No adapter to download for base model. Skipping.") continue - + adapter_bytes += download_adapter(adapter_id, adapter_source, api_token) - + adapter_memory_size = self.model.adapter_memory_size() if adapter_memory_size > 0: - logger.info(f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes " - f"(reservation: {adapter_memory_size} bytes)") + logger.info( + f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes " + f"(reservation: {adapter_memory_size} bytes)" + ) adapter_memory_fraction = adapter_bytes / adapter_memory_size if adapter_memory_fraction > 1: raise ValueError( @@ -157,13 +157,14 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co ) else: # Assume 0.0 memory fraction if adapter memory size is not set - logger.info(f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes " - f"(no reservation limit)") + logger.info( + f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes " + f"(no reservation limit)" + ) adapter_memory_fraction = 0.0 - + return generate_pb2.DownloadAdapterResponse( - downloaded=True, - memory_fraction=adapter_memory_fraction + downloaded=True, memory_fraction=adapter_memory_fraction ) async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): @@ -171,7 +172,7 @@ async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): if is_base_model(adapter_parameters): logger.info("No adapter to load for base model. Skipping.") return generate_pb2.LoadAdapterResponse(loaded=False) - + try: adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index @@ -183,9 +184,9 @@ async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) adapter_parameters.adapter_ids[i] = adapter_id adapter_source = S3 - + self.model.load_adapter(adapter_parameters, adapter_source, adapter_index, api_token) - + return generate_pb2.LoadAdapterResponse(loaded=True) except Exception: logger.exception("Error when loading adapter") @@ -196,7 +197,7 @@ async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, cont if is_base_model(adapter_parameters): logger.info("No adapter to offload for base model. Skipping.") return generate_pb2.OffloadAdapterResponse(offloaded=False) - + try: adapter_idx = request.adapter_index adapter_source = _adapter_source_enum_to_string(request.adapter_source) @@ -206,7 +207,7 @@ async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, cont # Ensure there is enough memory for the next adapter torch.cuda.empty_cache() torch.cuda.synchronize(self.model.device) - + return generate_pb2.OffloadAdapterResponse(offloaded=True) except Exception: logger.exception("Error when offloading adapter") @@ -251,7 +252,16 @@ async def serve_inner( try: model = get_model( - model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code, source, adapter_source + model_id, + adapter_id, + revision, + sharded, + quantize, + compile, + dtype, + trust_remote_code, + source, + adapter_source, ) except Exception: logger.exception("Error when initializing model") @@ -271,7 +281,7 @@ async def serve_inner( create_exllama_buffers() except ImportError: pass - + # set speculative decoding tokens speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) if speculative_tokens > 0: @@ -311,14 +321,14 @@ async def serve_inner( asyncio.run( serve_inner( - model_id, - adapter_id, - revision, - sharded, - quantize, - compile, - dtype, - trust_remote_code, + model_id, + adapter_id, + revision, + sharded, + quantize, + compile, + dtype, + trust_remote_code, speculative_tokens, ) ) @@ -335,4 +345,4 @@ def _adapter_source_enum_to_string(adapter_source: int) -> str: elif adapter_source == generate_pb2.AdapterSource.PBASE: return PBASE else: - raise ValueError(f"Unknown adapter source {adapter_source}") \ No newline at end of file + raise ValueError(f"Unknown adapter source {adapter_source}") diff --git a/server/lorax_server/tracing.py b/server/lorax_server/tracing.py index 0d9dcccea..677c07122 100644 --- a/server/lorax_server/tracing.py +++ b/server/lorax_server/tracing.py @@ -55,9 +55,7 @@ def _start_span(self, handler_call_details, context, set_status_on_exception=Fal def setup_tracing(shard: int, otlp_endpoint: str): - resource = Resource.create( - attributes={"service.name": f"lorax.server-{shard}"} - ) + resource = Resource.create(attributes={"service.name": f"lorax.server-{shard}"}) span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) span_processor = BatchSpanProcessor(span_exporter) diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index f73ab992f..951bbcb11 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -11,7 +11,7 @@ from lorax_server.utils.merges.strategies import merge_adapters if TYPE_CHECKING: - from lorax_server.adapters.config import AdapterConfig, ModuleMap + from lorax_server.adapters.config import AdapterConfig, ModuleMap BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -22,7 +22,7 @@ class AdapterParametersContainer: adapter_parameters: generate_pb2.AdapterParameters adapter_source: str adapter_index: int - + def __hash__(self) -> int: return self.adapter_index @@ -58,18 +58,22 @@ def _load_and_merge( api_token: str, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: params = adapter_params.adapter_parameters - + adapters_to_merge = [] merged_weight_names = set() tokenizer = None for adapter_id in params.adapter_ids: if adapter_id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") - + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( - model_id, adapter_id, adapter_params.adapter_source, weight_names, api_token, + model_id, + adapter_id, + adapter_params.adapter_source, + weight_names, + api_token, ) - + adapters_to_merge.append((module_map, adapter_config)) merged_weight_names = merged_weight_names.union(adapter_weight_names) if tokenizer is None: @@ -77,21 +81,26 @@ def _load_and_merge( if len(adapters_to_merge) == 0: raise ValueError("No adapters to merge.") - + module_map, adapter_config = merge_adapters(adapters_to_merge, params) return module_map, adapter_config, merged_weight_names, tokenizer -def check_architectures(model_id: str, adapter_id: str, adapter_config: "AdapterConfig", api_token: str): +def check_architectures( + model_id: str, adapter_id: str, adapter_config: "AdapterConfig", api_token: str +): try: expected_config = AutoConfig.from_pretrained(model_id, token=api_token) - model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path, token=api_token) + model_config = AutoConfig.from_pretrained( + adapter_config.base_model_name_or_path, token=api_token + ) except Exception as e: warnings.warn( f"Unable to check architecture compatibility for adapter '{adapter_id}' " - f"against model '{model_id}'. Assuming they are compatible. Error: {e}") + f"against model '{model_id}'. Assuming they are compatible. Error: {e}" + ) return - + if model_config.architectures == expected_config.architectures: warnings.warn( f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " @@ -99,9 +108,11 @@ def check_architectures(model_id: str, adapter_id: str, adapter_config: "Adapter ) else: # TODO(travis): revisit this when we support clasification heads which will not use CausalLM - raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " - f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " - f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) @lru_cache(maxsize=128) @@ -113,8 +124,10 @@ def load_module_map( api_token: str, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with - # lorax_server/utils/adapter.py::create_merged_weight_files - source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) + # lorax_server/utils/adapter.py::create_merged_weight_files + source = get_model_source( + adapter_source, adapter_id, extension=".safetensors", api_token=api_token + ) config_path = get_config_path(adapter_id, adapter_source) adapter_config = source.load_config() if adapter_config.base_model_name_or_path != model_id: @@ -125,13 +138,15 @@ def load_module_map( except Exception: # Adapter does not have a tokenizer, so fallback to base model tokenizer adapter_tokenizer = None - + # load adapter weights from all shards (should have relatively small memory footprint) adapter_filenames = source.weight_files() adapter_weights = {} for filename in adapter_filenames: adapter_weights.update(load_file(filename)) - + # map the model weights to the relevant adapter weights (LoRA A and B matrices) - module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, weight_names + ) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer diff --git a/server/lorax_server/utils/awq/awq.py b/server/lorax_server/utils/awq/awq.py index 2252734b5..d5c84ec97 100644 --- a/server/lorax_server/utils/awq/awq.py +++ b/server/lorax_server/utils/awq/awq.py @@ -5,6 +5,7 @@ import torch.nn as nn import awq_inference_engine # with CUDA kernels + class AWQLinear(nn.Module): def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): super().__init__() @@ -20,8 +21,12 @@ def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): self.w_bit = w_bit self.group_size = group_size if group_size != -1 else self.in_features - assert self.in_features % self.group_size == 0, "in_features must be divisible by group_size" - assert self.out_features % (32 // self.w_bit) == 0, "out_features must be divisible by 32 // w_bit" + assert ( + self.in_features % self.group_size == 0 + ), "in_features must be divisible by group_size" + assert ( + self.out_features % (32 // self.w_bit) == 0 + ), "out_features must be divisible by 32 // w_bit" self.qweight = qweight self.qzeros = qzeros @@ -30,21 +35,22 @@ def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): @torch.no_grad() def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features, ) + out_shape = x.shape[:-1] + (self.out_features,) input_dtype = x.dtype if input_dtype != torch.float16: x = x.half() - - out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) - + + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) + if input_dtype != torch.float16: out = out.to(dtype=input_dtype) - + out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) - @property def weight(self) -> torch.Tensor: - return self.qweight \ No newline at end of file + return self.qweight diff --git a/server/lorax_server/utils/convert.py b/server/lorax_server/utils/convert.py index f02703bed..836f5df07 100644 --- a/server/lorax_server/utils/convert.py +++ b/server/lorax_server/utils/convert.py @@ -27,9 +27,7 @@ def _remove_duplicate_names( shareds = _find_shared_tensors(state_dict) to_remove = defaultdict(list) for shared in shareds: - complete_names = set( - [name for name in shared if _is_complete(state_dict[name])] - ) + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) if not complete_names: if len(shared) == 1: # Force contiguous @@ -93,9 +91,13 @@ def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): sf_tensor = reloaded[k] if not torch.equal(pt_tensor, sf_tensor): if torch.any(torch.isnan(pt_tensor)): - raise NanWeightsError(f"Weights unusuable as param {k} in file {pt_file} contains NaN values") + raise NanWeightsError( + f"Weights unusuable as param {k} in file {pt_file} contains NaN values" + ) if torch.any(torch.isinf(pt_tensor)): - raise InfWeightsError(f"Weights unusuable as param {k} in file {pt_file} contains inf values") + raise InfWeightsError( + f"Weights unusuable as param {k} in file {pt_file} contains inf values" + ) raise RuntimeError(f"The output tensors do not match for key {k}") @@ -107,14 +109,10 @@ def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: Lis for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): # Skip blacklisted files - if ( - "arguments" in pt_file.name - or "args" in pt_file.name - or "training" in pt_file.name - ): + if "arguments" in pt_file.name or "args" in pt_file.name or "training" in pt_file.name: continue start = datetime.datetime.now() convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start - logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") \ No newline at end of file + logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index d84bf4313..6c47b2b84 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -27,8 +27,7 @@ ) if not (is_sm8x or is_sm90): raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" + f"GPU with CUDA capability {major} {minor} is not supported for " "Flash Attention V2" ) HAS_FLASH_ATTN_V2 = True except ImportError as e: @@ -42,9 +41,7 @@ ) from e if not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e + raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") from e logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True @@ -81,9 +78,7 @@ def attention( if HAS_FLASH_ATTN: if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) + raise NotImplementedError("window_size_left is only available with flash attn v2") # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: diff --git a/server/lorax_server/utils/gptq/custom_autotune.py b/server/lorax_server/utils/gptq/custom_autotune.py index 1eb40f1ed..a0345474f 100644 --- a/server/lorax_server/utils/gptq/custom_autotune.py +++ b/server/lorax_server/utils/gptq/custom_autotune.py @@ -87,9 +87,7 @@ def kernel_call(): try: # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - return triton.testing.do_bench( - kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 - ) + return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40) except triton.OutOfResources: return (float("inf"), float("inf"), float("inf")) @@ -108,8 +106,7 @@ def run(self, *args, **kwargs): pruned_configs = self.prune_configs(kwargs) bench_start = time.time() timings = { - config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs + config: self._bench(*args, config=config, **kwargs) for config in pruned_configs } bench_end = time.time() self.bench_time = bench_end - bench_start @@ -149,9 +146,7 @@ def prune_configs(self, kwargs): ) for config in pruned_configs } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ - :top_k - ] + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] return pruned_configs def warmup(self, *args, **kwargs): @@ -167,9 +162,7 @@ def warmup(self, *args, **kwargs): self.nargs = None -def autotune( - configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False -): +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): """ Decorator for auto-tuning a :code:`triton.jit`'d function. .. highlight:: python diff --git a/server/lorax_server/utils/gptq/exllamav2.py b/server/lorax_server/utils/gptq/exllamav2.py index 734b17e1f..a0d7da9a5 100644 --- a/server/lorax_server/utils/gptq/exllamav2.py +++ b/server/lorax_server/utils/gptq/exllamav2.py @@ -10,20 +10,22 @@ try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error('exllamav2_kernels not installed.') + logger.error("exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device) + output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) gemm_half_q_half(x, q_handle, output, force_cuda) return output.view(output_shape) + def make_group_map(q_groups, num_qrows): # Convert q_groups to a list gr = q_groups.tolist() @@ -112,7 +114,9 @@ def ext_make_q_matrix(w: dict, temp_dq): w["scales"] = w["scales"].half() # Check if 'g_idx' exists and is not all zeros - g_idx_exists_and_not_all_zeros = w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item() + g_idx_exists_and_not_all_zeros = ( + w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item() + ) if g_idx_exists_and_not_all_zeros: # Create 'q_perm' and 'q_invperm' @@ -161,6 +165,7 @@ def ext_make_q_matrix(w: dict, temp_dq): FIXED_BYTES = 0 LAYERS = [] + def set_device(device): global DEVICE DEVICE = device @@ -183,11 +188,12 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): super().__init__() if bits != 4: raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") + f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." + ) self.q_handle = None self.q_tensors = None self.bits = bits - self.maxq = 2 ** self.bits - 1 + self.maxq = 2**self.bits - 1 self.infeatures = qweight.shape[0] // self.bits * 32 self.outfeatures = qweight.shape[1] + qweight.shape[1] % 32 @@ -207,39 +213,36 @@ def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None self.q_tensors = { - "qweight":self.qweight, - "qzeros":self.qzeros, - "scales":self.scales, - "g_idx":self.g_idx + "qweight": self.qweight, + "qzeros": self.qzeros, + "scales": self.scales, + "g_idx": self.g_idx, } temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - self.q_handle = ext_make_q_matrix( - self.q_tensors, temp_dq - ) - - def forward(self, x, force_cuda = False): + self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + + def forward(self, x, force_cuda=False): output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) if self.bias is not None: output.add_(self.bias) return output - + def temp_dq_size(self): return self.infeatures * self.outfeatures * 2 + 128 - + def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - + def scratch_spacing(self, max_input_len=8192, max_batch_size=32): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) @property def weight(self) -> torch.Tensor: return self.qweight - - -class ExLlamaV2DeviceTensors: + +class ExLlamaV2DeviceTensors: device_idx: int scratch_bytes: int scratch_idx: int @@ -248,15 +251,15 @@ class ExLlamaV2DeviceTensors: def __init__(self, device, scratch_bytes): self.device = device self.scratch_bytes = scratch_bytes - + def prepare(self): - self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) + self.scratch = torch.empty((self.scratch_bytes // 2,), dtype=torch.half, device=self.device) def get_scratch_slice(self, size_bytes): - - if self.scratch is None: self.prepare() + if self.scratch is None: + self.prepare() size_bytes = ((size_bytes + 127) // 128) * 128 size_half = size_bytes // 2 scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice \ No newline at end of file + return scratch_slice diff --git a/server/lorax_server/utils/gptq/quant_linear.py b/server/lorax_server/utils/gptq/quant_linear.py index 7c44f3e3c..a811060dc 100644 --- a/server/lorax_server/utils/gptq/quant_linear.py +++ b/server/lorax_server/utils/gptq/quant_linear.py @@ -158,8 +158,7 @@ def matmul_248_kernel( a_mask = offs_am[:, None] < M # b_ptrs is set up such that it repeats elements along the K axis 8 times b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn + (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B @@ -275,12 +274,8 @@ def new(cls, bits, groupsize, infeatures, outfeatures, bias): (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), dtype=torch.int32, ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) + scales = torch.zeros((math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16) + g_idx = torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) if bias: bias = torch.zeros((outfeatures), dtype=torch.float16) else: @@ -327,9 +322,7 @@ def pack(self, linear, scales, zeros, g_idx=None): zeros -= 1 zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 col = 0 while col < qzeros.shape[1]: @@ -357,7 +350,7 @@ def forward(self, x): ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) - + @property def weight(self) -> torch.Tensor: return self.qweight diff --git a/server/lorax_server/utils/gptq/quantize.py b/server/lorax_server/utils/gptq/quantize.py index c300dcfbe..72fa65de6 100644 --- a/server/lorax_server/utils/gptq/quantize.py +++ b/server/lorax_server/utils/gptq/quantize.py @@ -104,9 +104,7 @@ def find_params(self, x, weight=False): xmax1 = p * xmax scale1 = (xmax1 - xmin1) / self.maxq zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = self._quantize( - x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq - ) + q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q -= x q.abs_() q.pow_(self.norm) @@ -180,9 +178,7 @@ def add_batch(self, inp, out): if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() @@ -206,11 +202,7 @@ def add_batch(self, inp, out): def print_loss(self, name, q_weight, weight_error, timecost): table = Texttable() length = 28 - name = ( - (name + " " * (length - len(name))) - if len(name) <= length - else name[:length] - ) + name = (name + " " * (length - len(name))) if len(name) <= length else name[:length] table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) @@ -237,9 +229,7 @@ def print_loss(self, name, q_weight, weight_error, timecost): table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) print(table.draw().split("\n")[-2]) - def fasterquant( - self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" - ): + def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""): self.layer.to(self.dev) W = self.layer.weight.data.clone() @@ -278,9 +268,7 @@ def fasterquant( H = torch.linalg.cholesky(H, upper=True) except Exception: # Addition because Falcon fails on h_to_4h - H = torch.linalg.cholesky( - H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True - ) + H = torch.linalg.cholesky(H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True) Hinv = H g_idx = [] @@ -340,9 +328,7 @@ def fasterquant( if isinstance(self.layer, transformers.Conv1D): Q = Q.t() - self.print_loss( - name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) - ) + self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) if scale == []: scale.append(self.quantizer.scale) @@ -583,9 +569,7 @@ def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""): res = {} for name1, child in module.named_children(): res.update( - find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) + find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1) ) return res @@ -616,9 +600,7 @@ def sequential( prefix = "transformer.h" dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev - ) + inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) cache = {"i": 0} extra = {} @@ -651,9 +633,7 @@ def forward(self, inp, **kwargs): outs = torch.zeros_like(inps) - extra = { - k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items() - } + extra = {k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()} print("Ready.") @@ -674,9 +654,7 @@ def forward(self, inp, **kwargs): gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer.configure( - bits, perchannel=True, sym=sym, mse=False - ) + gptq[name].quantizer.configure(bits, perchannel=True, sym=sym, mse=False) pass def add_batch(name): @@ -885,20 +863,14 @@ def _load(): def unload(module, name): def _unload(): - load_weights_post_hook(name, weights, recursive=True)( - module, None, None - ) + load_weights_post_hook(name, weights, recursive=True)(module, None, None) return _unload module.load = load(module, name) module.unload = unload(module, name) - hooks.append( - module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) - ) - hooks.append( - module.register_forward_hook(load_weights_post_hook(name, weights)) - ) + hooks.append(module.register_forward_pre_hook(load_weights_pre_hook(name, weights))) + hooks.append(module.register_forward_hook(load_weights_post_hook(name, weights))) model.seqlen = 2048 dataset = "wikitext2" @@ -965,15 +937,11 @@ def _unload(): config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer") - tokenizer = AutoTokenizer.from_pretrained( - model_id, trust_remote_code=trust_remote_code - ) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) tokenizer.save_pretrained(output_dir) logger.info("Saved tokenizer") if upload_to_model_id: api = HfApi() - api.upload_folder( - folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model" - ) + api.upload_folder(folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model") diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 4a1db6be1..bb3c0a077 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -57,13 +57,13 @@ def get_cached_batch_size(batch_size: int) -> int: def pad_and_fill(dest: torch.Tensor, src: torch.Tensor, pad_value: int): - dest[:src.shape[0]] = src - dest[src.shape[0]:].fill_(pad_value) + dest[: src.shape[0]] = src + dest[src.shape[0] :].fill_(pad_value) def next_pow_2(x: int) -> int: assert x > 0 - return 1 << (x-1).bit_length() + return 1 << (x - 1).bit_length() @dataclass @@ -78,8 +78,8 @@ class GraphState: @lru_cache(maxsize=1) def get_max_graph_state( - device: torch.device, - adapter_layers: Tuple[str], + device: torch.device, + adapter_layers: Tuple[str], max_total_tokens: int, sliding_window_blocks: Optional[int] = None, ) -> GraphState: @@ -150,7 +150,7 @@ def __init__( self.input_state = input_state self.output_states = output_states self.model = model - + @staticmethod def trace( model: nn.Module, @@ -162,7 +162,9 @@ def trace( max_total_tokens: int, sliding_window_blocks: Optional[int] = None, ) -> "GraphWrapper": - max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks) + max_input_state = get_max_graph_state( + device, adapter_layers, max_total_tokens, sliding_window_blocks + ) # WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2 # as the segment size. This is a workaround until we can figure out why. @@ -192,10 +194,16 @@ def trace( tmp_expand=weight_data.rank_data[MAX_RANK].tmp_expand[:tmp_expand_size], lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], - segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:segment_size], - segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:segment_size], + segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[ + :segment_size + ], + segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[ + :segment_size + ], ), - } if max_rank > 0 else {}, + } + if max_rank > 0 + else {}, ) } @@ -209,7 +217,9 @@ def trace( meta=AdapterBatchMetadata( adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], adapter_set=max_input_state.adapter_data.meta.adapter_set, - adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], + adapter_segments=max_input_state.adapter_data.meta.adapter_segments[ + :batch_size + ], segment_indices=max_input_state.adapter_data.meta.segment_indices, ), data=adapter_weight_data, @@ -235,10 +245,8 @@ def trace( torch.cuda.synchronize(device) - return GraphWrapper( - graph, graph.pool(), input_state, output_states, model - ) - + return GraphWrapper(graph, graph.pool(), input_state, output_states, model) + def forward( self, input_ids: torch.Tensor, @@ -258,7 +266,9 @@ def forward( pad_and_fill(self.input_state.input_lengths, input_lengths, 0) self.input_state.block_tables.zero_() - self.input_state.block_tables[:block_tables.shape[0], :block_tables.shape[1]] = block_tables + self.input_state.block_tables[ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables for layer_name, weight_data in self.input_state.adapter_data.data.items(): lora_data = weight_data[LORA] @@ -268,7 +278,7 @@ def forward( rank_data.segment_starts.fill_(SEGMENT_PAD_VALUE) rank_data.segment_ends.fill_(SEGMENT_PAD_VALUE) continue - + source_data = adapter_data.data[layer_name] dest_data = lora_data for rank, source_rank_data in source_data.rank_data.items(): @@ -277,22 +287,28 @@ def forward( pad_and_fill(dest_rank_data.lora_a_ptr, source_rank_data.lora_a_ptr, 0) pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0) - pad_and_fill(dest_rank_data.segment_starts, source_rank_data.segment_starts, SEGMENT_PAD_VALUE) - pad_and_fill(dest_rank_data.segment_ends, source_rank_data.segment_ends, SEGMENT_PAD_VALUE) - + pad_and_fill( + dest_rank_data.segment_starts, + source_rank_data.segment_starts, + SEGMENT_PAD_VALUE, + ) + pad_and_fill( + dest_rank_data.segment_ends, source_rank_data.segment_ends, SEGMENT_PAD_VALUE + ) + self.graph.replay() - return self.output_states[:input_ids.shape[0]] - + return self.output_states[: input_ids.shape[0]] + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) class GraphCache: def __init__( - self, - model: nn.Module, - device: torch.device, + self, + model: nn.Module, + device: torch.device, adapter_layers: List[str], max_total_tokens: int, sliding_window_blocks: Optional[int] = None, @@ -330,7 +346,7 @@ def can_use_graph( and max_rank in _allowed_ranks and all(k == LORA for k in adapter_keys) ) - + def get_estimated_cache_memory(self) -> int: # Store off graphs into temporary cache to discard after estimation tmp_cache = {} @@ -343,7 +359,7 @@ def get_estimated_cache_memory(self) -> int: for i, max_rank in enumerate(reversed(CACHED_MAX_RANKS)): torch.cuda.synchronize(self.device) free_memory_before, _ = torch.cuda.mem_get_info(self.device) - + key = (batch_size, max_rank) graph = GraphWrapper.trace( self.model, @@ -366,11 +382,11 @@ def get_estimated_cache_memory(self) -> int: delta_memory = free_memory_before - free_memory_after if i > 0: samples.append(delta_memory) - + # Tracing all graphs can take a while, so limit the number of samples if len(samples) == MAX_SAMPLES: break - + # Estimate memory usage for all batch sizes and ranks ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS) per_graph_memory = median(samples) @@ -381,7 +397,7 @@ def warmup(self): pool = None with tqdm(total=ngraphs, desc="Trace CUDA graphs") as pbar: for batch_size in reversed(CACHED_BATCH_SIZES): - pbar.set_postfix({'batch_size': batch_size}) + pbar.set_postfix({"batch_size": batch_size}) for max_rank in reversed(CACHED_MAX_RANKS): key = (batch_size, max_rank) graph = GraphWrapper.trace( @@ -425,7 +441,7 @@ def forward( max_rank, self.memory_pool, ) - + output_states = self.cache[key].forward( input_ids=input_ids, position_ids=position_ids, @@ -440,6 +456,6 @@ def forward( ) return output_states - + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 3cfbfe226..dd2a76259 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -17,7 +17,7 @@ HAS_BITS_AND_BYTES = False HAS_AWQ = True -try: +try: from lorax_server.utils.awq.awq import AWQLinear except ImportError: HAS_AWQ = False @@ -38,6 +38,7 @@ class HQQLinearLayer(HQQLinear): @property def weight(self) -> torch.Tensor: return self.W_q + except ImportError: HAS_HQQ = False @@ -45,7 +46,12 @@ def weight(self) -> torch.Tensor: from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.gptq.quant_linear import QuantLinear -from lorax_server.utils.sgmv import lora_a_sgmv_cutlass, lora_b_sgmv_cutlass, has_sgmv, orient_for_rank +from lorax_server.utils.sgmv import ( + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + has_sgmv, + orient_for_rank, +) from lorax_server.utils.state import is_warmup HAS_EXLLAMA = True @@ -141,7 +147,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x = torch.addmm(self.bias, input.view(-1, input.size(-1)), self.weight) x = x.view(size_out) return x - + class EETQLinear(nn.Module): """ @@ -190,29 +196,28 @@ def __init__( self.bias = None def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - Performs the forward pass of the layer. - - Args: - input (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor. - """ - # The function w8_a16_gemm performs a matrix multiplication operation between the input and the weight of the layer. - # The result is then scaled by a factor (self.scale). - gemm_output = w8_a16_gemm(input, self.weight, self.scale) - - # If a bias is present (i.e., self.bias is not None), it is added to the output of the matrix multiplication. - # If a bias is not present (i.e., self.bias is None), the output of the matrix multiplication is returned as is. - if self.bias is not None: - final_output = gemm_output + self.bias - else: - final_output = gemm_output - - # The final output is returned. - return final_output + """ + Performs the forward pass of the layer. + + Args: + input (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + # The function w8_a16_gemm performs a matrix multiplication operation between the input and the weight of the layer. + # The result is then scaled by a factor (self.scale). + gemm_output = w8_a16_gemm(input, self.weight, self.scale) + + # If a bias is present (i.e., self.bias is not None), it is added to the output of the matrix multiplication. + # If a bias is not present (i.e., self.bias is None), the output of the matrix multiplication is returned as is. + if self.bias is not None: + final_output = gemm_output + self.bias + else: + final_output = gemm_output + # The final output is returned. + return final_output class Linear8bitLt(nn.Module): @@ -272,6 +277,7 @@ def forward(self, x: torch.Tensor): self.weight.data = self.state.CxB return out + class Linear4bit(nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() @@ -293,7 +299,9 @@ def forward(self, x: torch.Tensor): # Check if quantization state is initialized if getattr(self.weight, "quant_state", None) is None: - print("FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.") + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) # Convert input to compute_dtype if specified inp_dtype = x.dtype @@ -311,6 +319,7 @@ def forward(self, x: torch.Tensor): return out + def get_linear(weight, bias, quantize, fan_in_fan_out=False): # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out # Set to True if replacing a Conv1D layer with a Linear layer @@ -344,9 +353,7 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): if HAS_EETQ: linear = EETQLinear(weight, bias) else: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) + raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") elif quantize == "gptq": try: qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight @@ -371,17 +378,28 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): try: qweight, qzeros, scales, _, bits, groupsize, _ = weight except Exception: - raise NotImplementedError( - f"The passed weight is not compatible with `awq`" - ) - linear = AWQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None) + raise NotImplementedError(f"The passed weight is not compatible with `awq`") + linear = AWQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) elif "hqq-" in quantize: if quantize == "hqq-4bit": - quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False) + quant_config = BaseQuantizeConfig( + nbits=4, group_size=64, quant_zero=True, quant_scale=False + ) elif quantize == "hqq-3bit": - quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=False) + quant_config = BaseQuantizeConfig( + nbits=3, group_size=64, quant_zero=True, quant_scale=False + ) elif quantize == "hqq-2bit": - quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=False) + quant_config = BaseQuantizeConfig( + nbits=2, group_size=16, quant_zero=True, quant_scale=False + ) # init nn.linear from weight and bias layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None) @@ -389,7 +407,7 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): layer.weight.data = weight if bias is not None: layer.bias.data = bias - + linear = HQQLinearLayer(layer, quant_config, del_orig=True) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") @@ -452,9 +470,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.mm(input, self.linear.weight.T, out=local_out) - torch.distributed.all_gather_into_tensor( - world_out, local_out, group=self.process_group - ) + torch.distributed.all_gather_into_tensor(world_out, local_out, group=self.process_group) return world_out output = super().forward(input) @@ -481,22 +497,19 @@ def load_qkv(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out=False @classmethod def load(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = False): - return cls.load_multi( - config, [prefix], weights, bias, dim=0, fan_in_fan_out=fan_in_fan_out) + return cls.load_multi(config, [prefix], weights, bias, dim=0, fan_in_fan_out=fan_in_fan_out) @classmethod def load_multi( - cls, - config, - prefixes: List[Union[str, Tuple]], - weights, - bias: bool, - dim: int, - fan_in_fan_out=False + cls, + config, + prefixes: List[Union[str, Tuple]], + weights, + bias: bool, + dim: int, + fan_in_fan_out=False, ): - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) + weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim) if bias: b = weights.get_sharded_list("bias", prefixes, dim=0) @@ -505,7 +518,7 @@ def load_multi( bias = None linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) return cls(linear) - + class LoraLinear(nn.Module): def __init__(self, base_layer, layer_id, process_group): @@ -548,7 +561,7 @@ def forward_layer_type( if self.process_group.size() > 1: v = self.collect_lora_a(v) - + lora_b_sgmv_cutlass( proj, v, @@ -564,12 +577,16 @@ def forward_layer_type( else: for adapter_index in adapter_data.meta.adapter_set: if data is not None and data.has_adapter(adapter_index): - adapter_mask = (adapter_data.meta.adapter_indices == adapter_index).to(input.dtype).view(-1, 1) + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) layer_result = self.forward_lora(input, data, adapter_index, adapter_mask) result[:, start_idx:end_idx] += layer_result return result - + def forward_lora( self, input: torch.Tensor, @@ -585,13 +602,13 @@ def forward_lora( a_out = input @ lora_a if self.process_group.size() > 1: a_out = self.collect_lora_a(a_out) - + result = (a_out @ lora_b) * adapter_mask return result def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: raise NotImplementedError("Implemented in subclasses") - + class TensorParallelMultiAdapterLinear(LoraLinear): def __init__(self, base_layer, layer_id, layer_names, sizes, process_group): @@ -604,10 +621,10 @@ def load(cls, base_layer, layer_id, layer_names, sizes, process_group): return TensorParallelMultiAdapterLinear( base_layer, layer_id, layer_names, sizes, process_group ) - + def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torch.Tensor: result = self.base_layer(input) - + # handle models like Bloom that have inputs of shape # (batch_size, sequence_length, hidden_size) # we need to reshape them to (batch_size * sequence_length, hidden_size) @@ -627,9 +644,11 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torc end_idx = offset // self.process_group.size() else: end_idx = result.shape[1] - - result = self.forward_layer_type(result, input, adapter_data, layer_name, start_idx, end_idx) - + + result = self.forward_layer_type( + result, input, adapter_data, layer_name, start_idx, end_idx + ) + if is_3d: result = result.reshape(prev_shape) @@ -656,10 +675,10 @@ def __init__(self, base_layer, layer_id, layer_name, process_group): @classmethod def load(cls, base_layer, layer_id, layer_name, process_group): return cls(base_layer, layer_id, layer_name, process_group) - + def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torch.Tensor: result = self.base_layer(input) - + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 stride = result.shape[-1] // self.process_group.size() start_idx = self.process_group.rank() * stride @@ -668,7 +687,7 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torc self.forward_layer_type(result, input, adapter_data, self.layer_name, start_idx, end_idx) return result - + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. @@ -682,7 +701,9 @@ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: class MultiAdapterHead(TensorParallelAdapterRowLinear): - def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: result = super().forward(input, adapter_data) # Medusa @@ -692,7 +713,7 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> Tupl speculative_logits = None if data is not None and data.default_medusa is not None: speculative_logits = data.default_medusa.model(input) - + # TODO(travis): support multiple medusa adapters with masking: # for adapter_index in adapter_data.meta.adapter_set: # if data.has_adapter(adapter_index): @@ -711,11 +732,11 @@ def __init__(self, linear, process_group, all_reduce: bool = True): @classmethod def load( - cls, - config, - prefix: str, - weights, - bias: bool, + cls, + config, + prefix: str, + weights, + bias: bool, fan_in_fan_out: bool = False, all_reduce: bool = True, ): @@ -942,9 +963,7 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - def get_cos_sin( - self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype - ): + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): """ Return cos and sin for the asked position ids """ @@ -990,9 +1009,7 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - self.inv_freq = _create_inv_freq( - self.dim, newbase, self.inv_freq.device - ) + self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 @@ -1007,14 +1024,14 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, - dim, - max_position_embeddings=2048, - base=10000, - factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attn_factor=1, - beta_fast=32, + dim, + max_position_embeddings=2048, + base=10000, + factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, beta_slow=1, finetuned=True, device=None, @@ -1031,7 +1048,9 @@ def __init__( self.finetuned = finetuned self.yarn(device, factor) - super().__init__(_create_inv_freq(dim, base, device), factor, max_position_embeddings, device, dtype) + super().__init__( + _create_inv_freq(dim, base, device), factor, max_position_embeddings, device, dtype + ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -1046,30 +1065,43 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) - + def yarn(self, device, scaling_factor): pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings) - inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) self.inv_freq = inv_freq - self.mscale = float(get_mscale(scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + get_mscale(scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation # Inverse dim formula to find dim based on number of rotations def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Find dim range bounds based on rotations def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(find_correction_dim( - low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim( - high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim-1) # Clamp values just in case + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case def linear_ramp_mask(min, max, dim): if min == max: @@ -1085,4 +1117,4 @@ def get_mscale(scale=1): return 0.1 * math.log(scale) + 1.0 except ImportError: - pass \ No newline at end of file + pass diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index c1fe05585..dd6bdf612 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -62,9 +62,7 @@ def __call__(self, scores): self.static_warped_scores = local_scores # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) + self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1) self.static_scores.copy_(scores) self.cuda_graph.replay() @@ -84,9 +82,7 @@ def static_warper( top_p: Optional[float], typical_p: Optional[float], ) -> StaticWarper: - return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p - ) + return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): @@ -103,17 +99,13 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): self.penalty = penalty - self.penalty_tensor = torch.tensor( - penalty, dtype=dtype, device=device - ).unsqueeze(1) + self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where( - score < 0, score * self.penalty_tensor, score / self.penalty_tensor - ) + score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) scores.scatter_(1, input_ids, score) return scores @@ -137,13 +129,9 @@ class HeterogeneousTemperatureLogitsWarper(LogitsWarper): The value used to module the logits distribution. """ - def __init__( - self, temperature: List[float], dtype: torch.dtype, device: torch.device - ): + def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): self.temperature = temperature - self.temperature_tensor = torch.tensor( - temperature, dtype=dtype, device=device - ).unsqueeze(1) + self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: scores.div_(self.temperature_tensor) @@ -182,9 +170,7 @@ def __init__( min_tokens_to_keep: int = 1, ): self.top_p = top_p - self.top_p_opposite = 1 - torch.tensor( - top_p, dtype=dtype, device=device - ).unsqueeze(1) + self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @@ -198,7 +184,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = probs <= self.top_p_opposite # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( @@ -251,9 +237,9 @@ def __init__( disabled = [x == 0 for x in top_k] if any(disabled): - self.top_k_disabled_mask = torch.tensor( - disabled, dtype=torch.bool, device=device - ).view(-1, 1) + self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view( + -1, 1 + ) else: self.top_k_disabled_mask = None @@ -357,9 +343,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso if self.disabled_mask is not None: last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) - sorted_indices_to_remove = sorted_scores > sorted_scores.gather( - 1, last_ind.view(-1, 1) - ) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 @@ -379,9 +363,7 @@ def filter(self, indices): self.mass_tensor = self.mass_tensor[indices] if self.disabled_mask is not None: - self.disabled_mask = ( - self.disabled_mask[indices] if any(disabled) else None - ) + self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None return self return None @@ -403,7 +385,7 @@ def __init__( def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: for i, processor in self.processors.items(): - scores[i: i + 1] = processor(input_ids[i: i + 1], scores[i: i + 1]) + scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) return scores def filter(self, indices): @@ -437,7 +419,7 @@ def __init__( def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: for i, processor in enumerate(self.sequence_processors): if processor is not None: - scores[i:i + 1] = processor(input_ids[i:i + 1], scores[i:i + 1]) + scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) return scores def filter(self, indices): @@ -445,7 +427,7 @@ def filter(self, indices): if any([x is not None for x in self.sequence_processors]): return self return None - + @classmethod def from_schemas( cls, diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index e2ccfccfa..1c3819cb7 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -6,7 +6,7 @@ import torch from lorax_server.pb.generate_pb2 import ( - AdapterParameters, + AdapterParameters, MajoritySignMethod as MajoritySignMethodEnum, MergeStrategy as MergeStrategyEnum, ) @@ -17,7 +17,9 @@ from lorax_server.utils.adapter import ModuleMap -def _apply_weights(tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor) -> torch.Tensor: +def _apply_weights( + tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor +) -> torch.Tensor: if isinstance(tensors, torch.Tensor): t = tensors else: @@ -49,16 +51,18 @@ class TiesMerge(MergeStrategy): def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): self.density = density self.majority_sign_method = majority_sign_method - + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="magnitude") for tensor in task_tensors] task_tensors = torch.stack(task_tensors, dim=0) # elect sign before applying weights - majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=self.majority_sign_method) + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) weighted_task_tensors = _apply_weights(task_tensors, weights) - + # disjoint merge return disjoint_merge(weighted_task_tensors, majority_sign_mask) @@ -66,10 +70,12 @@ def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torc class DareLinearMerge(MergeStrategy): def __init__(self, density: float, **kwargs): self.density = density - + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify - task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors + ] weighted_task_tensors = _apply_weights(task_tensors, weights) return weighted_task_tensors.sum(dim=0) @@ -81,13 +87,17 @@ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify - task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors + ] task_tensors = torch.stack(task_tensors, dim=0) # elect sign before applying weights - majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=self.majority_sign_method) + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) weighted_task_tensors = _apply_weights(task_tensors, weights) - + # disjoint merge mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) return mixed_task_tensors @@ -115,7 +125,9 @@ def merge_adapters( merge_config = { "density": merge_params.density, - "majority_sign_method": MajoritySignMethodEnum.Name(merge_params.majority_sign_method).lower(), + "majority_sign_method": MajoritySignMethodEnum.Name( + merge_params.majority_sign_method + ).lower(), } merge_strategy = strategy_registry[strategy_name](**merge_config) @@ -159,7 +171,7 @@ def _validate_lora_configs(lora_configs: List["LoraConfig"]): ranks = set(lora_config.r for lora_config in lora_configs) if len(ranks) > 1: raise ValueError(f"unable to merge adapters, lora configs have different ranks: {ranks}") - + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): raise ValueError("unable to merge adapters, lora configs have no target modules") @@ -168,9 +180,9 @@ def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": merged_lora_config = copy.copy(lora_configs[0]) # merge target modules as a union operation - merged_target_modules = sorted(set( - module for lora_config in lora_configs for module in lora_config.target_modules - )) + merged_target_modules = sorted( + set(module for lora_config in lora_configs for module in lora_config.target_modules) + ) merged_lora_config.target_modules = merged_target_modules return merged_lora_config diff --git a/server/lorax_server/utils/merges/utils.py b/server/lorax_server/utils/merges/utils.py index 88e2a2989..d9ad3278a 100644 --- a/server/lorax_server/utils/merges/utils.py +++ b/server/lorax_server/utils/merges/utils.py @@ -54,7 +54,10 @@ def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch def prune( - tensor: torch.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False + tensor: torch.Tensor, + density: float, + method: Literal["magnitude", "random"], + rescale: bool = False, ) -> torch.Tensor: """ Prune the values of task tensors based on the `method`. @@ -77,7 +80,9 @@ def prune( raise ValueError(f"Unknown method {method}") -def calculate_majority_sign_mask(tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"): +def calculate_majority_sign_mask( + tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" +): """ Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. diff --git a/server/lorax_server/utils/paged_attn.py b/server/lorax_server/utils/paged_attn.py index 6b9d4b28e..f4c52d58e 100644 --- a/server/lorax_server/utils/paged_attn.py +++ b/server/lorax_server/utils/paged_attn.py @@ -14,10 +14,10 @@ def reshape_and_cache( - key: torch.Tensor, # [num_tokens, num_heads, head_size] - value: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] slot_mapping: torch.Tensor, # [num_tokens] ): cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) @@ -25,20 +25,20 @@ def reshape_and_cache( # Source: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/attention.py def single_query_cached_kv_attention( - output: torch.Tensor, # [num_tokens, num_heads, head_size] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + output: torch.Tensor, # [num_tokens, num_heads, head_size] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] kv_head_mapping: torch.Tensor, softmax_scale: float, - block_tables: torch.Tensor, # [num_blocks, block_size] + block_tables: torch.Tensor, # [num_blocks, block_size] input_lengths: torch.Tensor, # [num_blocks] max_s: int, ): block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -91,4 +91,4 @@ def single_query_cached_kv_attention( block_size, max_s, None, - ) \ No newline at end of file + ) diff --git a/server/lorax_server/utils/segments.py b/server/lorax_server/utils/segments.py index 841ee3f6f..465f9da2e 100644 --- a/server/lorax_server/utils/segments.py +++ b/server/lorax_server/utils/segments.py @@ -38,7 +38,7 @@ def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): # positions by the value of the last segment in the previous batch to account for # the concatenation. adapter_segments = adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] - + if self.adapter_segment_indices and self.adapter_segment_indices[-1] == segment_indices[0]: # If the last segment in the previous batch is the same as the first segment in this batch, # then we merge them together into a single segment. In effect, this means removing it from @@ -46,9 +46,9 @@ def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): # end index from the previous batch. segment_indices = segment_indices[1:] self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] - + self.adapter_segment_indices.extend(segment_indices) self.adapter_segment_tensors.append(adapter_segments) - + def build(self) -> Tuple[torch.Tensor, List[int]]: return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 9c399c7dc..0b19f9fe3 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -8,6 +8,7 @@ try: import punica_kernels as _kernels + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) except ImportError: warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") @@ -29,7 +30,7 @@ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" if not has_sgmv(): return t - + # tensor parallelism will result in effective rank being divided by world_size, # so we need to scale the min rank to offset that effect min_rank = MIN_SGMV_RANK * world_size @@ -38,8 +39,9 @@ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: # otherwise, pad to the nearest multiple of the block size current_rank = t.size(dim) target_rank = ( - min_rank if current_rank <= min_rank else - (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE + min_rank + if current_rank <= min_rank + else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE ) if current_rank == target_rank: return t @@ -94,7 +96,7 @@ def add_lora_sgmv_cutlass( # Custom SGMV shrink only supports rank 16, 32, 64, 128 _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank) return - + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) @@ -135,7 +137,9 @@ def get_tmp_expand_size(size: int) -> int: return _kernels.sgmv_cutlass_tmp_size(size) -def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: +def get_tmp_tensors( + nsegments: int, lora_rank: int, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: if use_cutlass_shrink(lora_rank): tmp = get_tmp_tensor_for_size(nsegments, device) return tmp, tmp diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index 4824c6dea..5606a7c7e 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -4,7 +4,16 @@ import requests -from .hub import EntryNotFoundError, LocalEntryNotFoundError, RevisionNotFoundError, get_hub_model_local_dir, weight_files, download_weights, weight_hub_files, HubModelSource +from .hub import ( + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, + get_hub_model_local_dir, + weight_files, + download_weights, + weight_hub_files, + HubModelSource, +) from .local import LocalModelSource, get_model_local_dir from .s3 import S3ModelSource, get_s3_model_local_dir, _get_bucket_and_model_id @@ -30,7 +39,9 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_URL_ENDPOINT.format(name) elif len(name_components) == 2: name, version = name_components - url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) + url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format( + name, version + ) else: raise ValueError(f"Invalid model id {model_id}") resp = requests.get(url, headers=headers) @@ -40,7 +51,13 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: # TODO(travis): refactor into registry pattern -def get_model_source(source: str, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None): +def get_model_source( + source: str, + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + api_token: Optional[str] = None, +): if source == HUB: return HubModelSource(model_id, revision, extension, api_token) elif source == S3: @@ -85,4 +102,4 @@ def get_local_dir(model_id: str, source: str): "get_hub_model_local_dir", "get_s3_model_local_dir", "map_pbase_model_id_to_s3", -] \ No newline at end of file +] diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index b32155c63..39ff56203 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -26,7 +26,10 @@ def get_hub_model_local_dir(model_id: str) -> Path: def weight_hub_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + api_token: Optional[str] = None, ) -> List[str]: """Get the weights filenames on the hub""" api = HfApi(token=api_token) @@ -50,9 +53,11 @@ def weight_hub_files( return filenames - def weight_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + api_token: Optional[str] = None, ) -> List[Path]: """Get the local files""" # Local model @@ -74,27 +79,21 @@ def weight_files( # Change pytorch extension to safetensors extension # It is possible that we have safetensors weights locally even though they are not on the # hub if we converted weights locally without pushing them - filenames = [ - f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames - ] + filenames = [f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames] if WEIGHTS_CACHE_OVERRIDE is not None: files = [] for filename in filenames: p = Path(WEIGHTS_CACHE_OVERRIDE) / filename if not p.exists(): - raise FileNotFoundError( - f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." - ) + raise FileNotFoundError(f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}.") files.append(p) return files repo_cache = get_hub_model_local_dir(model_id) files = [] for filename in filenames: - cache_file = try_to_load_from_cache( - repo_cache, revision=revision, filename=filename - ) + cache_file = try_to_load_from_cache(repo_cache, revision=revision, filename=filename) if cache_file is None: raise LocalEntryNotFoundError( f"File {filename} of model {model_id} not found in " @@ -107,7 +106,10 @@ def weight_files( def download_weights( - filenames: List[str], model_id: str, revision: Optional[str] = None, api_token: Optional[str] = None + filenames: List[str], + model_id: str, + revision: Optional[str] = None, + api_token: Optional[str] = None, ) -> List[Path]: """Download the safetensors files from the hub""" @@ -158,7 +160,13 @@ def download_file(filename, tries=5, backoff: int = 5): class HubModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + api_token: Optional[str] = None, + ): self.model_id = model_id self.revision = revision self.extension = extension @@ -182,7 +190,7 @@ def download_weights(self, filenames): def download_model_assets(self): filenames = self.remote_weight_files() return self.download_weights(filenames) - + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: try: return Path(hf_hub_download(self.model_id, revision=None, filename=filename)) diff --git a/server/lorax_server/utils/sources/local.py b/server/lorax_server/utils/sources/local.py index cd55d2e5b..7006ecaef 100644 --- a/server/lorax_server/utils/sources/local.py +++ b/server/lorax_server/utils/sources/local.py @@ -22,13 +22,15 @@ def get_model_local_dir(model_id: str) -> Path: if os.path.isabs(model_id): return Path(model_id) - + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / model_id return repo_cache class LocalModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors"): + def __init__( + self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors" + ): if len(model_id) < 5: raise ValueError(f"model_id '{model_id}' is too short for prefix filtering") @@ -56,11 +58,9 @@ def weight_files(self, extension: str = None): f"No local weights found in {model_id} with extension {extension}" ) return local_files - - raise FileNotFoundError( - f"No local weights found in {model_id} with extension {extension}" - ) - + + raise FileNotFoundError(f"No local weights found in {model_id} with extension {extension}") + def download_weights(self, filenames: List[str]): return [] @@ -69,13 +69,11 @@ def download_model_assets(self): def get_local_path(self, model_id: str) -> Path: return get_model_local_dir(model_id) - + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: path = get_model_local_dir(self.model_id) / filename if not path.exists(): if ignore_errors: return None - raise FileNotFoundError( - f"File {filename} of model {self.model_id} not found in {path}" - ) + raise FileNotFoundError(f"File {filename} of model {self.model_id} not found in {path}") return path diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index 9ef33f590..90946c3de 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -33,7 +33,7 @@ def _get_bucket_and_model_id(model_id: str) -> Tuple[str, str]: ) bucket_name, model_id = model_id_no_protocol.split("/", 1) return bucket_name, model_id - + bucket = os.getenv("PREDIBASE_MODEL_BUCKET") if not bucket: # assume that the id preceding the first slash is the bucket name @@ -43,12 +43,12 @@ def _get_bucket_and_model_id(model_id: str) -> Tuple[str, str]: f"model_id should be of the form `bucket_name/model_id` " f"if PREDIBASE_MODEL_BUCKET environment variable is not set" ) - + bucket_name, model_id = model_id.split("/", 1) return bucket_name, model_id - + return bucket, model_id - + def _get_bucket_resource(bucket_name: str) -> "Bucket": """Get the s3 client""" @@ -61,21 +61,17 @@ def _get_bucket_resource(bucket_name: str) -> "Bucket": S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) R2_ACCOUNT_ID = os.environ.get("R2_ACCOUNT_ID", None) - + if R2_ACCOUNT_ID: - s3 = boto3.resource('s3', - endpoint_url = f'https://{R2_ACCOUNT_ID}.r2.cloudflarestorage.com', - config=config - ) + s3 = boto3.resource( + "s3", endpoint_url=f"https://{R2_ACCOUNT_ID}.r2.cloudflarestorage.com", config=config + ) return s3.Bucket(bucket_name) elif S3_ENDPOINT_URL: - s3 = boto3.resource('s3', - endpoint_url = f'{S3_ENDPOINT_URL}', - config=config - ) + s3 = boto3.resource("s3", endpoint_url=f"{S3_ENDPOINT_URL}", config=config) return s3.Bucket(bucket_name) else: - s3 = boto3.resource('s3', config=config) + s3 = boto3.resource("s3", config=config) return s3.Bucket(bucket_name) @@ -85,12 +81,12 @@ def get_s3_model_local_dir(model_id: str): return repo_cache -def weight_s3_files( - bucket: Any, model_id: str, extension: str = ".safetensors" -) -> List[str]: +def weight_s3_files(bucket: Any, model_id: str, extension: str = ".safetensors") -> List[str]: """Get the weights filenames from s3""" model_files = bucket.objects.filter(Prefix=model_id) - filenames = [f.key.removeprefix(model_id).lstrip("/") for f in model_files if f.key.endswith(extension)] + filenames = [ + f.key.removeprefix(model_id).lstrip("/") for f in model_files if f.key.endswith(extension) + ] if not filenames: raise EntryNotFoundError( f"No {extension} weights found for model {model_id}", @@ -100,9 +96,13 @@ def weight_s3_files( def download_files_from_s3( - bucket: Any, filenames: List[str], model_id: str, revision: str = "", + bucket: Any, + filenames: List[str], + model_id: str, + revision: str = "", ) -> List[Path]: """Download the safetensors files from the s3""" + def download_file(filename): repo_cache = get_s3_model_local_dir(model_id) local_file = try_to_load_from_cache(repo_cache, revision, filename) @@ -121,7 +121,7 @@ def download_file(filename): # TODO: add support for revision logger.info( f"Downloaded {local_file_path} in {timedelta(seconds=int(time.time() - start_time))}." - ) + ) if not local_file_path.is_file(): raise FileNotFoundError(f"File {local_file_path} not found") return local_file_path @@ -168,16 +168,12 @@ def weight_files_s3( # Change pytorch extension to safetensors extension # It is possible that we have safetensors weights locally even though they are not on the # hub if we converted weights locally without pushing them - filenames = [ - f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames - ] + filenames = [f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames] repo_cache = get_s3_model_local_dir(model_id) files = [] for filename in filenames: - cache_file = try_to_load_from_cache( - repo_cache, revision, filename - ) + cache_file = try_to_load_from_cache(repo_cache, revision, filename) if cache_file is None: raise LocalEntryNotFoundError( f"File {filename} of model {model_id} not found in " @@ -213,7 +209,9 @@ def download_model_from_s3(bucket: Any, model_id: str, extension: str = ".safete class S3ModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors"): + def __init__( + self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors" + ): if len(model_id) < 5: raise ValueError(f"model_id '{model_id}' is too short for prefix filtering") @@ -223,7 +221,7 @@ def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = self.revision = revision self.extension = extension self.bucket = _get_bucket_resource(bucket) - + @property def api_token(self) -> Optional[str]: return None @@ -235,7 +233,7 @@ def remote_weight_files(self, extension: str = None): def weight_files(self, extension: str = None): extension = extension or self.extension return weight_files_s3(self.bucket, self.model_id, self.revision, extension) - + def download_weights(self, filenames: List[str]): return download_files_from_s3(self.bucket, filenames, self.model_id, self.revision) @@ -245,7 +243,7 @@ def download_model_assets(self): def get_local_path(self, model_id: str): _, model_id = _get_bucket_and_model_id(model_id) return get_s3_model_local_dir(model_id) - + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: filenames = [filename] try: diff --git a/server/lorax_server/utils/sources/source.py b/server/lorax_server/utils/sources/source.py index ff14994a5..58896203f 100644 --- a/server/lorax_server/utils/sources/source.py +++ b/server/lorax_server/utils/sources/source.py @@ -54,28 +54,28 @@ def remote_weight_files(self, extension: str = None): @abstractmethod def weight_files(self, extension: str = None) -> List[Path]: pass - + @abstractmethod def download_weights(self, filenames: List[str]): pass - + @abstractmethod def download_model_assets(self): - """ The reason we need this function is that for s3 - we need to download all the model files whereas for - hub we only need to download the weight files. And maybe - for other future sources we might need something different. + """The reason we need this function is that for s3 + we need to download all the model files whereas for + hub we only need to download the weight files. And maybe + for other future sources we might need something different. So this function will take the necessary steps to download - the needed files for any source """ + the needed files for any source""" pass - + @abstractmethod def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: pass - + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: raise NotImplementedError - + def get_weight_bytes(self) -> int: total_size = 0 for path in self.weight_files(): @@ -86,21 +86,21 @@ def get_weight_bytes(self) -> int: st = os.stat(fname) if st.st_size < 8: raise RuntimeError(f"Length of safetensor file less than 8 bytes: {fname}") - + with open(fname, "rb") as f: # read header size - b8 = f.read(8) + b8 = f.read(8) if len(b8) != 8: raise RuntimeError(f"Failed to read first 8 bytes of safetensor file: {fname}") - - headerlen = int.from_bytes(b8, 'little', signed=False) + + headerlen = int.from_bytes(b8, "little", signed=False) if 8 + headerlen > st.st_size: raise RuntimeError(f"Header extends past end of file: {fname}") - + hdrbuf = f.read(headerlen) header = json.loads(hdrbuf) - metadata = header.get('__metadata__', {}) - total_size_bytes = metadata.get('total_size') + metadata = header.get("__metadata__", {}) + total_size_bytes = metadata.get("total_size") if total_size_bytes is None: # Fallback to determining this value from the data offsets min_data_offset = None @@ -108,8 +108,8 @@ def get_weight_bytes(self) -> int: for v in header.values(): if not isinstance(v, dict): continue - - data_offsets = v.get('data_offsets') + + data_offsets = v.get("data_offsets") if data_offsets is None: continue @@ -117,12 +117,12 @@ def get_weight_bytes(self) -> int: min_data_offset = min(min_data_offset, data_offsets[0]) else: min_data_offset = data_offsets[0] - + if max_data_offset is not None: max_data_offset = max(max_data_offset, data_offsets[1]) else: max_data_offset = data_offsets[1] - + if min_data_offset is None or max_data_offset is None: # Fallback to determining total bytes from file size total_size_bytes = st.st_size @@ -130,9 +130,9 @@ def get_weight_bytes(self) -> int: total_size_bytes = max_data_offset - min_data_offset total_size += total_size_bytes - + return total_size - + def load_config(self) -> AdapterConfig: from lorax_server.adapters import load_adapter_config diff --git a/server/lorax_server/utils/tokenizer.py b/server/lorax_server/utils/tokenizer.py index 4a051861d..79437f357 100644 --- a/server/lorax_server/utils/tokenizer.py +++ b/server/lorax_server/utils/tokenizer.py @@ -8,11 +8,13 @@ class TokenizerManager: def __init__(self): self.tokenizers = {} - + def add_tokenizer(self, adapter_idx: int, tokenizer: PreTrainedTokenizerBase): self.tokenizers[adapter_idx] = tokenizer - def get_tokenizer(self, adapter_idx: int, default: PreTrainedTokenizerBase) -> Optional[PreTrainedTokenizerBase]: + def get_tokenizer( + self, adapter_idx: int, default: PreTrainedTokenizerBase + ) -> Optional[PreTrainedTokenizerBase]: return self.tokenizers.get(adapter_idx, default) def get_inputs( @@ -24,5 +26,7 @@ def get_inputs( if r.apply_chat_template: inputs = json.loads(inputs) tokenizer = self.get_tokenizer(r.adapter_index, base_tokenizer) - inputs = tokenizer.apply_chat_template(inputs, add_generation_prompt=True, tokenize=False) + inputs = tokenizer.apply_chat_template( + inputs, add_generation_prompt=True, tokenize=False + ) return inputs diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 808ec53ff..aaa01ddc3 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -18,7 +18,8 @@ HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousProcessorWrapper, - HeterogeneousSchemaLogitsProcessor, OutlinesLogitsProcessor, + HeterogeneousSchemaLogitsProcessor, + OutlinesLogitsProcessor, ) @@ -58,9 +59,7 @@ def __init__( device: str = "cpu", tokenizer: Optional[PreTrainedTokenizerBase] = None, ): - self.watermark_processor = ( - WatermarkLogitsProcessor(device=device) if watermark else None - ) + self.watermark_processor = WatermarkLogitsProcessor(device=device) if watermark else None self.repetition_processor = ( RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty @@ -68,9 +67,7 @@ def __init__( ) self.schema_processor = ( - OutlinesLogitsProcessor(schema, tokenizer) - if schema and tokenizer - else None + OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None ) has_warpers = ( @@ -196,9 +193,7 @@ def from_pb( pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": - stop_sequence_criterias = [ - StopSequenceCriteria(sequence) for sequence in pb.stop_sequences - ] + stop_sequence_criterias = [StopSequenceCriteria(sequence) for sequence in pb.stop_sequences] return StoppingCriteria( tokenizer.eos_token_id, stop_sequence_criterias, @@ -268,9 +263,7 @@ def __init__( ) self.repetition_processor = ( - HeterogeneousRepetitionPenaltyLogitsProcessor( - repetition_penalty, dtype, device - ) + HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, dtype, device) if any([x != 1.0 for x in repetition_penalty]) else None ) @@ -290,12 +283,8 @@ def __init__( ) if any([x != 1.0 for x in temperature]): - do_sample = [ - sample or x != 1.0 for x, sample in zip(temperature, do_sample) - ] - warpers.append( - HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) - ) + do_sample = [sample or x != 1.0 for x, sample in zip(temperature, do_sample)] + warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)) if any([x != 0 for x in top_k]): do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] @@ -323,10 +312,10 @@ def __init__( def __call__( self, - input_ids: torch.Tensor, - scores: torch.Tensor, - speculate: int, - speculated_ids: Optional[torch.Tensor] = None, + input_ids: torch.Tensor, + scores: torch.Tensor, + speculate: int, + speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ @@ -349,7 +338,7 @@ def __call__( B = scores.shape[0] S = 1 scores = scores.view(B, S, -1) - + next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) for j in range(S): scores_j = scores[:, j] @@ -366,8 +355,8 @@ def __call__( next_ids_j = self.choice(scores_j) scores[:, j] = scores_j next_ids[:, j] = next_ids_j - - next_ids = next_ids.view(B*S) + + next_ids = next_ids.view(B * S) scores = scores.view(B * S, -1) if speculated_ids is not None: @@ -376,7 +365,7 @@ def __call__( S = speculated_ids.shape[1] + 1 indices = [] for i in range(B): - next_ids_i = next_ids[i*S: (i + 1)*S] + next_ids_i = next_ids[i * S : (i + 1) * S] speculated_ids_i = speculated_ids[i] validate_speculative = next_ids_i[:-1] == speculated_ids_i index = i * S @@ -392,7 +381,9 @@ def __call__( break accepted_ids.append(accepted) - accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) + accepted_ids = torch.tensor( + accepted_ids, device=input_ids.device, dtype=input_ids.dtype + ) next_ids = next_ids[indices] scores = scores[indices] indices = torch.arange(B, device=input_ids.device) * S @@ -401,9 +392,9 @@ def __call__( else: accepted_ids = torch.ones_like(next_ids) - next_logprobs = torch.gather( - torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) - ).view(-1) + next_logprobs = torch.gather(torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)).view( + -1 + ) speculative_ids = None if speculate > 0: @@ -430,7 +421,7 @@ def filter(self, indices): if self.repetition_processor is not None: self.repetition_processor = self.repetition_processor.filter(indices) - + if self.schema_processor is not None: self.schema_processor = self.schema_processor.filter(indices) @@ -552,10 +543,10 @@ def filter(self, indices): def ngram_speculate( - input_ids: torch.Tensor, - next_ids: torch.Tensor, - accepted_ids: torch.Tensor, - speculate: int, + input_ids: torch.Tensor, + next_ids: torch.Tensor, + accepted_ids: torch.Tensor, + speculate: int, ) -> torch.Tensor: # Inspired by TGI implementation of: # https://github.com/apoorvumang/prompt-lookup-decoding @@ -567,7 +558,9 @@ def ngram_speculate( # Speculate out from the last match by the number of speculative tokens `speculate` # Clamp the indices to the maximum length of the input_ids to prevent out-of-bound errors - all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=input_ids.device) + all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange( + speculate, device=input_ids.device + ) all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) # Gather the speculative tokens from the input_ids to form a [B, S] tensor diff --git a/server/lorax_server/utils/watermark.py b/server/lorax_server/utils/watermark.py index df7b90e39..dbab45ff6 100644 --- a/server/lorax_server/utils/watermark.py +++ b/server/lorax_server/utils/watermark.py @@ -39,9 +39,7 @@ def __init__( def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): if isinstance(input_ids, list): - assert ( - len(input_ids) >= 1 - ), "requires at least a 1 token prefix sequence to seed rng" + assert len(input_ids) >= 1, "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1] else: assert len(input_ids) == 1 @@ -67,9 +65,7 @@ def _get_greenlist_ids( return greenlist_ids @staticmethod - def _calc_greenlist_mask( - scores: torch.FloatTensor, greenlist_token_ids - ) -> torch.BoolTensor: + def _calc_greenlist_mask(scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: green_tokens_mask = torch.zeros_like(scores) green_tokens_mask[-1, greenlist_token_ids] = 1 final_mask = green_tokens_mask.bool() @@ -85,9 +81,7 @@ def _bias_greenlist_logits( def __call__( self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor ) -> torch.FloatTensor: - greenlist_ids = self._get_greenlist_ids( - input_ids, scores.shape[-1], scores.device - ) + greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device) green_tokens_mask = self._calc_greenlist_mask( scores=scores, greenlist_token_ids=greenlist_ids ) diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 5307d5db6..255e63a3f 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -56,6 +56,7 @@ class Weights(AbstractWeights): process_group: The process group for distributed training. _handles (Dict[str, Any]): Dictionary of file handles for opened weight files. """ + def __init__( self, filenames: List[Path], @@ -77,7 +78,7 @@ def __init__( f"Key {k} was found in multiple adapter files: {filename} and {routing[k]}" ) routing[k] = filename - + # set of keys that point to adapter files. Duplicates for these keys found # in main model files will be overridden. adapter_routes = set(routing.keys()) @@ -86,7 +87,9 @@ def __init__( with safe_open(filename, framework="pytorch") as f: for k in f.keys(): if k in adapter_routes: - logger.debug(f"Overriding main model weights with adapter weights for key: {k}") + logger.debug( + f"Overriding main model weights with adapter weights for key: {k}" + ) elif k in routing: raise RuntimeError( f"Key {k} was found in multiple non-adapter files: {filename} and {routing[k]}" @@ -141,7 +144,9 @@ def get_tensor(self, tensor_name: str): tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None): + def get_partial_sharded( + self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None + ): """Loads tensor with the given name and shards it along the given dimension. The optional range argument can be used to load and split on only a subset of the tensor. @@ -190,7 +195,7 @@ def get_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim, range=range) - + def get_sharded_prefix(self, module_name: str, prefix: Union[str, Tuple], dim: int): if isinstance(prefix, str): return self.get_sharded(f"{prefix}.{module_name}", dim=dim) @@ -198,27 +203,21 @@ def get_sharded_prefix(self, module_name: str, prefix: Union[str, Tuple], dim: i assert isinstance(prefix, tuple) assert len(prefix) == 2 return self.get_sharded(f"{prefix[0]}.{module_name}", dim=dim, range=prefix[1]) - + def get_sharded_list(self, module_name: str, prefixes: List[Union[str, Tuple]], dim: int): return [self.get_sharded_prefix(module_name, p, dim=dim) for p in prefixes] def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str, dim: int): if quantize in ["gptq", "awq"]: try: - qweight = torch.cat( - self.get_sharded_list("qweight", prefixes, dim=1), dim=1 - ) + qweight = torch.cat(self.get_sharded_list("qweight", prefixes, dim=1), dim=1) except RuntimeError: raise RuntimeError( "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `lorax-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - qzeros = torch.cat( - self.get_sharded_list("qzeros", prefixes, dim=1), dim=1 - ) - scales = torch.cat( - self.get_sharded_list("scales", prefixes, dim=1), dim=1 - ) + qzeros = torch.cat(self.get_sharded_list("qzeros", prefixes, dim=1), dim=1) + scales = torch.cat(self.get_sharded_list("scales", prefixes, dim=1), dim=1) if quantize == "gptq": # no tensor parallelism, so remove the range if provided prefixes = [p[0] if isinstance(p, tuple) else p for p in prefixes] @@ -370,6 +369,7 @@ def _set_gptq_params(self, model_id): except Exception: pass + def get_start_stop_idxs_for_rank(offset, size, rank, world_size): block_size = size // world_size start = offset + rank * block_size @@ -380,7 +380,7 @@ def get_start_stop_idxs_for_rank(offset, size, rank, world_size): def shard_on_dim(t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup): world_size = process_group.size() rank = process_group.rank() - + size = t.shape[dim] start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) @@ -405,6 +405,7 @@ def download_weights( # Import here after the logger is added to log potential import exceptions from lorax_server import utils from lorax_server.utils import sources + model_source = sources.get_model_source(source, model_id, revision, extension, api_token) # Test if files were already download @@ -458,8 +459,7 @@ def download_weights( # Safetensors final filenames local_st_files = [ - p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" - for p in local_pt_files + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] try: from transformers import AutoConfig diff --git a/server/pyproject.toml b/server/pyproject.toml index 565d1edea..dfe560d4a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -64,3 +64,13 @@ markers = ["private: marks tests as requiring an admin hf token (deselect with ' [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.isort] +profile = "black" +line_length = 100 +force_sort_within_sections = "False" +order_by_type = "False" + +[tool.black] +line-length = 100 +exclude = "./python/"