From 6226aaf2bb073b83bcb0a10e83a69d612b217c80 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Mon, 18 Dec 2023 17:39:44 +0100 Subject: [PATCH] WIP --- clients/python/lorax/client.py | 1 - clients/python/tests/conftest.py | 4 +- .../scripts/dynamic_adapter_loading.py | 29 +- server/awq_kernels/setup.py | 38 ++- server/lorax_server/cli.py | 15 +- server/lorax_server/models/__init__.py | 40 ++- server/lorax_server/models/cache_manager.py | 2 +- server/lorax_server/models/causal_lm.py | 5 +- .../custom_modeling/flash_gpt2_modeling.py | 92 +++--- .../custom_modeling/flash_llama_modeling.py | 100 ++++-- .../custom_modeling/flash_mistral_modeling.py | 104 ++++-- .../custom_modeling/flash_mixtral_modeling.py | 229 ++++++++------ .../custom_modeling/flash_phi_modeling.py | 104 ++++-- .../custom_modeling/flash_qwen_modeling.py | 100 ++++-- .../custom_modeling/flash_rw_modeling.py | 2 +- .../models/custom_modeling/mpt_modeling.py | 9 +- .../models/custom_modeling/neox_modeling.py | 12 +- server/lorax_server/models/flash_causal_lm.py | 296 +++++++++++++----- server/lorax_server/models/flash_gpt2.py | 51 +-- server/lorax_server/models/flash_llama.py | 113 +++++-- server/lorax_server/models/flash_mistral.py | 125 ++++++-- server/lorax_server/models/flash_mixtral.py | 68 ++-- server/lorax_server/models/flash_phi.py | 54 ++-- server/lorax_server/models/flash_qwen.py | 74 +++-- .../lorax_server/models/flash_santacoder.py | 3 - server/lorax_server/models/model.py | 9 +- server/lorax_server/server.py | 40 ++- server/lorax_server/tracing.py | 4 +- server/lorax_server/utils/__init__.py | 3 +- server/lorax_server/utils/adapter.py | 114 ++++--- server/lorax_server/utils/awq/awq.py | 25 +- server/lorax_server/utils/convert.py | 2 +- server/lorax_server/utils/globals.py | 11 + server/lorax_server/utils/gptq/exllamav2.py | 128 ++++---- .../lorax_server/utils/gptq/quant_linear.py | 4 +- server/lorax_server/utils/layers.py | 211 ++++++++----- server/lorax_server/utils/lora.py | 36 ++- server/lorax_server/utils/medusa.py | 59 ++++ server/lorax_server/utils/paged_attn.py | 22 +- server/lorax_server/utils/segments.py | 15 +- server/lorax_server/utils/sgmv.py | 7 +- server/lorax_server/utils/sources/__init__.py | 28 +- server/lorax_server/utils/sources/hub.py | 11 +- server/lorax_server/utils/sources/local.py | 29 +- server/lorax_server/utils/sources/s3.py | 39 ++- server/lorax_server/utils/sources/source.py | 16 +- server/lorax_server/utils/tokens.py | 97 +++++- server/lorax_server/utils/weights.py | 47 ++- server/punica_kernels/setup.py | 11 +- server/tests/utils/test_adapter.py | 53 ++-- server/tests/utils/test_segments.py | 1 - server/tests/utils/test_sgmv.py | 18 +- 52 files changed, 1796 insertions(+), 914 deletions(-) create mode 100644 server/lorax_server/utils/globals.py create mode 100644 server/lorax_server/utils/medusa.py diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index d7c559662..c5ff9a2e5 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -498,7 +498,6 @@ async def generate_stream( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: async with session.post(self.base_url, json=request.dict()) as resp: - if resp.status != 200: raise parse_error(resp.status, await resp.json()) diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 76889680c..c02e63bf3 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -46,6 +46,4 @@ def unsupported_url(base_url, unsupported_model): @pytest.fixture(scope="session") def hf_headers(): - return build_hf_headers( - library_name="lorax-tests", library_version=__version__ - ) + return build_hf_headers(library_name="lorax-tests", library_version=__version__) diff --git a/integration-tests/scripts/dynamic_adapter_loading.py b/integration-tests/scripts/dynamic_adapter_loading.py index 1fa784f6d..1851197e6 100644 --- a/integration-tests/scripts/dynamic_adapter_loading.py +++ b/integration-tests/scripts/dynamic_adapter_loading.py @@ -54,8 +54,8 @@ def query_lorax(args): if adapter_id is not None: # request_params["adapter_source"] = "local" request_params["adapter_id"] = adapter_id - - print("request_params", request_params) + + print("request_params", request_params) url = "http://localhost:8080/generate" headers = { "Content-Type": "application/json", @@ -67,7 +67,7 @@ def query_lorax(args): }, ).encode("utf-8") request = Request(url, headers=headers, data=data) - + try: with urlopen(request) as response: response_body = json.loads(response.read().decode("utf-8")) @@ -78,19 +78,22 @@ def query_lorax(args): print(f"exception in request: {adapter_id}") return adapter_id, 0, None - print("adapter_id: {}\nCompleted {} in {:3f} seconds ({:3f} tokens / s)\n----".format( - adapter_id, - ntokens, - duration_s, - (ntokens / duration_s), - )) + print( + "adapter_id: {}\nCompleted {} in {:3f} seconds ({:3f} tokens / s)\n----".format( + adapter_id, + ntokens, + duration_s, + (ntokens / duration_s), + ) + ) return adapter_id, ntokens, duration_s, response_body["generated_text"] def get_local_path(model_id): model_id = model_id.replace("/", "--") - return f"/data/models--{model_id}/snapshots/834b33af35ff5965ea3e4bc18b51ad5d65da7466" - + return ( + f"/data/models--{model_id}/snapshots/834b33af35ff5965ea3e4bc18b51ad5d65da7466" + ) def main(): @@ -145,7 +148,6 @@ def main(): # # # "hessertaboada/ludwig-webinar", # # # "AmlanSamanta/ludwig-webinar", - # # # None, # # # # download error: bad adapter name @@ -197,6 +199,5 @@ def main(): # print({k: np.mean(v) for k, v in d.items()}) -if __name__ == '__main__': +if __name__ == "__main__": main() - diff --git a/server/awq_kernels/setup.py b/server/awq_kernels/setup.py index a53d073a5..6c626b2f4 100644 --- a/server/awq_kernels/setup.py +++ b/server/awq_kernels/setup.py @@ -3,16 +3,17 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -common_setup_kwargs = { -} +common_setup_kwargs = {} def get_generator_flag(): generator_flag = [] torch_dir = torch.__path__[0] - if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + if os.path.exists( + os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h") + ): generator_flag = ["-DOLD_GENERATOR_PATH"] - + return generator_flag @@ -23,7 +24,9 @@ def get_compute_capabilities(): cc = major * 10 + minor if cc < 75: - raise RuntimeError("GPUs with compute capability less than 7.5 are not supported.") + raise RuntimeError( + "GPUs with compute capability less than 7.5 are not supported." + ) # figure out compute capability compute_capabilities = {75, 80, 86, 89, 90} @@ -34,14 +37,15 @@ def get_compute_capabilities(): return capability_flags + generator_flags = get_generator_flag() arch_flags = get_compute_capabilities() -extra_compile_args={ +extra_compile_args = { "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], "nvcc": [ - "-O3", + "-O3", "-std=c++17", "-DENABLE_BF16", "-U__CUDA_NO_HALF_OPERATORS__", @@ -53,7 +57,9 @@ def get_compute_capabilities(): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - ] + arch_flags + generator_flags + ] + + arch_flags + + generator_flags, } extensions = [ @@ -64,8 +70,9 @@ def get_compute_capabilities(): "awq_cuda/quantization/gemm_cuda_gen.cu", "awq_cuda/layernorm/layernorm.cu", "awq_cuda/position_embedding/pos_encoding_kernels.cu", - "awq_cuda/quantization/gemv_cuda.cu" - ], extra_compile_args=extra_compile_args + "awq_cuda/quantization/gemv_cuda.cu", + ], + extra_compile_args=extra_compile_args, ) ] @@ -75,18 +82,17 @@ def get_compute_capabilities(): [ "awq_cuda/pybind_ft.cpp", "awq_cuda/attention/ft_attention.cpp", - "awq_cuda/attention/decoder_masked_multihead_attention.cu" - ], extra_compile_args=extra_compile_args + "awq_cuda/attention/decoder_masked_multihead_attention.cu", + ], + extra_compile_args=extra_compile_args, ) ) additional_setup_kwargs = { "ext_modules": extensions, - "cmdclass": {'build_ext': BuildExtension} + "cmdclass": {"build_ext": BuildExtension}, } common_setup_kwargs.update(additional_setup_kwargs) -setup( - **common_setup_kwargs -) \ No newline at end of file +setup(**common_setup_kwargs) diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 987f0096c..00911954c 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -82,9 +82,19 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, source, adapter_source + model_id, + adapter_id, + revision, + sharded, + quantize, + dtype, + trust_remote_code, + uds_path, + source, + adapter_source, ) + def _download_weights( model_id: str, revision: Optional[str] = None, @@ -95,6 +105,7 @@ def _download_weights( # Import here after the logger is added to log potential import exceptions from lorax_server import utils from lorax_server.utils import sources + model_source = sources.get_model_source(source, model_id, revision, extension) # Test if files were already download @@ -168,7 +179,7 @@ def _download_weights( discard_names = getattr(class_, "_tied_weights_keys", []) discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) - except Exception as e: + except Exception: discard_names = [] # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files, discard_names) diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index f6a12135b..ac06f57d8 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -1,4 +1,3 @@ -import os import torch from loguru import logger @@ -19,6 +18,7 @@ from lorax_server.models.t5 import T5Sharded from lorax_server.models.gpt_neox import GPTNeoxSharded from lorax_server.utils.sources import get_s3_model_local_dir +from lorax_server.utils.globals import set_speculation_num # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -69,7 +69,7 @@ __all__.append(FlashGPT2) __all__.append(FlashQwen) __all__.append(FlashPhi) - + MISTRAL = True try: from lorax_server.models.flash_mistral import FlashMistral @@ -103,10 +103,11 @@ def get_model( adapter_source: str, ) -> Model: config_dict = None + medusa_id = None if source == "s3": # change the model id to be the local path to the folder so # we can load the config_dict locally - logger.info(f"Using the local files since we are coming from s3") + logger.info("Using the local files since we are coming from s3") model_path = get_s3_model_local_dir(model_id) logger.info(f"model_path: {model_path}") config_dict, _ = PretrainedConfig.get_config_dict( @@ -118,11 +119,18 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) - else: + else: raise ValueError(f"Unknown source {source}") - + model_type = config_dict["model_type"] + if "medusa_num_heads" in config_dict: + medusa_id = model_id + model_id = config_dict["base_model_name_or_path"] + revision = "main" + speculate_medusa = config_dict["medusa_num_heads"] + set_speculation_num(speculate_medusa) + if dtype is None: dtype = torch.float16 elif dtype == "float16": @@ -234,6 +242,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + medusa_id=medusa_id, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) @@ -271,7 +280,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashRWSharded( @@ -300,9 +309,10 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + medusa_id=medusa_id, ) raise NotImplementedError("Mistral model requires flash attention v2") - + if model_type == "mixtral": if MIXTRAL: return FlashMixtral( @@ -314,8 +324,10 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") - + raise NotImplementedError( + "Mixtral models requires flash attention v2, stk and megablocks" + ) + if model_type == "qwen": if FLASH_ATTENTION: return FlashQwen( @@ -328,7 +340,7 @@ def get_model( trust_remote_code=trust_remote_code, ) raise NotImplementedError("Qwen model requires flash attention v2") - + if model_type in ["phi-msft", "phi"]: if FLASH_ATTENTION: return FlashPhi( @@ -367,13 +379,9 @@ def get_model( "gptq quantization is not supported for AutoModel, you can try to quantize it with `lorax-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): - raise ValueError( - "4bit quantization is not supported for AutoModel" - ) + raise ValueError("4bit quantization is not supported for AutoModel") if quantize == "awq": - raise ValueError( - "awq quantization is not supported for AutoModel" - ) + raise ValueError("awq quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( diff --git a/server/lorax_server/models/cache_manager.py b/server/lorax_server/models/cache_manager.py index d6df44bab..4be8b1b95 100644 --- a/server/lorax_server/models/cache_manager.py +++ b/server/lorax_server/models/cache_manager.py @@ -132,4 +132,4 @@ def get_cache_manager() -> CacheManager: if CACHE_MANAGER is None: raise RuntimeError("cache manager was not initialized") - return CACHE_MANAGER \ No newline at end of file + return CACHE_MANAGER diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index d9d09cc40..5a3d226c8 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -1,5 +1,4 @@ import torch -import inspect from dataclasses import dataclass from opentelemetry import trace @@ -99,7 +98,9 @@ def from_pb( ) adapter_indices_list.append(r.adapter_index) - adapter_indices = torch.tensor(adapter_indices_list, dtype=torch.int64, device=device) + adapter_indices = torch.tensor( + adapter_indices_list, dtype=torch.int64, device=device + ) tokenized_inputs = tokenizer( inputs, diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index 0216734d1..75ea97baa 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -30,16 +30,12 @@ from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( - FastLinear, TensorParallelAdapterRowLinear, TensorParallelMultiAdapterLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, FastLayerNorm, - PositionRotaryEmbedding, - get_linear, ) from lorax_server.utils.lora import AdapterBatchData @@ -61,15 +57,22 @@ def load_attention_multi(config, prefix, weights, fan_in_fan_out=False): ) -def load_attention(config, prefix, weights, layer_id, layer_names, fan_in_fan_out=False): - base_layer = load_attention_multi(config, prefix, weights, fan_in_fan_out=fan_in_fan_out) +def load_attention( + config, prefix, weights, layer_id, layer_names, fan_in_fan_out=False +): + base_layer = load_attention_multi( + config, prefix, weights, fan_in_fan_out=fan_in_fan_out + ) projection_size = config.n_embd return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, layer_names, sizes=[ + base_layer, + layer_id, + layer_names, + sizes=[ 3 * projection_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - class FlashGPT2Attention(torch.nn.Module): @@ -79,9 +82,9 @@ def __init__(self, config, prefix, weights, layer_id): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - 1, 1, max_positions, max_positions - ), + torch.tril( + torch.ones((max_positions, max_positions), dtype=torch.bool) + ).view(1, 1, max_positions, max_positions), persistent=False, ) self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) @@ -98,10 +101,10 @@ def __init__(self, config, prefix, weights, layer_id): self.scale_attn_weights = config.scale_attn_weights if self.scale_attn_weights: - self.softmax_scale = self.head_dim ** -0.5 + self.softmax_scale = self.head_dim**-0.5 else: self.softmax_scale = 1.0 - + if config.add_cross_attention: raise ValueError("Cross attention in GPT-2 is not supported.") @@ -110,14 +113,21 @@ def __init__(self, config, prefix, weights, layer_id): self.layer_idx = layer_id self.reorder_and_upcast_attn = config.reorder_and_upcast_attn - self.c_attn = load_attention(config, prefix, weights, layer_id, [ATTN_C_ATTN], fan_in_fan_out=True) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=True, - fan_in_fan_out=True, - ), layer_id, ATTN_C_PROJ, process_group=weights.process_group) + self.c_attn = load_attention( + config, prefix, weights, layer_id, [ATTN_C_ATTN], fan_in_fan_out=True + ) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + fan_in_fan_out=True, + ), + layer_id, + ATTN_C_PROJ, + process_group=weights.process_group, + ) self.pruned_heads = set() @@ -140,7 +150,6 @@ def __init__(self, config, prefix, weights, layer_id): ) self.num_key_value_heads = self.num_heads - def forward( self, hidden_states, @@ -150,7 +159,7 @@ def forward( slots, input_lengths, max_s, - adapter_data + adapter_data, ): qkv = self.c_attn(hidden_states, adapter_data) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -208,25 +217,30 @@ def __init__(self, config, prefix, weights, layer_id): # https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config.n_inner n_inner = config.n_inner if config.n_inner is not None else config.n_embd * 4 self.c_fc = TensorParallelMultiAdapterLinear.load( - c_fc, - layer_id, - [MLP_C_FC], + c_fc, + layer_id, + [MLP_C_FC], sizes=[n_inner], - process_group=weights.process_group + process_group=weights.process_group, ) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=True, - fan_in_fan_out=True, - ), layer_id, MLP_C_PROJ, process_group=weights.process_group) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + fan_in_fan_out=True, + ), + layer_id, + MLP_C_PROJ, + process_group=weights.process_group, + ) self.act = ACT2FN[config.activation_function] def forward( - self, + self, hidden_states: Optional[Tuple[torch.FloatTensor]], adapter_data: AdapterBatchData, ) -> torch.FloatTensor: @@ -253,7 +267,9 @@ def __init__(self, layer_id, config, weights): prefix=f"{prefix}.ln_2", weights=weights, eps=layer_norm_eps ) - self.mlp = GPT2MLP(config, prefix=f"{prefix}.mlp", weights=weights, layer_id=layer_id) + self.mlp = GPT2MLP( + config, prefix=f"{prefix}.mlp", weights=weights, layer_id=layer_id + ) self.process_group = weights.process_group def forward( @@ -395,5 +411,5 @@ def forward( # lm_head reuses the weights of the embedding layer # https://github.com/huggingface/transformers/issues/6291 logits = hidden_states @ self.wte_t - logits = logits[:, :self.transformer.config.vocab_size] + logits = logits[:, : self.transformer.config.vocab_size] return logits diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index da335d98b..ec8802f1a 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -41,7 +41,17 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, + AdapterBatchData, +) class LlamaConfig(PretrainedConfig): @@ -147,17 +157,21 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - + def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[ + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -234,12 +248,17 @@ def __init__( self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -247,9 +266,9 @@ def __init__( def get_query_key_value_weights(self, clone=True): """Gets the query, key, and value weights from the attention layer. - + If `clone`, then the weights are cloned before being returned. - + NOTE: if not `clone`, then the weights are returned as views, meaning that changes to the weights will be reflected in the attention layer. """ @@ -327,7 +346,9 @@ def forward( max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class LlamaMLP(nn.Module): @@ -353,18 +374,27 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [GATE_PROJ, UP_PROJ], sizes=[ + gate_up_proj, + layer_id, + [GATE_PROJ, UP_PROJ], + sizes=[ config.intermediate_size, config.intermediate_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ), layer_id, DOWN_PROJ, process_group=weights.process_group) + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -372,7 +402,9 @@ def __init__(self, prefix, config, weights, layer_id): def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): @@ -380,9 +412,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = FlashLlamaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.input_layernorm = LlamaRMSNorm( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -509,11 +546,16 @@ def __init__(self, config, weights): super().__init__() self.model = FlashLlamaModel(config, weights) - self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = TensorParallelAdapterRowLinear.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 1c0902265..536ae9f7f 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -24,7 +24,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig -from typing import Optional, List, Set, Tuple +from typing import Optional, List, Tuple # Flash attention imports import dropout_layer_norm @@ -42,7 +42,17 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, + AdapterBatchData, +) if not HAS_FLASH_ATTN_V2: raise ImportError("Mistral model requires flash attn v2") @@ -153,17 +163,21 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - + def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[ + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -243,12 +257,17 @@ def __init__( self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -256,9 +275,9 @@ def __init__( def get_query_key_value_weights(self, clone=True): """Gets the query, key, and value weights from the attention layer. - + If `clone`, then the weights are cloned before being returned. - + NOTE: if not `clone`, then the weights are returned as views, meaning that changes to the weights will be reflected in the attention layer. """ @@ -343,7 +362,9 @@ def forward( max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class MistralMLP(nn.Module): @@ -369,18 +390,27 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [GATE_PROJ, UP_PROJ], sizes=[ + gate_up_proj, + layer_id, + [GATE_PROJ, UP_PROJ], + sizes=[ config.intermediate_size, config.intermediate_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=False, - ), layer_id, DOWN_PROJ, process_group=weights.process_group) + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -388,7 +418,9 @@ def __init__(self, prefix, config, weights, layer_id): def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): @@ -396,9 +428,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = MistralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = MistralMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.input_layernorm = MistralRMSNorm( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -529,11 +566,16 @@ def __init__(self, config, weights): super().__init__() self.model = MistralModel(config, weights) - self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = TensorParallelAdapterRowLinear.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) self.max_past = config.sliding_window if self.max_past is None: raise ValueError("max_past cannot be None") @@ -576,4 +618,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states, adapter_data) - return logits \ No newline at end of file + return logits diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index a878b0471..5a5c11172 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -74,28 +74,28 @@ class MixtralConfig(PretrainedConfig): model_type = "mixtral" def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - max_position_embeddings=4096 * 32, - initializer_range=0.02, - rms_norm_eps=1e-05, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - sliding_window=4096, - num_experts_per_tok=2, - num_local_experts=8, - **kwargs, + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + num_experts_per_tok=2, + num_local_experts=8, + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -136,11 +136,15 @@ def load_attention(config, prefix, weights, layer_id): base_layer = load_attention_multi(config, prefix, weights) head_size = config.hidden_size // config.num_attention_heads return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], sizes=[ + base_layer, + layer_id, + [ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ], + sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -193,16 +197,18 @@ def _load_experts(config, prefix, mat, weights): rank = weights.process_group.rank() assert ( - config.intermediate_size % world_size == 0 + config.intermediate_size % world_size == 0 ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" block_size = config.intermediate_size // world_size start = rank * block_size stop = (rank + 1) * block_size - tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), - dtype=weights.dtype, - device=weights.device) + tensor = torch.empty( + (config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) for i in range(config.num_local_experts): slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") @@ -211,7 +217,9 @@ def _load_experts(config, prefix, mat, weights): expert_slice = slice_[:, start:stop].t().contiguous() else: expert_slice = slice_[start:stop] - tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) return tensor @@ -267,6 +275,7 @@ def forward(self, hidden_states, residual=None): return normed_hidden_states, res + class MixtralAttention(torch.nn.Module): """ MixtralAttention module performs attention computation for the Mixtral model. @@ -316,7 +325,7 @@ def __init__( device=weights.device, ) - self.softmax_scale = self.head_size ** -0.5 + self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -325,35 +334,40 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() + config.num_key_value_heads // weights.process_group.size() ) self.query_key_value = load_attention(config, prefix, weights, layer_id) - self.o_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ), layer_id, ATTN_O_PROJ, process_group=weights.process_group) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + ATTN_O_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) def forward( - self, - hidden_states, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - adapter_data, - prefill_cache_indices, + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + prefill_cache_indices, ): """ Performs forward pass of the attention module. @@ -428,7 +442,9 @@ def forward( max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) @torch.jit.script @@ -527,8 +543,9 @@ def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, self.blocking, block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, self.blocking, block_rows, blocks_per_row + ) # For now, use meta init to save the device memory. data = torch.empty( @@ -572,8 +589,7 @@ def indices_and_padded_bins(self, selected_experts: torch.Tensor): # position of each bin. # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, - self.blocking) + padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) # padded_tokens_per_expert => [128, O, 128, ...] # Cumulative selected experts per token @@ -612,8 +628,7 @@ def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: # Permute tokens and pad to prepare expert computation # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, - self.top_k) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) # Create the sparse matrix topology with torch.no_grad(): @@ -655,7 +670,7 @@ def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(x, group=self.process_group) return x.view(*input_shape) - + def dense_forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (sequence_length, model_dim) @@ -757,7 +772,11 @@ def __init__(self, prefix, config: MixtralConfig, weights): ] self.w2 = [ TensorParallelRowLinear.load( - config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False, all_reduce=False, + config, + prefix=f"{prefix}.experts.{i}.w2", + weights=weights, + bias=False, + all_reduce=False, ) for i in range(self.num_experts) ] @@ -813,7 +832,10 @@ def __init__(self, layer_id, config, weights): prefix = f"model.layers.{layer_id}" self.self_attn = MixtralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, ) moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights) @@ -828,19 +850,19 @@ def __init__(self, layer_id, config, weights): ) def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - adapter_data, - prefill_cache_indices, + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + prefill_cache_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -896,17 +918,17 @@ def __init__(self, config, weights): self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - adapter_data: AdapterBatchData, - prefill_cache_indices: Optional[torch.Tensor], + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -943,28 +965,33 @@ def __init__(self, config, weights): super().__init__() self.model = MixtralModel(config, weights) - self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = TensorParallelAdapterRowLinear.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) self.max_past = config.sliding_window if self.max_past is None: raise ValueError("max_past cannot be None") def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - adapter_data: AdapterBatchData, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -990,4 +1017,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states, adapter_data) - return logits \ No newline at end of file + return logits diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 3f3586d7e..69a61630c 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -14,7 +14,6 @@ from torch import nn from transformers.activations import ACT2FN -from transformers.models.phi import PhiConfig from typing import Optional, List, Tuple from lorax_server.utils import flash_attn @@ -36,13 +35,19 @@ ATTN_OUT_PROJ = "mixer.out_proj" MLP_FC1 = "mlp.fc1" MLP_FC2 = "mlp.fc2" - + def load_attention(config, prefix, weights, layer_id, head_dim, n_head, n_head_kv): op_size = head_dim * (n_head + 2 * n_head_kv) - base_layer = load_attention_multi(config, prefix, weights, head_dim, n_head, n_head_kv) + base_layer = load_attention_multi( + config, prefix, weights, head_dim, n_head, n_head_kv + ) return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_WQKV], sizes=[op_size], process_group=weights.process_group + base_layer, + layer_id, + [ATTN_WQKV], + sizes=[op_size], + process_group=weights.process_group, ) @@ -52,7 +57,10 @@ def load_attention_multi(config, prefix, weights, head_dim, n_head, n_head_kv): prefixes=[ (f"{prefix}.Wqkv", (0, head_dim * n_head)), (f"{prefix}.Wqkv", (head_dim * n_head, head_dim * n_head_kv)), - (f"{prefix}.Wqkv", ((head_dim * n_head) + (head_dim * n_head_kv), head_dim * n_head_kv)), + ( + f"{prefix}.Wqkv", + ((head_dim * n_head) + (head_dim * n_head_kv), head_dim * n_head_kv), + ), ], dim=0, weights=weights, @@ -93,17 +101,32 @@ def __init__( ) self.num_key_value_heads = getattr(config, "n_head_kv", None) or self.num_heads - self.Wqkv = load_attention(config, prefix, weights, layer_id, self.head_size, self.num_heads, self.num_key_value_heads) - self.out_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( + self.Wqkv = load_attention( config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=True, - ), layer_id, ATTN_OUT_PROJ, process_group=weights.process_group) + prefix, + weights, + layer_id, + self.head_size, + self.num_heads, + self.num_key_value_heads, + ) + self.out_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=True, + ), + layer_id, + ATTN_OUT_PROJ, + process_group=weights.process_group, + ) # After initializing layers, scale num heads by num shards for use in forward() to split outputs self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = self.num_key_value_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( @@ -172,7 +195,9 @@ def forward( max_s, ) - return self.out_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + return self.out_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class PhiMLP(nn.Module): @@ -200,15 +225,24 @@ def __init__(self, prefix, config, weights, layer_id): out_size = fc1.linear.weight.shape[-1] * weights.process_group.size() self.fc1 = TensorParallelMultiAdapterLinear.load( - fc1, layer_id, [MLP_FC1], sizes=[out_size], process_group=weights.process_group + fc1, + layer_id, + [MLP_FC1], + sizes=[out_size], + process_group=weights.process_group, ) - self.fc2 = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.fc2", - weights=weights, - bias=True, - ), layer_id, MLP_FC2, process_group=weights.process_group) - + self.fc2 = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=True, + ), + layer_id, + MLP_FC2, + process_group=weights.process_group, + ) + def forward(self, hidden_states, adapter_data): hidden_states = self.fc1(hidden_states, adapter_data) hidden_states = self.act(hidden_states) @@ -220,14 +254,19 @@ class FlashPhiLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" - + self.ln = FastLayerNorm.load( prefix=f"{prefix}.ln", weights=weights, eps=config.layer_norm_epsilon ) self.mixer = FlashPhiAttention( - prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.mixer", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = PhiMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.process_group = weights.process_group def forward( @@ -342,12 +381,17 @@ def __init__(self, config, weights): self.ln = FastLayerNorm.load( prefix=f"{prefix}.ln", weights=weights, eps=config.layer_norm_epsilon ) - self.linear = TensorParallelAdapterRowLinear.load(TensorParallelHead.load( - config, - prefix=f"{prefix}.linear", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) - + self.linear = TensorParallelAdapterRowLinear.load( + TensorParallelHead.load( + config, + prefix=f"{prefix}.linear", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) + def forward(self, hidden_states, adapter_data): hidden_states, _ = self.ln(hidden_states) hidden_states = self.linear(hidden_states, adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 1e4a69cca..f3c02a1d0 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -139,15 +139,21 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - + def load_attention(config, prefix, weights, layer_id): - projection_size = (config.hidden_size // config.num_attention_heads) * config.num_attention_heads + projection_size = ( + config.hidden_size // config.num_attention_heads + ) * config.num_attention_heads base_layer = load_attention_multi(config, prefix, weights, projection_size) return TensorParallelMultiAdapterLinear.load( - base_layer, layer_id, [ATTN_C_ATTN], sizes=[ + base_layer, + layer_id, + [ATTN_C_ATTN], + sizes=[ 3 * projection_size, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) @@ -177,7 +183,9 @@ def __init__( self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.projection_size = (self.head_size * config.num_attention_heads) // weights.process_group.size() + self.projection_size = ( + self.head_size * config.num_attention_heads + ) // weights.process_group.size() self.process_group = weights.process_group self.rotary_emb = PositionRotaryEmbedding.static( @@ -199,12 +207,17 @@ def __init__( self.c_attn = load_attention(config, prefix, weights, layer_id) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=False, - ), layer_id, ATTN_C_PROJ, process_group=weights.process_group) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=False, + ), + layer_id, + ATTN_C_PROJ, + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -271,7 +284,9 @@ def forward( max_s, ) - return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + return self.c_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class QwenMLP(nn.Module): @@ -297,18 +312,27 @@ def __init__(self, prefix, config, weights, layer_id): bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, layer_id, [MLP_W2, MLP_W1], sizes=[ + gate_up_proj, + layer_id, + [MLP_W2, MLP_W1], + sizes=[ config.intermediate_size // 2, config.intermediate_size // 2, - ], process_group=weights.process_group + ], + process_group=weights.process_group, ) - self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.c_proj", - weights=weights, - bias=False, - ), layer_id, MLP_C_PROJ, process_group=weights.process_group) + self.c_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=False, + ), + layer_id, + MLP_C_PROJ, + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -316,7 +340,9 @@ def __init__(self, prefix, config, weights, layer_id): def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size // 2) - return self.c_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) + return self.c_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashQwenLayer(nn.Module): @@ -324,9 +350,14 @@ def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" self.attn = FlashQwenAttention( - prefix=f"{prefix}.attn", config=config, weights=weights, layer_id=layer_id, + prefix=f"{prefix}.attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = QwenMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = QwenMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.ln_1 = QwenRMSNorm( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon @@ -368,9 +399,7 @@ def forward( ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.ln_2( - attn_output, res - ) + normed_attn_res_output, attn_res = self.ln_2(attn_output, res) mlp_output = self.mlp(normed_attn_res_output, adapter_data) @@ -384,9 +413,7 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.wte = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights - ) + self.wte = TensorParallelEmbedding(prefix="transformer.wte", weights=weights) self.h = nn.ModuleList( [ FlashQwenLayer( @@ -453,11 +480,16 @@ def __init__(self, config, weights): super().__init__() self.transformer = FlashQwenModel(config, weights) - self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load( - config, - prefix="lm_head", - weights=weights, - ), 0, LM_HEAD, process_group=weights.process_group) + self.lm_head = TensorParallelAdapterRowLinear.load( + TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 2426d754e..bfcb4d856 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -253,7 +253,7 @@ def __init__( if process_group.size() > self.num_groups: raise NotImplementedError( - f"Tensor Parallelism is not implemented for world_size > n groups" + "Tensor Parallelism is not implemented for world_size > n groups" ) if self.num_groups % process_group.size() != 0: raise NotImplementedError( diff --git a/server/lorax_server/models/custom_modeling/mpt_modeling.py b/server/lorax_server/models/custom_modeling/mpt_modeling.py index 7d1ea1f1c..5932d8bcc 100644 --- a/server/lorax_server/models/custom_modeling/mpt_modeling.py +++ b/server/lorax_server/models/custom_modeling/mpt_modeling.py @@ -3,7 +3,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math -import os import warnings from typing import List, Optional, Tuple, Union import torch @@ -178,7 +177,7 @@ def flash_attn_fn( _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: - raise NotImplementedError(f"attn_bias not implemented for flash attn.") + raise NotImplementedError("attn_bias not implemented for flash attn.") (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) @@ -263,9 +262,9 @@ def triton_flash_attn_fn( _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: - raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") + raise NotImplementedError("Dropout not implemented for attn_impl: triton.") if needs_weights: - raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") + raise NotImplementedError("attn_impl: triton cannot return attn weights.") if key_padding_mask is not None: warnings.warn( "Propagating key_padding_mask to the attention module " @@ -957,7 +956,7 @@ def forward( if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError( - f"past_key_values must provide a past_key_value for each attention " + "past_key_values must provide a past_key_value for each attention " + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." ) past_position = past_key_values[0][0].size(1) diff --git a/server/lorax_server/models/custom_modeling/neox_modeling.py b/server/lorax_server/models/custom_modeling/neox_modeling.py index 1ce018ddd..520be62bb 100644 --- a/server/lorax_server/models/custom_modeling/neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/neox_modeling.py @@ -21,24 +21,14 @@ import torch.distributed import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN -from transformers.file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel -from transformers import GPTNeoXConfig from loguru import logger from lorax_server.utils.layers import ( TensorParallelColumnLinear, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 306c2cabf..da9d13b78 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -10,7 +10,6 @@ from dataclasses import dataclass from opentelemetry import trace from peft import LoraConfig -from tqdm import tqdm from transformers import PreTrainedTokenizerBase from typing import Optional, Set, Tuple, List, Type, Union, Dict @@ -30,9 +29,15 @@ from lorax_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map from lorax_server.utils.dist import MEMORY_FRACTION -from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights +from lorax_server.utils.lora import ( + AdapterBatchData, + AdapterBatchMetadata, + BatchedLoraWeights, + MergedLoraWeights, +) from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.weights import shard_on_dim +from lorax_server.utils.globals import get_speculation_num tracer = trace.get_tracer(__name__) @@ -47,6 +52,7 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor position_ids: torch.Tensor + speculation_ids: torch.Tensor # Flash Attention values @@ -131,6 +137,7 @@ def from_pb( needed_blocks_slots = [] start_slots = [] slot_indices = [] + speculation_ids = [] input_lengths = [] prefix_offsets = [] @@ -146,7 +153,7 @@ def from_pb( next_token_chooser_parameters = [] stopping_criterias = [] - + adapter_indices_list = [] adapter_set = set() @@ -197,7 +204,7 @@ def from_pb( # Paged attention # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 + total_tokens = input_length + max_new_tokens - 1 + get_speculation_num() needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -235,9 +242,13 @@ def from_pb( cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) + max_length = max( + max_length, input_length + max_new_tokens + get_speculation_num() + ) - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -277,7 +288,9 @@ def from_pb( ) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) if all_prefill_logprobs: prefill_head_indices = None @@ -326,6 +339,7 @@ def from_pb( adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), + speculation_ids=None, ) @tracer.start_as_current_span("filter") @@ -436,6 +450,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) + speculation_ids = ( + self.speculation_ids[indices] if self.speculation_ids is not None else None + ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -443,7 +460,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) return type(self)( batch_id=self.batch_id, @@ -478,6 +497,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), + speculation_ids=speculation_ids, ) @classmethod @@ -497,6 +517,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks + speculation_length = ( + b.speculation_ids.shape[1] if b.speculation_ids is not None else 0 + ) max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -504,6 +527,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max( input_length + stopping_criteria.max_new_tokens + + speculation_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( b.input_lengths, b.stopping_criterias @@ -525,9 +549,13 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch (total_batch_size, max_length) ) - total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) - - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) adapter_set = set() adapter_segment_builder = SegmentConcatBuilder() @@ -571,13 +599,20 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Copy over adapter indices adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] - adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[ + adapter_start_index:adapter_end_index + ] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) # Update adapter segments - adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -613,6 +648,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + speculation_ids = ( + torch.cat([b.speculation_ids for b in batches], dim=0) + if batches[0].speculation_ids is not None + else None + ) + # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None @@ -651,6 +692,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), + speculation_ids=speculation_ids, ) def __del__(self): @@ -684,7 +726,9 @@ def __init__( # This may be set to False in the subclass constructor self.dynamic_adapter_loading_enabled = True - self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict(BatchedLoraWeights) + self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict( + BatchedLoraWeights + ) super(FlashCausalLM, self).__init__( model=model, @@ -702,38 +746,40 @@ def __init__( @property def supports_adapter_loading(self) -> bool: return False - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: return {} - + @property def adapter_layers(self) -> List[str]: return [] - + def get_num_layers_for_type(self, layer_type: str) -> int: return 0 - + def is_row_parallel(self, layer_type: str) -> bool: return False def load_adapter(self, adapter_id, adapter_source, adapter_index): """Physically loads the adapter weights into the model. - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are merged into the model + adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded + into model. Otherwise, the adapter weights are merged into the model weights on the fly. """ if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") - + if not self.dynamic_adapter_loading_enabled: if adapter_id == BASE_MODEL_ADAPTER_ID: return else: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError( + f"This model was initialized with the adapter {self.adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) # If we are doing dynamic adapter loading, then we need to reset the weights if adapter_id == self.adapter_id: @@ -748,11 +794,17 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index): unused_weight_names = adapter_weight_names.copy() for layer_name in self.adapter_layers: self.load_batched_adapter_weights( - module_map, adapter_config, adapter_index, layer_name, unused_weight_names + module_map, + adapter_config, + adapter_index, + layer_name, + unused_weight_names, ) - + if len(unused_weight_names) > 0: - logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}") + logger.warning( + f"{adapter_id} unused adapter weights: {unused_weight_names}" + ) self.adapter_id = adapter_id @@ -771,35 +823,34 @@ def shard_lora_weights( # [r, hidden_size] weights_b = [ - shard_on_dim(w, dim=1, process_group=self.process_group) - for w in weights_b + shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b ] return weights_a, weights_b def load_batched_adapter_weights( - self, - module_map: Dict[str, Dict], - adapter_config: LoraConfig, - adapter_index: int, + self, + module_map: Dict[str, Dict], + adapter_config: LoraConfig, + adapter_index: int, layer_type: str, unused_weight_names: Set[str], ): nlayers = self.get_num_layers_for_type(layer_type) lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers - + for layer_id in range(nlayers): key = (layer_id, layer_type) weight_name, layer = self.target_to_layer[key] - + base_weight = layer.base_layer.linear.weight base_device = base_weight.device if weight_name not in module_map: # There is no LoRA weight for this layer type in the adapter return - + lora_a, lora_a_name = module_map[weight_name]["lora_A"] lora_a = lora_a.to(base_device, self.dtype) @@ -817,24 +868,27 @@ def load_batched_adapter_weights( lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale q_lora_merged = MergedLoraWeights( - *self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config, + *self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), + adapter_config, ) q_lora_weights = self.batched_lora_weights[layer_type] q_lora_weights.add_adapter(adapter_index, q_lora_merged) - + def offload_adapter(self, adapter_id, adapter_source, adapter_index): """Offloads the adapter weights from GPU to CPU or disk.""" if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") - + if not self.dynamic_adapter_loading_enabled: if adapter_id == BASE_MODEL_ADAPTER_ID: return else: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError( + f"This model was initialized with the adapter {self.adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) if adapter_id == BASE_MODEL_ADAPTER_ID: return @@ -863,7 +917,9 @@ def warmup(self, batch: FlashCausalLMBatch): ) _, batch = self.generate_token(batch) except RuntimeError as e: - if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError): + if "CUDA out of memory" in str(e) or isinstance( + e, torch.cuda.OutOfMemoryError + ): raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" @@ -912,9 +968,38 @@ def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) - def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, torch.Tensor]: global CACHE_MANAGER + if batch.speculation_ids is not None: + B, speculative_length = batch.speculation_ids.shape + new_length = speculative_length + 1 + batch.input_ids = torch.cat( + [batch.input_ids.unsqueeze(-1), batch.speculation_ids], dim=1 + ).reshape(-1) + arange = torch.arange( + new_length, device=batch.position_ids.device + ).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + batch.position_ids = ( + batch.position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + # Model Forward return self.model.forward( input_ids=batch.input_ids, @@ -954,7 +1039,9 @@ def generate_token( # Assign pointers to LoRA weights # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights) + adapter_data = AdapterBatchData.from_meta( + batch.adapter_meta, self.batched_lora_weights + ) try: out = self.forward(batch, adapter_data) @@ -962,17 +1049,40 @@ def generate_token( del batch raise e + if isinstance(out, tuple): + out, speculation_logits = out + else: + speculation_logits = None + if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) + if speculation_logits is not None: + speculation_logits = ( + speculation_logits[batch.prefill_next_token_indices] + if prefill_logprobs + else speculation_logits + ) else: next_token_logits = out - next_input_ids, next_token_logprobs = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits + ( + next_input_ids, + next_token_logprobs, + accepted_ids, + speculation_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], + next_token_logits, + get_speculation_num(), + batch.speculation_ids, + speculation_logits, ) + speculation_length = ( + speculation_ids.shape[1] if speculation_ids is not None else 0 + ) if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -984,7 +1094,9 @@ def generate_token( # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) else: prefill_logprobs = None next_position_ids = batch.position_ids @@ -1001,6 +1113,7 @@ def generate_token( iterator = zip( batch.input_lengths, batch.all_input_ids, + accepted_ids, ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second @@ -1008,9 +1121,11 @@ def generate_token( # It is faster if we delay this sync for the maximum amount of time # For each member of the batch + index_count = 0 for i, ( input_length, all_input_ids, + accepted_id, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -1028,7 +1143,9 @@ def generate_token( # Initialize adapter indices # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices @@ -1043,23 +1160,28 @@ def generate_token( start_index + 1 : start_index + out_length ] - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + for xy in range(accepted_id): + batch.all_input_ids_tensor[i, input_length + xy] = next_input_ids[ + index_count + ] + index_count += 1 cumulative_length += input_length # Set values in batch - batch.input_ids = next_input_ids - batch.position_ids = next_position_ids + 1 + batch.input_ids = next_input_ids[accepted_ids.cumsum(-1) - 1] + batch.speculation_ids = speculation_ids + batch.position_ids = next_position_ids + accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices - batch.input_lengths_tensor += 1 - batch.slot_indices += 1 + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids if prefill: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, + adapter_segments, + dtype=torch.int32, device=batch.adapter_meta.adapter_segments.device, ) @@ -1074,7 +1196,7 @@ def generate_token( # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = batch.input_ids.tolist() + next_token_ids = next_input_ids.tolist() # Zipped iterator iterator = zip( @@ -1086,11 +1208,11 @@ def generate_token( batch.all_input_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, - next_token_ids, - next_token_logprobs, + accepted_ids, ) # For each member of the batch + index_count = 0 for i, ( request, input_length, @@ -1100,27 +1222,45 @@ def generate_token( all_input_ids, do_sample, seed, - next_token_id, - next_token_logprob, + accepted_ids_num, ) in enumerate(iterator): # Append next token to all tokens - all_input_ids.append(next_token_id) - - # Generated token - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) + next_token_texts = [] + left = 0 + before = stopping_criteria.current_tokens + + current_stopped = False + for j in range(index_count, index_count + accepted_ids_num): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - # Evaluate stopping criteria stop, reason = stopping_criteria( next_token_id, next_token_text, ) - if not stop: - stopped = False + if stop: + left = index_count + accepted_ids_num - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[ + index_count : index_count + accepted_ids_num - left + ] + _next_token_logprobs = next_token_logprobs[ + index_count : index_count + accepted_ids_num - left + ] + index_count += accepted_ids_num # Shard generations # All generations will be appended in the rust sharded client @@ -1164,7 +1304,7 @@ def generate_token( request.id, prefill_tokens, next_token_id, - next_token_logprob, + next_token_logprobs, next_token_text, next_token_id in self.all_special_ids, generated_text, @@ -1173,7 +1313,9 @@ def generate_token( generations.append(generation) # Update values - batch.input_lengths[i] = input_length + 1 + batch.input_lengths[i] = input_length + accepted_ids_num.item() + if batch.input_lengths[i] > batch.max_seqlen: + batch.max_seqlen = batch.input_lengths[i] batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index db28c5d70..a4d847ce4 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -1,11 +1,9 @@ -from collections import defaultdict import torch import torch.distributed from loguru import logger from opentelemetry import trace -from transformers import AutoTokenizer, GPT2Model -from tqdm import tqdm +from transformers import AutoTokenizer from typing import Dict, List, Optional, Tuple from lorax_server.models import FlashCausalLM @@ -19,11 +17,8 @@ LM_HEAD, ) from lorax_server.utils import ( - compute_delta_weight, create_merged_weight_files, - get_start_stop_idxs_for_rank, initialize_torch_distributed, - load_module_map, weight_files, Weights, ) @@ -70,27 +65,32 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") - # if adapter_id passed in as part of model instantiation, then we merge + # if adapter_id passed in as part of model instantiation, then we merge # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None self.dynamic_adapter_loading_enabled = True self.adapter_id = BASE_MODEL_ADAPTER_ID if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + logger.info( + f"Merging adapter weights from adapter_id {adapter_id} into model weights." + ) # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + adapter_id, + model_id, + model_weight_filenames=filenames, + adapter_source=adapter_source, ) self.dynamic_adapter_loading_enabled = False self.adapter_id = adapter_id weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize == "gptq": @@ -115,28 +115,37 @@ def __init__( @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "transformer.h" for i, layer in enumerate(self.model.transformer.h): - layer_weights[(i, ATTN_C_ATTN)] = (f"{prefix}.{i}.{ATTN_C_ATTN}", layer.attn.c_attn) - layer_weights[(i, ATTN_C_PROJ)] = (f"{prefix}.{i}.{ATTN_C_PROJ}", layer.attn.c_proj) + layer_weights[(i, ATTN_C_ATTN)] = ( + f"{prefix}.{i}.{ATTN_C_ATTN}", + layer.attn.c_attn, + ) + layer_weights[(i, ATTN_C_PROJ)] = ( + f"{prefix}.{i}.{ATTN_C_PROJ}", + layer.attn.c_proj, + ) layer_weights[(i, MLP_C_FC)] = (f"{prefix}.{i}.{MLP_C_FC}", layer.mlp.c_fc) - layer_weights[(i, MLP_C_PROJ)] = (f"{prefix}.{i}.{MLP_C_PROJ}", layer.mlp.c_proj) + layer_weights[(i, MLP_C_PROJ)] = ( + f"{prefix}.{i}.{MLP_C_PROJ}", + layer.mlp.c_proj, + ) # TODO: make Embedding layers adapter-compatible # layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index ac108a4d0..2dbde9996 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -4,7 +4,6 @@ from loguru import logger from opentelemetry import trace from transformers import AutoTokenizer -from tqdm import tqdm from typing import Dict, List, Optional, Tuple from lorax_server.models import FlashCausalLM @@ -19,12 +18,33 @@ Weights, ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) +from lorax_server.utils.medusa import MedusaModel +from huggingface_hub import hf_hub_download +import json tracer = trace.get_tracer(__name__) -ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] +ADAPTER_LAYERS = [ + Q_PROJ, + K_PROJ, + V_PROJ, + O_PROJ, + GATE_PROJ, + UP_PROJ, + DOWN_PROJ, + LM_HEAD, +] ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} @@ -38,6 +58,7 @@ def __init__( quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + medusa_id: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -63,27 +84,32 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") - # if adapter_id passed in as part of model instantiation, then we merge + # if adapter_id passed in as part of model instantiation, then we merge # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None self.dynamic_adapter_loading_enabled = True self.adapter_id = BASE_MODEL_ADAPTER_ID if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + logger.info( + f"Merging adapter weights from adapter_id {adapter_id} into model weights." + ) # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + adapter_id, + model_id, + model_weight_filenames=filenames, + adapter_source=adapter_source, ) self.dynamic_adapter_loading_enabled = False self.adapter_id = adapter_id weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize in ["gptq", "awq"]: @@ -92,6 +118,22 @@ def __init__( self.model_id = model_id model = FlashLlamaForCausalLM(config, weights) + if medusa_id is not None: + medusa_config = hf_hub_download( + medusa_id, revision=revision, filename="config.json" + ) + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_head = hf_hub_download( + medusa_id, revision=revision, filename="medusa_lm_head.pt" + ) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) + torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, @@ -104,34 +146,55 @@ def __init__( rank=rank, world_size=world_size, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) - layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) - - layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) - layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) - layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, GATE_PROJ)] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, UP_PROJ)] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, DOWN_PROJ)] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index c150b582c..d48e12063 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -30,8 +30,21 @@ StoppingCriteria, ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, + AdapterBatchData, + AdapterBatchMetadata, +) from lorax_server.utils.segments import find_segments +from lorax_server.utils.medusa import MedusaModel +from huggingface_hub import hf_hub_download tracer = trace.get_tracer(__name__) @@ -39,7 +52,16 @@ SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None -ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] +ADAPTER_LAYERS = [ + Q_PROJ, + K_PROJ, + V_PROJ, + O_PROJ, + GATE_PROJ, + UP_PROJ, + DOWN_PROJ, + LM_HEAD, +] ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} @@ -199,8 +221,10 @@ def from_pb( max_blocks = max(max_blocks, needed_blocks) max_length = max(max_length, input_length + max_new_tokens) - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) - + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device ) @@ -242,7 +266,9 @@ def from_pb( ) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) if all_prefill_logprobs: prefill_head_indices = None @@ -310,6 +336,7 @@ def __init__( quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + medusa_id: Optional[str] = None, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -345,27 +372,32 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") - # if adapter_id passed in as part of model instantiation, then we merge + # if adapter_id passed in as part of model instantiation, then we merge # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None self.dynamic_adapter_loading_enabled = True self.adapter_id = BASE_MODEL_ADAPTER_ID if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + logger.info( + f"Merging adapter weights from adapter_id {adapter_id} into model weights." + ) # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + adapter_id, + model_id, + model_weight_filenames=filenames, + adapter_source=adapter_source, ) self.dynamic_adapter_loading_enabled = False self.adapter_id = adapter_id weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize in ["gptq", "awq"]: @@ -374,6 +406,22 @@ def __init__( self.model_id = model_id model = FlashMistralForCausalLM(config, weights) + if medusa_id is not None: + medusa_config = hf_hub_download( + medusa_id, revision=revision, filename="config.json" + ) + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_head = hf_hub_download( + medusa_id, revision=revision, filename="medusa_lm_head.pt" + ) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) + torch.distributed.barrier(group=self.process_group) super(FlashMistral, self).__init__( model=model, @@ -396,7 +444,9 @@ def supports_adapter_loading(self) -> bool: def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch - def forward(self, batch: FlashMistralBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: FlashMistralBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward logits = self.model.forward( input_ids=batch.input_ids, @@ -414,30 +464,51 @@ def forward(self, batch: FlashMistralBatch, adapter_data: AdapterBatchData) -> T if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) - layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) - - layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) - layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) - layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, GATE_PROJ)] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, UP_PROJ)] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, DOWN_PROJ)] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 31a005586..b2f24b511 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -22,9 +22,6 @@ ATTN_O_PROJ, ATTN_Q_PROJ, ATTN_V_PROJ, - MOE_W1, - MOE_W2, - MOE_W3, FlashMixtralForCausalLM, MixtralConfig, ) @@ -206,8 +203,10 @@ def from_pb( max_blocks = max(max_blocks, needed_blocks) max_length = max(max_length, input_length + max_new_tokens) - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) - + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device ) @@ -249,7 +248,9 @@ def from_pb( ) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) if all_prefill_logprobs: prefill_head_indices = None @@ -352,27 +353,32 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") - # if adapter_id passed in as part of model instantiation, then we merge + # if adapter_id passed in as part of model instantiation, then we merge # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None self.dynamic_adapter_loading_enabled = True self.adapter_id = BASE_MODEL_ADAPTER_ID if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + logger.info( + f"Merging adapter weights from adapter_id {adapter_id} into model weights." + ) # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + adapter_id, + model_id, + model_weight_filenames=filenames, + adapter_source=adapter_source, ) self.dynamic_adapter_loading_enabled = False self.adapter_id = adapter_id weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize in ["gptq", "awq"]: @@ -403,7 +409,9 @@ def supports_adapter_loading(self) -> bool: def batch_type(self) -> Type[FlashMixtralBatch]: return FlashMixtralBatch - def forward(self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward logits = self.model.forward( input_ids=batch.input_ids, @@ -421,31 +429,43 @@ def forward(self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData) -> T if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, ATTN_Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) - layer_weights[(i, ATTN_O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + layer_weights[(i, ATTN_Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, ATTN_O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) # TODO(travis): requires implementing this for block sparse MoE # layer_weights[(i, MOE_W1)] = (f"{prefix}.{i}.moe.w1", layer.moe.w1) # layer_weights[(i, MOE_W2)] = (f"{prefix}.{i}.moe.w2", layer.moe.w2) # layer_weights[(i, MOE_W3)] = (f"{prefix}.{i}.moe.w3", layer.moe.w3) - + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index 4284796f0..eeba17e3e 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -67,27 +67,32 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") - # if adapter_id passed in as part of model instantiation, then we merge + # if adapter_id passed in as part of model instantiation, then we merge # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None self.dynamic_adapter_loading_enabled = True self.adapter_id = BASE_MODEL_ADAPTER_ID if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + logger.info( + f"Merging adapter weights from adapter_id {adapter_id} into model weights." + ) # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + adapter_id, + model_id, + model_weight_filenames=filenames, + adapter_source=adapter_source, ) self.dynamic_adapter_loading_enabled = False self.adapter_id = adapter_id weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize == "gptq": @@ -109,36 +114,44 @@ def __init__( rank=rank, world_size=world_size, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "transformer.h" for i, layer in enumerate(self.model.transformer.h): - layer_weights[(i, ATTN_WQKV)] = (f"{prefix}.{i}.mixer.Wqkv", layer.mixer.Wqkv) - layer_weights[(i, ATTN_OUT_PROJ)] = (f"{prefix}.{i}.mixer.out_proj", layer.mixer.out_proj) + layer_weights[(i, ATTN_WQKV)] = ( + f"{prefix}.{i}.mixer.Wqkv", + layer.mixer.Wqkv, + ) + layer_weights[(i, ATTN_OUT_PROJ)] = ( + f"{prefix}.{i}.mixer.out_proj", + layer.mixer.out_proj, + ) layer_weights[(i, MLP_FC1)] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1) layer_weights[(i, MLP_FC2)] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2) - + layer_weights[(0, LM_HEAD)] = ("lm_head.linear", self.model.lm_head.linear) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - - def split_lora_b_qkv(self, t: torch.Tensor, head_size: int, num_heads: int, num_key_value_heads: int) -> torch.Tensor: + + def split_lora_b_qkv( + self, t: torch.Tensor, head_size: int, num_heads: int, num_key_value_heads: int + ) -> torch.Tensor: # Because we're splitting on the hidden size dimension, we need to # account for the separate q, k, and v matrices. chunks = t.split( @@ -151,11 +164,10 @@ def split_lora_b_qkv(self, t: torch.Tensor, head_size: int, num_heads: int, num_ ) assert len(chunks) == 3 chunks = [ - shard_on_dim(w, dim=1, process_group=self.process_group) - for w in chunks + shard_on_dim(w, dim=1, process_group=self.process_group) for w in chunks ] return torch.cat(chunks, dim=1) - + def shard_lora_weights( self, weights_a: List[torch.Tensor], diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index f87e10f31..471bb7014 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -68,27 +68,32 @@ def __init__( filenames = weight_files(model_id, revision=revision, extension=".safetensors") - # if adapter_id passed in as part of model instantiation, then we merge + # if adapter_id passed in as part of model instantiation, then we merge # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None self.dynamic_adapter_loading_enabled = True self.adapter_id = BASE_MODEL_ADAPTER_ID if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + logger.info( + f"Merging adapter weights from adapter_id {adapter_id} into model weights." + ) # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + adapter_id, + model_id, + model_weight_filenames=filenames, + adapter_source=adapter_source, ) self.dynamic_adapter_loading_enabled = False self.adapter_id = adapter_id weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, ) if config.quantize == "gptq": @@ -110,47 +115,61 @@ def __init__( rank=rank, world_size=world_size, ) - + @property def supports_adapter_loading(self) -> bool: return True - + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} prefix = "transformer.h" for i, layer in enumerate(self.model.transformer.h): - layer_weights[(i, ATTN_C_ATTN)] = (f"{prefix}.{i}.attn.c_attn", layer.attn.c_attn) - layer_weights[(i, ATTN_C_PROJ)] = (f"{prefix}.{i}.attn.c_proj", layer.attn.c_proj) + layer_weights[(i, ATTN_C_ATTN)] = ( + f"{prefix}.{i}.attn.c_attn", + layer.attn.c_attn, + ) + layer_weights[(i, ATTN_C_PROJ)] = ( + f"{prefix}.{i}.attn.c_proj", + layer.attn.c_proj, + ) + + layer_weights[(i, MLP_W1)] = ( + f"{prefix}.{i}.mlp.w1", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, MLP_W2)] = ( + f"{prefix}.{i}.mlp.w2", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, MLP_C_PROJ)] = ( + f"{prefix}.{i}.mlp.c_proj", + layer.mlp.c_proj, + ) - layer_weights[(i, MLP_W1)] = (f"{prefix}.{i}.mlp.w1", layer.mlp.gate_up_proj) - layer_weights[(i, MLP_W2)] = (f"{prefix}.{i}.mlp.w2", layer.mlp.gate_up_proj) - layer_weights[(i, MLP_C_PROJ)] = (f"{prefix}.{i}.mlp.c_proj", layer.mlp.c_proj) - layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) return layer_weights - + @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS - + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) - + def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - + def split_lora_b_qkv(self, t: torch.Tensor, projection_size: int) -> torch.Tensor: # Because we're splitting on the hidden size dimension, we need to # account for the separate q, k, and v matrices. chunks = torch.split(t, projection_size, dim=1) assert len(chunks) == 3 chunks = [ - shard_on_dim(w, dim=1, process_group=self.process_group) - for w in chunks + shard_on_dim(w, dim=1, process_group=self.process_group) for w in chunks ] return torch.cat(chunks, dim=1) - + def shard_lora_weights( self, weights_a: List[torch.Tensor], @@ -169,11 +188,10 @@ def shard_lora_weights( # [r, hidden_size] # Because we're splitting on the hidden size dimension, we need to # account for the separate q, k, and v matrices. - projection_size = (self.config.hidden_size // self.config.num_attention_heads) * self.config.num_attention_heads - weights_b = [ - self.split_lora_b_qkv(w, projection_size) - for w in weights_b - ] + projection_size = ( + self.config.hidden_size // self.config.num_attention_heads + ) * self.config.num_attention_heads + weights_b = [self.split_lora_b_qkv(w, projection_size) for w in weights_b] return weights_a, weights_b else: diff --git a/server/lorax_server/models/flash_santacoder.py b/server/lorax_server/models/flash_santacoder.py index 4a136683f..5f046654a 100644 --- a/server/lorax_server/models/flash_santacoder.py +++ b/server/lorax_server/models/flash_santacoder.py @@ -4,10 +4,7 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig from typing import Optional, List -import json -import os -from huggingface_hub import hf_hub_download from lorax_server.models import FlashCausalLM from lorax_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 1b3f1e543..21c6bc591 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -8,6 +8,7 @@ from lorax_server.models.types import Batch, GeneratedText from lorax_server.pb.generate_pb2 import InfoResponse from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID +from lorax_server.utils.globals import get_speculation_num B = TypeVar("B", bound=Batch) @@ -23,6 +24,7 @@ def __init__( rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, + speculation: Optional[int] = None, ): self.model = model.eval() self.tokenizer = tokenizer @@ -34,6 +36,10 @@ def __init__( self.world_size = world_size self.sliding_window = sliding_window + if speculation is None: + speculation = get_speculation_num() + self.speculation = speculation + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -51,6 +57,7 @@ def info(self) -> InfoResponse: dtype=str(self.dtype), device_type=self.device.type, window_size=self.sliding_window, + speculation=self.speculation, ) @property @@ -102,7 +109,7 @@ def check_initialized(self): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) - + def load_adapter(self, adapter_id, adapter_source, adapter_index): if adapter_id == BASE_MODEL_ADAPTER_ID: return diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5fdae98db..3da9bea01 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -17,7 +17,15 @@ from lorax_server.models import Model, get_model from lorax_server.pb import generate_pb2_grpc, generate_pb2 from lorax_server.tracing import UDSOpenTelemetryAioServerInterceptor -from lorax_server.utils import HUB, LOCAL, S3, PBASE, get_config_path, get_local_dir, map_pbase_model_id_to_s3 +from lorax_server.utils import ( + HUB, + LOCAL, + S3, + PBASE, + get_config_path, + get_local_dir, + map_pbase_model_id_to_s3, +) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID @@ -116,7 +124,7 @@ async def Decode(self, request, context): generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, ) - + async def DownloadAdapter(self, request, context): adapter_id = request.adapter_id if adapter_id == BASE_MODEL_ADAPTER_ID: @@ -147,14 +155,16 @@ async def DownloadAdapter(self, request, context): logger.exception("Error when downloading adapter") if adapter_source != LOCAL: - # delete safetensors files if there is an issue downloading or converting + # delete safetensors files if there is an issue downloading or converting # the weights to prevent cache hits by subsequent calls try: local_path = get_local_dir(adapter_id, adapter_source) shutil.rmtree(local_path) except Exception as e: - logger.warning(f"Error cleaning up safetensors files after " - f"download error: {e}\nIgnoring.") + logger.warning( + f"Error cleaning up safetensors files after " + f"download error: {e}\nIgnoring." + ) raise async def LoadAdapter(self, request, context): @@ -166,7 +176,7 @@ async def LoadAdapter(self, request, context): adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token) adapter_source = S3 self.model.load_adapter(adapter_id, adapter_source, adapter_index) - + return generate_pb2.LoadAdapterResponse( adapter_id=adapter_id, adapter_source=request.adapter_source, @@ -182,7 +192,7 @@ async def OffloadAdapter(self, request, context): adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index self.model.offload_adapter(adapter_id, adapter_source, adapter_index) - + return generate_pb2.OffloadAdapterResponse( adapter_id=adapter_id, adapter_source=request.adapter_source, @@ -227,7 +237,15 @@ async def serve_inner( try: model = get_model( - model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, source, adapter_source + model_id, + adapter_id, + revision, + sharded, + quantize, + dtype, + trust_remote_code, + source, + adapter_source, ) except Exception: logger.exception("Error when initializing model") @@ -275,7 +293,9 @@ async def serve_inner( await server.stop(0) asyncio.run( - serve_inner(model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner( + model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code + ) ) @@ -290,4 +310,4 @@ def _adapter_source_enum_to_string(adapter_source: int) -> str: elif adapter_source == generate_pb2.AdapterSource.PBASE: return PBASE else: - raise ValueError(f"Unknown adapter source {adapter_source}") \ No newline at end of file + raise ValueError(f"Unknown adapter source {adapter_source}") diff --git a/server/lorax_server/tracing.py b/server/lorax_server/tracing.py index 0d9dcccea..677c07122 100644 --- a/server/lorax_server/tracing.py +++ b/server/lorax_server/tracing.py @@ -55,9 +55,7 @@ def _start_span(self, handler_call_details, context, set_status_on_exception=Fal def setup_tracing(shard: int, otlp_endpoint: str): - resource = Resource.create( - attributes={"service.name": f"lorax.server-{shard}"} - ) + resource = Resource.create(attributes={"service.name": f"lorax.server-{shard}"}) span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) span_processor = BatchSpanProcessor(span_exporter) diff --git a/server/lorax_server/utils/__init__.py b/server/lorax_server/utils/__init__.py index 910ae613f..64eec3694 100644 --- a/server/lorax_server/utils/__init__.py +++ b/server/lorax_server/utils/__init__.py @@ -1,5 +1,5 @@ from lorax_server.utils.adapter import ( - compute_delta_weight, + compute_delta_weight, create_merged_weight_files, load_module_map, ) @@ -13,7 +13,6 @@ download_weights, map_pbase_model_id_to_s3, weight_hub_files, - weight_files, EntryNotFoundError, HUB, PBASE, diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index a3ed66c12..1ad93c1fd 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -15,7 +15,7 @@ from tqdm import tqdm from filelock import FileLock -from lorax_server.utils.sources import get_model_source, get_config_path, weight_files +from lorax_server.utils.sources import get_model_source, get_config_path, weight_files BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -24,13 +24,15 @@ @lru_cache(maxsize=128) def load_module_map(model_id, adapter_id, adapter_source, weight_names): # TODO(geoffrey): refactor this and merge parts of this function with - # lorax_server/utils/adapter.py::create_merged_weight_files + # lorax_server/utils/adapter.py::create_merged_weight_files source = get_model_source(adapter_source, adapter_id, extension=".safetensors") config_path = get_config_path(adapter_id, adapter_source) adapter_config = LoraConfig.from_pretrained(config_path) if adapter_config.base_model_name_or_path != model_id: expected_config = AutoConfig.from_pretrained(model_id) - model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path) + model_config = AutoConfig.from_pretrained( + adapter_config.base_model_name_or_path + ) if model_config.architectures == expected_config.architectures: warnings.warn( f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " @@ -38,16 +40,18 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names): ) else: # TODO(travis): revisit this when we support clasification heads which will not use CausalLM - raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " - f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " - f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) # load adapter weights from all shards (should have relatively small memory footprint) adapter_filenames = source.weight_files() adapter_weights = {} for filename in adapter_filenames: adapter_weights.update(load_file(filename)) - + # map the model weights to the relevant adapter weights (LoRA A and B matrices) adapter_weight_names = set() module_map = {} @@ -56,7 +60,7 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names): lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: continue - + module_map[weight_name] = { "lora_A": (adapter_weights[lora_a_name], lora_a_name), "lora_B": (adapter_weights[lora_b_name], lora_b_name), @@ -67,16 +71,16 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names): def compute_delta_weight( - lora_A: torch.Tensor, - lora_B: torch.Tensor, - fan_in_fan_out: bool, - alpha: float, - r: float + lora_A: torch.Tensor, + lora_B: torch.Tensor, + fan_in_fan_out: bool, + alpha: float, + r: float, ) -> torch.Tensor: """Computes the delta weight for a Linear layer given A and B LoRA matrices. - + TODO: add logic for other module types beyond Linear layers. - + Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806 """ scaling = alpha / r @@ -85,9 +89,9 @@ def compute_delta_weight( def merge_adapter_weights( - model_weights: Dict[str, torch.Tensor], - adapter_weights: Dict[str, torch.Tensor], - adapter_config: LoraConfig + model_weights: Dict[str, torch.Tensor], + adapter_weights: Dict[str, torch.Tensor], + adapter_config: LoraConfig, ) -> Tuple[Dict[str, torch.Tensor], Set[str]]: """ Merges the adapter weights into the model weights. @@ -114,31 +118,40 @@ def merge_adapter_weights( matrix_type = adapter_weight_name.split(".")[-2] module_mapping[weight_name][matrix_type] = adapter_weight_name processed_adapter_weight_names.add(adapter_weight_name) - + # merge adapter weights into model weights merged_weights = {} for weight_name, adapter_weight_names in tqdm( - module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)): - + module_mapping.items(), + desc="Merging adapter weights", + total=len(module_mapping), + ): # TODO: support adapter types beyond LoRA # TODO: put this on GPU if it is available. This should greatly speedup compute_delta_weight lora_A = adapter_weights[adapter_weight_names["lora_A"]] lora_B = adapter_weights[adapter_weight_names["lora_B"]] delta_weight = compute_delta_weight( - lora_A, lora_B, adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r) - + lora_A, + lora_B, + adapter_config.fan_in_fan_out, + adapter_config.lora_alpha, + adapter_config.r, + ) + # transpose delta weight if necessary # TODO(geoffrey): I believe this is required when using Conv1D layers (gpt2). # We can likely take this out once we've switched to using Linear layers. - if (delta_weight.shape != model_weights[weight_name].shape and - delta_weight.T.shape == model_weights[weight_name].shape): + if ( + delta_weight.shape != model_weights[weight_name].shape + and delta_weight.T.shape == model_weights[weight_name].shape + ): delta_weight = delta_weight.T merged_weights[weight_name] = model_weights[weight_name] + delta_weight return merged_weights, processed_adapter_weight_names def create_merged_weight_files( - adapter_id: str, + adapter_id: str, model_id: str, model_weight_filenames: List[Path], adapter_source: str = "hub", @@ -150,21 +163,27 @@ def create_merged_weight_files( adapter_path = get_config_path(adapter_id, adapter_source) adapter_config = LoraConfig.from_pretrained(adapter_path) if adapter_config.base_model_name_or_path != model_id: - raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " - f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") - + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + # load adapter weights from all shards (should have relatively small memory footprint) adapter_weights = {} for filename in adapter_filenames: adapter_weights.update(load_file(filename)) remaining_adapter_weight_names = set(adapter_weights.keys()) - merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged" + merged_weight_directory = ( + Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged" + ) # just grab the existing files if they already exist and return immediately - lock = FileLock(str(merged_weight_directory)+ ".lock") + lock = FileLock(str(merged_weight_directory) + ".lock") with lock: if merged_weight_directory.is_dir(): - logger.info(f"Merged weight directory {merged_weight_directory} exist, skipping merge computation.") + logger.info( + f"Merged weight directory {merged_weight_directory} exist, skipping merge computation." + ) return weight_files(merged_weight_directory) else: logger.info("Merged weight files do not exist, computing merge.") @@ -178,23 +197,30 @@ def create_merged_weight_files( ) model_weights = load_file(filename) merged_weights, processed_adapter_weight_names = merge_adapter_weights( - model_weights, adapter_weights, adapter_config) - - merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename)) + model_weights, adapter_weights, adapter_config + ) + + merged_adapter_filename = Path( + merged_weight_directory, os.path.basename(filename) + ) save_file(merged_weights, merged_adapter_filename) logger.debug(f"Saved merged weights into {merged_adapter_filename}") merged_weight_filenames.append(merged_adapter_filename) remaining_adapter_weight_names = remaining_adapter_weight_names.difference( - processed_adapter_weight_names) - + processed_adapter_weight_names + ) + if len(remaining_adapter_weight_names) > 0: - logger.warning("WARNING: The following lora weights were not merged into the model weights:") + logger.warning( + "WARNING: The following lora weights were not merged into the model weights:" + ) for lora_name in remaining_adapter_weight_names: logger.warning("\t" + lora_name) logger.info( - f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}") + f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}" + ) return merged_weight_filenames @@ -203,10 +229,12 @@ def main(): adapter_config = LoraConfig.from_pretrained(adapter_id) model_id = adapter_config.base_model_name_or_path model_weight_filenames = weight_files(model_id, extension=".safetensors") - - merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames) + + merged_adapter_filenames = create_merged_weight_files( + adapter_id, model_id, model_weight_filenames + ) print(merged_adapter_filenames) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/server/lorax_server/utils/awq/awq.py b/server/lorax_server/utils/awq/awq.py index 2252734b5..78b070b58 100644 --- a/server/lorax_server/utils/awq/awq.py +++ b/server/lorax_server/utils/awq/awq.py @@ -1,10 +1,10 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py -import math import torch import torch.nn as nn import awq_inference_engine # with CUDA kernels + class AWQLinear(nn.Module): def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): super().__init__() @@ -20,8 +20,12 @@ def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): self.w_bit = w_bit self.group_size = group_size if group_size != -1 else self.in_features - assert self.in_features % self.group_size == 0, "in_features must be divisible by group_size" - assert self.out_features % (32 // self.w_bit) == 0, "out_features must be divisible by 32 // w_bit" + assert ( + self.in_features % self.group_size == 0 + ), "in_features must be divisible by group_size" + assert ( + self.out_features % (32 // self.w_bit) == 0 + ), "out_features must be divisible by 32 // w_bit" self.qweight = qweight self.qzeros = qzeros @@ -30,21 +34,22 @@ def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): @torch.no_grad() def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features, ) + out_shape = x.shape[:-1] + (self.out_features,) input_dtype = x.dtype if input_dtype != torch.float16: x = x.half() - - out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) - + + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) + if input_dtype != torch.float16: out = out.to(dtype=input_dtype) - + out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) - @property def weight(self) -> torch.Tensor: - return self.qweight \ No newline at end of file + return self.qweight diff --git a/server/lorax_server/utils/convert.py b/server/lorax_server/utils/convert.py index f911e5b5c..0b62f5208 100644 --- a/server/lorax_server/utils/convert.py +++ b/server/lorax_server/utils/convert.py @@ -111,4 +111,4 @@ def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: Lis start = datetime.datetime.now() convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start - logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") \ No newline at end of file + logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") diff --git a/server/lorax_server/utils/globals.py b/server/lorax_server/utils/globals.py new file mode 100644 index 000000000..a9284aa2f --- /dev/null +++ b/server/lorax_server/utils/globals.py @@ -0,0 +1,11 @@ +SPECULATION_NUM = 0 + + +def get_speculation_num() -> int: + global SPECULATION_NUM + return SPECULATION_NUM + + +def set_speculation_num(speculate: int): + global SPECULATION_NUM + SPECULATION_NUM = speculate diff --git a/server/lorax_server/utils/gptq/exllamav2.py b/server/lorax_server/utils/gptq/exllamav2.py index 19d0cea0b..55887bafe 100644 --- a/server/lorax_server/utils/gptq/exllamav2.py +++ b/server/lorax_server/utils/gptq/exllamav2.py @@ -10,40 +10,44 @@ try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error('exllamav2_kernels not installed.') + logger.error("exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device) + output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) gemm_half_q_half(x, q_handle, output, force_cuda) return output.view(output_shape) + def ext_make_q_matrix(w: dict, temp_dq, key: str = None): """ - Create Q matrix + Create Q matrix """ # EXL2 - # won't work as the moment because the tensors are not the same. + # won't work as the moment because the tensors are not the same. if "q_weight" in w: w["q_scale_max"] /= 256 w["q_perm"] = w["q_perm"].short() w["q_invperm"] = w["q_invperm"].short() - return make_q_matrix(w["q_weight"], - w["q_perm"], - w["q_invperm"], - w["q_scale"], - w["q_scale_max"], - w["q_groups"], - none_tensor, - none_tensor, - none_tensor, - temp_dq) + return make_q_matrix( + w["q_weight"], + w["q_perm"], + w["q_invperm"], + w["q_scale"], + w["q_scale_max"], + w["q_groups"], + none_tensor, + none_tensor, + none_tensor, + temp_dq, + ) # GPTQ elif "qweight" in w: if w["scales"].dtype == torch.float: @@ -51,36 +55,46 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): # GPTQ with g_idx (act_order) if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): - w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device) + w["q_perm"] = torch.empty( + (w["qweight"].shape[0] * 8,), + dtype=torch.short, + device=w["qweight"].device, + ) w["q_invperm"] = torch.empty_like(w["q_perm"]) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return make_q_matrix(w["qweight"], - w["q_perm"], - w["q_invperm"], - none_tensor, - none_tensor, - none_tensor, - w["qzeros"], - w["scales"], - w["g_idx"].cpu(), - temp_dq) + return make_q_matrix( + w["qweight"], + w["q_perm"], + w["q_invperm"], + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + w["g_idx"].cpu(), + temp_dq, + ) # GPTQ without g_idx else: - return make_q_matrix(w["qweight"], - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - w["qzeros"], - w["scales"], - none_tensor, - temp_dq) + return make_q_matrix( + w["qweight"], + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + none_tensor, + temp_dq, + ) + DEVICE = None FIXED_BYTES = 0 LAYERS = [] + def set_device(device): global DEVICE DEVICE = device @@ -103,11 +117,12 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): super().__init__() if bits != 4: raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") + f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." + ) self.q_handle = None self.q_tensors = None self.bits = bits - self.maxq = 2 ** self.bits - 1 + self.maxq = 2**self.bits - 1 self.infeatures = qweight.shape[0] // self.bits * 32 self.outfeatures = qweight.shape[1] + qweight.shape[1] % 32 @@ -127,39 +142,36 @@ def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None self.q_tensors = { - "qweight":self.qweight, - "qzeros":self.qzeros, - "scales":self.scales, - "g_idx":self.g_idx + "qweight": self.qweight, + "qzeros": self.qzeros, + "scales": self.scales, + "g_idx": self.g_idx, } temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - self.q_handle = ext_make_q_matrix( - self.q_tensors, temp_dq - ) - - def forward(self, x, force_cuda = False): + self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + + def forward(self, x, force_cuda=False): output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) if self.bias is not None: output.add_(self.bias) return output - + def temp_dq_size(self): return self.infeatures * self.outfeatures * 2 + 128 - + def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - + def scratch_spacing(self, max_input_len=8192, max_batch_size=32): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) @property def weight(self) -> torch.Tensor: return self.qweight - - -class ExLlamaV2DeviceTensors: + +class ExLlamaV2DeviceTensors: device_idx: int scratch_bytes: int scratch_idx: int @@ -168,15 +180,17 @@ class ExLlamaV2DeviceTensors: def __init__(self, device, scratch_bytes): self.device = device self.scratch_bytes = scratch_bytes - + def prepare(self): - self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) + self.scratch = torch.empty( + (self.scratch_bytes // 2,), dtype=torch.half, device=self.device + ) def get_scratch_slice(self, size_bytes): - - if self.scratch is None: self.prepare() + if self.scratch is None: + self.prepare() size_bytes = ((size_bytes + 127) // 128) * 128 size_half = size_bytes // 2 scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice \ No newline at end of file + return scratch_slice diff --git a/server/lorax_server/utils/gptq/quant_linear.py b/server/lorax_server/utils/gptq/quant_linear.py index 7c44f3e3c..48e7ab888 100644 --- a/server/lorax_server/utils/gptq/quant_linear.py +++ b/server/lorax_server/utils/gptq/quant_linear.py @@ -2,7 +2,7 @@ import numpy as np import torch import torch.nn as nn -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.cuda.amp import custom_fwd try: import triton @@ -357,7 +357,7 @@ def forward(self, x): ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) - + @property def weight(self) -> torch.Tensor: return self.qweight diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 4c883c214..15422fd30 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -16,7 +16,7 @@ HAS_BITS_AND_BYTES = False HAS_AWQ = True -try: +try: from lorax_server.utils.awq.awq import AWQLinear except ImportError: HAS_AWQ = False @@ -24,7 +24,12 @@ from accelerate import init_empty_weights from lorax_server.utils.gptq.quant_linear import QuantLinear -from lorax_server.utils.sgmv import add_lora_sgmv_cutlass, lora_a_sgmv_cutlass, lora_b_sgmv_cutlass, has_sgmv, orient_for_rank +from lorax_server.utils.sgmv import ( + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + has_sgmv, + orient_for_rank, +) HAS_EXLLAMA = True if os.getenv("DISABLE_EXLLAMA") == "True": @@ -175,13 +180,17 @@ def forward(self, x: torch.Tensor): self.weight.data = self.state.CxB return out + class Linear4bit(nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() # Initialize weight with 4-bit quantization self.weight = Params4bit( - weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type + weight.data, + requires_grad=False, + compress_statistics=True, + quant_type=quant_type, ) self.weight.cuda(weight.device) @@ -196,7 +205,9 @@ def forward(self, x: torch.Tensor): # Check if quantization state is initialized if getattr(self.weight, "quant_state", None) is None: - print("FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.") + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) # Convert input to compute_dtype if specified inp_dtype = x.dtype @@ -207,13 +218,16 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) # Perform 4-bit matrix multiplication - out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) # Convert output back to the input dtype out = out.to(inp_dtype) return out + def get_linear(weight, bias, quantize, fan_in_fan_out=False): # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out # Set to True if replacing a Conv1D layer with a Linear layer @@ -248,11 +262,13 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight except Exception: raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." + "The passed weight is not `gptq` compatible, loader needs to be updated." ) if use_exllama: - linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + linear = exllamav2QuantLinear( + qweight, qzeros, scales, g_idx, bias, bits, groupsize + ) else: linear = QuantLinear( qweight, @@ -267,10 +283,15 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): try: qweight, qzeros, scales, _, bits, groupsize, _ = weight except Exception: - raise NotImplementedError( - f"The passed weight is not compatible with `awq`" - ) - linear = AWQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None) + raise NotImplementedError("The passed weight is not compatible with `awq`") + linear = AWQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -356,23 +377,28 @@ def load_qkv(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out=False raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None - linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) + linear = get_linear( + weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out + ) return cls(linear) @classmethod - def load(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = False): + def load( + cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = False + ): return cls.load_multi( - config, [prefix], weights, bias, dim=0, fan_in_fan_out=fan_in_fan_out) + config, [prefix], weights, bias, dim=0, fan_in_fan_out=fan_in_fan_out + ) @classmethod def load_multi( - cls, - config, - prefixes: List[Union[str, Tuple]], - weights, - bias: bool, - dim: int, - fan_in_fan_out=False + cls, + config, + prefixes: List[Union[str, Tuple]], + weights, + bias: bool, + dim: int, + fan_in_fan_out=False, ): weight = weights.get_multi_weights_col( prefixes, quantize=config.quantize, dim=dim @@ -383,9 +409,11 @@ def load_multi( bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) + linear = get_linear( + weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out + ) return cls(linear) - + class TensorParallelAdapterLinear(nn.Module): def __init__(self, base_layer, layer_id, process_group): @@ -432,17 +460,23 @@ def forward_layer_type( rank_segments.segment_ends, self.layer_id, ) - + result[:, start_idx:end_idx] += proj else: for adapter_index in adapter_data.meta.adapter_set: if data is not None and data.has_adapter(adapter_index): - adapter_mask = (adapter_data.meta.adapter_indices == adapter_index).to(input.dtype).view(-1, 1) - layer_result = self.forward_lora(input, data, adapter_index, adapter_mask) + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) + layer_result = self.forward_lora( + input, data, adapter_index, adapter_mask + ) result[:, start_idx:end_idx] += layer_result return result - + def forward_lora( self, input: torch.Tensor, @@ -456,14 +490,14 @@ def forward_lora( if self.process_group.size() > 1: a_out = self.collect_lora_a(a_out) - + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] result = (a_out @ lora_b) * adapter_mask return result def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: raise NotImplementedError("Implemented in subclasses") - + class TensorParallelMultiAdapterLinear(TensorParallelAdapterLinear): def __init__(self, base_layer, layer_id, layer_names, sizes, process_group): @@ -476,8 +510,10 @@ def load(cls, base_layer, layer_id, layer_names, sizes, process_group): return TensorParallelMultiAdapterLinear( base_layer, layer_id, layer_names, sizes, process_group ) - - def forward(self, input: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor: + + def forward( + self, input: torch.Tensor, adapter_data: AdapterBatchData + ) -> torch.Tensor: result = self.base_layer(input) offset = 0 @@ -486,8 +522,10 @@ def forward(self, input: torch.Tensor, adapter_data: AdapterBatchData) -> torch. offset += self.sizes[i] end_idx = offset // self.process_group.size() - - result = self.forward_layer_type(result, input, adapter_data, layer_name, start_idx, end_idx) + + result = self.forward_layer_type( + result, input, adapter_data, layer_name, start_idx, end_idx + ) return result @@ -499,7 +537,9 @@ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same # rank, compute `a_out` on each, and then slice them into the buffer as shown here: # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 - gathered_tensors = [torch.empty_like(a_out) for _ in range(self.process_group.size())] + gathered_tensors = [ + torch.empty_like(a_out) for _ in range(self.process_group.size()) + ] torch.distributed.all_gather(gathered_tensors, a_out) return torch.cat(gathered_tensors, dim=1) @@ -511,19 +551,25 @@ def __init__(self, base_layer, layer_id, layer_name, process_group): @classmethod def load(cls, base_layer, layer_id, layer_name, process_group): - return TensorParallelAdapterRowLinear(base_layer, layer_id, layer_name, process_group) - - def forward(self, input: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor: + return TensorParallelAdapterRowLinear( + base_layer, layer_id, layer_name, process_group + ) + + def forward( + self, input: torch.Tensor, adapter_data: AdapterBatchData + ) -> torch.Tensor: result = self.base_layer(input) - + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 stride = result.shape[-1] // self.process_group.size() start_idx = self.process_group.rank() * stride end_idx = (self.process_group.rank() + 1) * stride - self.forward_layer_type(result, input, adapter_data, self.layer_name, start_idx, end_idx) + self.forward_layer_type( + result, input, adapter_data, self.layer_name, start_idx, end_idx + ) return result - + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. @@ -534,7 +580,7 @@ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 torch.distributed.all_reduce(a_out, group=self.process_group) return a_out - + class TensorParallelRowLinear(SuperLayer): def __init__(self, linear, process_group, all_reduce: bool = True): @@ -544,11 +590,11 @@ def __init__(self, linear, process_group, all_reduce: bool = True): @classmethod def load( - cls, - config, - prefix: str, - weights, - bias: bool, + cls, + config, + prefix: str, + weights, + bias: bool, fan_in_fan_out: bool = False, all_reduce: bool = True, ): @@ -829,14 +875,14 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, - dim, - max_position_embeddings=2048, - base=10000, - factor=1, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attn_factor=1, - beta_fast=32, + dim, + max_position_embeddings=2048, + base=10000, + factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, beta_slow=1, finetuned=True, device=None, @@ -862,35 +908,60 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): ): self._seq_len_cached = seqlen - t = torch.arange(self._seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self._seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) - + def yarn(self, device): - pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + pos_freqs = self.base ** ( + torch.arange(0, self.dim, 2).float().to(device) / self.dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) - low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings) - inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) self.inv_freq = inv_freq - self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation # Inverse dim formula to find dim based on number of rotations - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) + def find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 + ): + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) # Find dim range bounds based on rotations - def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(find_correction_dim( - low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim( - high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim-1) # Clamp values just in case + def find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 + ): + low = math.floor( + find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case def linear_ramp_mask(min, max, dim): if min == max: @@ -906,4 +977,4 @@ def get_mscale(scale=1): return 0.1 * math.log(scale) + 1.0 except ImportError: - pass \ No newline at end of file + pass diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 05b2e1bc3..fa29e84b4 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -7,7 +7,6 @@ from torch.distributed import ProcessGroup from lorax_server.utils.sgmv import MIN_SGMV_RANK, orient_for_rank -from lorax_server.utils.weights import shard_on_dim # Constants @@ -40,16 +39,16 @@ class AdapterWeightData: lora_b: Dict[int, torch.Tensor] adapter_index_configs: Dict[int, LoraConfig] rank_data: Dict[int, RankSegments] - + def has_adapter(self, adapter_index: int) -> bool: return adapter_index in self.adapter_index_configs - + def can_vectorize(self, pg: ProcessGroup) -> bool: return all( rank_data.rank // pg.size() >= MIN_SGMV_RANK for rank_data in self.rank_data.values() ) - + @dataclass class AdapterBatchMetadata: @@ -65,7 +64,9 @@ class AdapterBatchData: data: Dict[str, AdapterWeightData] @staticmethod - def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, "BatchedLoraWeights"]) -> "AdapterBatchData": + def from_meta( + meta: AdapterBatchMetadata, weights: Dict[str, "BatchedLoraWeights"] + ) -> "AdapterBatchData": data = {} for k, v in weights.items(): if v.is_empty(): @@ -84,10 +85,7 @@ def __init__( adapter_config: LoraConfig, ): # [num_layers, hidden_size, r] - weights_a = [ - orient_for_rank(w, adapter_config.r) - for w in weights_a - ] + weights_a = [orient_for_rank(w, adapter_config.r) for w in weights_a] self.weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] @@ -136,9 +134,10 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: [ ( self.lora_weights[idx].weights_a.data_ptr() - if idx in self.lora_weights + if idx in self.lora_weights else EMPTY_TENSOR.data_ptr() - ) for idx in segment_indices + ) + for idx in segment_indices ], dtype=torch.int64, device=device, @@ -151,10 +150,11 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: lora_b_ptr = torch.tensor( [ ( - self.lora_weights[idx].weights_b.data_ptr() - if idx in self.lora_weights + self.lora_weights[idx].weights_b.data_ptr() + if idx in self.lora_weights else EMPTY_TENSOR.data_ptr() - ) for idx in segment_indices + ) + for idx in segment_indices ], dtype=torch.int64, device=device, @@ -170,7 +170,9 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: for segment_idx, adapter_idx in enumerate(segment_indices): if adapter_idx not in self.lora_weights: continue - rank_indices[self.lora_weights[adapter_idx].adapter_config.r].append(segment_idx) + rank_indices[self.lora_weights[adapter_idx].adapter_config.r].append( + segment_idx + ) rank_data = {} for rank, indices in rank_indices.items(): @@ -179,11 +181,11 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: lora_a_ptr=lora_a_ptr[indices], lora_b_ptr=lora_b_ptr[indices], segment_starts=meta.adapter_segments[indices], - segment_ends=meta.adapter_segments[[i+1 for i in indices]], + segment_ends=meta.adapter_segments[[i + 1 for i in indices]], ) return AdapterWeightData( - lora_a=lora_a, + lora_a=lora_a, lora_b=lora_b, adapter_index_configs=adapter_index_configs, rank_data=rank_data, diff --git a/server/lorax_server/utils/medusa.py b/server/lorax_server/utils/medusa.py new file mode 100644 index 000000000..8f795b34d --- /dev/null +++ b/server/lorax_server/utils/medusa.py @@ -0,0 +1,59 @@ +import torch +from dataclasses import dataclass +from lorax_server.utils.layers import FastLinear + + +@dataclass +class Output: + logits: torch.FloatTensor = None + speculative_logits: torch.FloatTensor = None + + +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__(self, config, weights, lm_head): + super().__init__() + self.heads = torch.nn.ModuleList( + [ + MedusaHead(config, prefix=f"{i}", weights=weights) + for i in range(config["medusa_num_heads"]) + ] + ) + self.lm_head = lm_head + + def forward(self, x): + logits = self.lm_head(x) + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return logits, speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(config["medusa_num_layers"]) + ] + ) + n = len(self.blocks) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x diff --git a/server/lorax_server/utils/paged_attn.py b/server/lorax_server/utils/paged_attn.py index 6b9d4b28e..f4c52d58e 100644 --- a/server/lorax_server/utils/paged_attn.py +++ b/server/lorax_server/utils/paged_attn.py @@ -14,10 +14,10 @@ def reshape_and_cache( - key: torch.Tensor, # [num_tokens, num_heads, head_size] - value: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] slot_mapping: torch.Tensor, # [num_tokens] ): cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) @@ -25,20 +25,20 @@ def reshape_and_cache( # Source: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/attention.py def single_query_cached_kv_attention( - output: torch.Tensor, # [num_tokens, num_heads, head_size] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + output: torch.Tensor, # [num_tokens, num_heads, head_size] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] kv_head_mapping: torch.Tensor, softmax_scale: float, - block_tables: torch.Tensor, # [num_blocks, block_size] + block_tables: torch.Tensor, # [num_blocks, block_size] input_lengths: torch.Tensor, # [num_blocks] max_s: int, ): block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -91,4 +91,4 @@ def single_query_cached_kv_attention( block_size, max_s, None, - ) \ No newline at end of file + ) diff --git a/server/lorax_server/utils/segments.py b/server/lorax_server/utils/segments.py index 841ee3f6f..8761cb1cb 100644 --- a/server/lorax_server/utils/segments.py +++ b/server/lorax_server/utils/segments.py @@ -37,18 +37,23 @@ def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): # from this batch denoting the beginning of the segment, then offset all segment # positions by the value of the last segment in the previous batch to account for # the concatenation. - adapter_segments = adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] - - if self.adapter_segment_indices and self.adapter_segment_indices[-1] == segment_indices[0]: + adapter_segments = ( + adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] + ) + + if ( + self.adapter_segment_indices + and self.adapter_segment_indices[-1] == segment_indices[0] + ): # If the last segment in the previous batch is the same as the first segment in this batch, # then we merge them together into a single segment. In effect, this means removing it from # the segment indices of this batch, and extending the segment span by removing the segment # end index from the previous batch. segment_indices = segment_indices[1:] self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] - + self.adapter_segment_indices.extend(segment_indices) self.adapter_segment_tensors.append(adapter_segments) - + def build(self) -> Tuple[torch.Tensor, List[int]]: return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 99853cc48..6e1e6f55d 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -6,6 +6,7 @@ try: import punica_kernels as _kernels + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) except ImportError: warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") @@ -56,9 +57,11 @@ def add_lora_sgmv_cutlass( """ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: # Custom SGMV shrink only supports rank 16, 32, 64, 128 - _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank) + _add_lora_sgmv_cutlass_legacy( + y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank + ) return - + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index 33e342d5b..8a207075c 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -4,7 +4,16 @@ import requests -from .hub import EntryNotFoundError, LocalEntryNotFoundError, RevisionNotFoundError, get_hub_model_local_dir, weight_files, download_weights, weight_hub_files, HubModelSource +from .hub import ( + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, + get_hub_model_local_dir, + weight_files, + download_weights, + weight_hub_files, + HubModelSource, +) from .local import LocalModelSource, get_model_local_dir from .s3 import S3ModelSource, get_s3_model_local_dir @@ -15,7 +24,9 @@ PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}" PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}" -PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.predibase.com") +PREDIBASE_GATEWAY_ENDPOINT = os.getenv( + "PREDIBASE_GATEWAY_ENDPOINT", "https://api.predibase.com" +) @lru_cache(maxsize=256) @@ -30,7 +41,9 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_URL_ENDPOINT.format(name) elif len(name_components) == 2: name, version = name_components - url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) + url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format( + name, version + ) else: raise ValueError(f"Invalid model id {model_id}") resp = requests.get(url, headers=headers) @@ -40,7 +53,12 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: # TODO(travis): refactor into registry pattern -def get_model_source(source: str, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"): +def get_model_source( + source: str, + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", +): if source == HUB: return HubModelSource(model_id, revision, extension) elif source == S3: @@ -84,4 +102,4 @@ def get_local_dir(model_id: str, source: str): "get_hub_model_local_dir", "get_s3_model_local_dir", "map_pbase_model_id_to_s3", -] \ No newline at end of file +] diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index ac4ee377d..799a2a171 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -10,8 +10,7 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, - EntryNotFoundError, - RevisionNotFoundError, # Import here to ease try/except in other part of the lib + EntryNotFoundError, # Import here to ease try/except in other part of the lib ) from .source import BaseModelSource, try_to_load_from_cache @@ -50,7 +49,6 @@ def weight_hub_files( return filenames - def weight_files( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[Path]: @@ -157,7 +155,12 @@ def download_file(filename, tries=5, backoff: int = 5): class HubModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + ): self.model_id = model_id self.revision = revision self.extension = extension diff --git a/server/lorax_server/utils/sources/local.py b/server/lorax_server/utils/sources/local.py index e89ffe3f5..fa45adb4f 100644 --- a/server/lorax_server/utils/sources/local.py +++ b/server/lorax_server/utils/sources/local.py @@ -1,34 +1,29 @@ import os -import time -from datetime import timedelta -from typing import Optional, List, Any +from typing import Optional, List -from loguru import logger from pathlib import Path -import boto3 -from botocore.config import Config from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from huggingface_hub.utils import ( - LocalEntryNotFoundError, - EntryNotFoundError, -) -from .s3 import get_s3_model_local_dir -from .source import BaseModelSource, try_to_load_from_cache +from .source import BaseModelSource def get_model_local_dir(model_id: str): if os.path.isabs(model_id): return Path(model_id) - + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / model_id return repo_cache class LocalModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors"): + def __init__( + self, + model_id: str, + revision: Optional[str] = "", + extension: str = ".safetensors", + ): if len(model_id) < 5: raise ValueError(f"model_id '{model_id}' is too short for prefix filtering") @@ -52,11 +47,11 @@ def weight_files(self, extension: str = None): f"No local weights found in {model_id} with extension {extension}" ) return local_files - + raise FileNotFoundError( f"No local weights found in {model_id} with extension {extension}" ) - + def download_weights(self, filenames: List[str]): return [] @@ -64,4 +59,4 @@ def download_model_assets(self): return [] def get_local_path(self, model_id: str): - return get_model_local_dir(model_id) \ No newline at end of file + return get_model_local_dir(model_id) diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index 241c165f7..c2bedc00f 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -26,7 +26,7 @@ def _get_bucket_resource(): mode="standard", ) ) - s3 = boto3.resource('s3', config=config) + s3 = boto3.resource("s3", config=config) bucket = os.getenv("PREDIBASE_MODEL_BUCKET") if not bucket: raise ValueError("PREDIBASE_MODEL_BUCKET environment variable is not set") @@ -44,7 +44,11 @@ def weight_s3_files( ) -> List[str]: """Get the weights filenames from s3""" model_files = bucket.objects.filter(Prefix=model_id) - filenames = [f.key.removeprefix(model_id).lstrip("/") for f in model_files if f.key.endswith(extension)] + filenames = [ + f.key.removeprefix(model_id).lstrip("/") + for f in model_files + if f.key.endswith(extension) + ] if not filenames: raise EntryNotFoundError( f"No {extension} weights found for model {model_id}", @@ -54,9 +58,13 @@ def weight_s3_files( def download_files_from_s3( - bucket: Any, filenames: List[str], model_id: str, revision: str = "", + bucket: Any, + filenames: List[str], + model_id: str, + revision: str = "", ) -> List[Path]: """Download the safetensors files from the s3""" + def download_file(filename): repo_cache = get_s3_model_local_dir(model_id) local_file = try_to_load_from_cache(repo_cache, revision, filename) @@ -75,7 +83,7 @@ def download_file(filename): # TODO: add support for revision logger.info( f"Downloaded {local_file_path} in {timedelta(seconds=int(time.time() - start_time))}." - ) + ) if not local_file_path.is_file(): raise FileNotFoundError(f"File {local_file_path} not found") return local_file_path @@ -129,9 +137,7 @@ def weight_files_s3( repo_cache = get_s3_model_local_dir(model_id) files = [] for filename in filenames: - cache_file = try_to_load_from_cache( - repo_cache, revision, filename - ) + cache_file = try_to_load_from_cache(repo_cache, revision, filename) if cache_file is None: raise LocalEntryNotFoundError( f"File {filename} of model {model_id} not found in " @@ -155,7 +161,9 @@ def download_model_from_s3(bucket: Any, model_id: str, extension: str = ".safete logger.info(filenames) download_files_from_s3(bucket, filenames, model_id) logger.info(f"Downloaded {len(filenames)} files") - logger.info(f"Contents of the cache folder: {os.listdir(get_s3_model_local_dir(model_id))}") + logger.info( + f"Contents of the cache folder: {os.listdir(get_s3_model_local_dir(model_id))}" + ) # Raise an error if none of the files we downloaded have the correct extension filenames_with_extension = [f for f in model_files if f.key.endswith(extension)] @@ -167,7 +175,12 @@ def download_model_from_s3(bucket: Any, model_id: str, extension: str = ".safete class S3ModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors"): + def __init__( + self, + model_id: str, + revision: Optional[str] = "", + extension: str = ".safetensors", + ): if len(model_id) < 5: raise ValueError(f"model_id '{model_id}' is too short for prefix filtering") @@ -184,12 +197,14 @@ def remote_weight_files(self, extension: str = None): def weight_files(self, extension: str = None): extension = extension or self.extension return weight_files_s3(self.bucket, self.model_id, self.revision, extension) - + def download_weights(self, filenames: List[str]): - return download_files_from_s3(self.bucket, filenames, self.model_id, self.revision) + return download_files_from_s3( + self.bucket, filenames, self.model_id, self.revision + ) def download_model_assets(self): return download_model_from_s3(self.bucket, self.model_id, self.extension) def get_local_path(self, model_id: str): - return get_s3_model_local_dir(model_id) \ No newline at end of file + return get_s3_model_local_dir(model_id) diff --git a/server/lorax_server/utils/sources/source.py b/server/lorax_server/utils/sources/source.py index 38bac0b48..4a7a54225 100644 --- a/server/lorax_server/utils/sources/source.py +++ b/server/lorax_server/utils/sources/source.py @@ -43,15 +43,15 @@ def remote_weight_files(self, extension: str = None): def weight_files(self, extension: str = None): raise NotImplementedError - + def download_weights(self, filenames: List[str]): raise NotImplementedError - + def download_model_assets(self): - """ The reason we need this function is that for s3 - we need to download all the model files whereas for - hub we only need to download the weight files. And maybe - for other future sources we might need something different. + """The reason we need this function is that for s3 + we need to download all the model files whereas for + hub we only need to download the weight files. And maybe + for other future sources we might need something different. So this function will take the necessary steps to download - the needed files for any source """ - raise NotImplementedError \ No newline at end of file + the needed files for any source""" + raise NotImplementedError diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 1d4a483d3..43c91b732 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -284,32 +284,99 @@ def __init__( self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): + def __call__( + self, + input_ids: torch.Tensor, + scores: torch.Tensor, + speculate: int, + speculation_ids: Optional[torch.Tensor] = None, + speculation_scores: Optional[torch.Tensor] = None, + verbose=False, + ): """ - Chooses the next tokens based on the input IDs and scores. + Perform token processing and selection based on input scores. Args: - input_ids (torch.Tensor): The input tensor containing the token IDs. - scores (torch.Tensor): The tensor containing the scores for each token. + input_ids (torch.Tensor): The input tensor of token IDs. + scores (torch.Tensor): The scores tensor representing the likelihood of each token. + speculate (int): The number of speculative tokens to generate. + speculation_ids (Optional[torch.Tensor]): The tensor of speculated token IDs. + speculation_scores (Optional[torch.Tensor]): The scores tensor for speculated tokens. + verbose (bool): Whether to enable verbose mode. Returns: - torch.Tensor: The tensor containing the next token IDs. - torch.Tensor: The tensor containing the log probabilities of the next tokens. + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: A tuple containing the following: + - next_ids (torch.Tensor): The selected token IDs for the next step. + - next_logprobs (torch.Tensor): The log probabilities of the selected token IDs. + - logprobs (torch.Tensor): The log probabilities of all token IDs. + - accepted_ids (torch.Tensor): The accepted tokens for each input sequence. + - speculative_ids (Optional[torch.Tensor]): The selected speculative token IDs. """ - if self.watermark_processor is not None: - scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor is not None: - scores = self.repetition_processor(input_ids, scores) - - for warper in self.warpers: - scores = warper(input_ids, scores) + if speculation_ids is not None: + B = scores.shape[0] // (speculation_ids.shape[1] + 1) if speculation_ids is not None else scores.shape[0] + S = speculation_ids.shape[1] + 1 if speculation_ids is not None else 1 + scores = scores.view(B, S, -1) + + next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) + for j in range(S): + _scores = scores[:, j] + if self.watermark_processor is not None: + _scores = self.watermark_processor(input_ids, _scores) + if self.repetition_processor is not None: + _scores = self.repetition_processor(input_ids, _scores) + + for warper in self.warpers: + _scores = warper(input_ids, _scores) + + _next_ids = self.choice(_scores) + scores[:, j] = _scores + next_ids[:, j] = _next_ids + next_ids = next_ids.view(B * S) + scores = scores.view(B * S, -1) + + if speculation_ids is not None: + accepted_ids = [] + B = next_ids.shape[0] // (speculation_ids.shape[1] + 1) + S = speculation_ids.shape[1] + 1 + indices = [] + for i in range(B): + _next_ids = next_ids[i * S : (i + 1) * S] + _speculated_ids = speculation_ids[i] + validate_speculative = _next_ids[:-1] == _speculated_ids + index = i * S + accepted = 1 + indices.append(index) + for valid in validate_speculative.tolist(): + if valid: + index += 1 + accepted += 1 + indices.append(index) + else: + break + accepted_ids.append(accepted) + + accepted_ids = torch.tensor( + accepted_ids, device=input_ids.device, dtype=input_ids.dtype + ) + next_ids = next_ids[indices] + scores = scores[indices] + indices = torch.arange(B, device=input_ids.device) * S + if speculation_scores is not None: + speculation_scores = speculation_scores[indices + accepted_ids - 1] + else: + accepted_ids = torch.ones_like(next_ids) - next_ids = self.choice(scores) + logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather( torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) ).view(-1) - return next_ids, next_logprobs + if speculate > 0: + speculative_ids = Greedy()(speculation_scores) + else: + speculative_ids = None + + return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids def filter(self, indices): """ diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index b786751b8..87da9022c 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -9,6 +9,7 @@ import torch.distributed import os + class Weights: """ A class representing weights for a model. @@ -29,6 +30,7 @@ class Weights: process_group: The process group for distributed training. _handles (Dict[str, Any]): Dictionary of file handles for opened weight files. """ + def __init__( self, filenames: List[Path], @@ -50,7 +52,7 @@ def __init__( f"Key {k} was found in multiple adapter files: {filename} and {routing[k]}" ) routing[k] = filename - + # set of keys that point to adapter files. Duplicates for these keys found # in main model files will be overridden. adapter_routes = set(routing.keys()) @@ -59,7 +61,9 @@ def __init__( with safe_open(filename, framework="pytorch") as f: for k in f.keys(): if k in adapter_routes: - logger.debug(f"Overriding main model weights with adapter weights for key: {k}") + logger.debug( + f"Overriding main model weights with adapter weights for key: {k}" + ) elif k in routing: raise RuntimeError( f"Key {k} was found in multiple non-adapter files: {filename} and {routing[k]}" @@ -114,7 +118,9 @@ def get_tensor(self, tensor_name: str): tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None): + def get_partial_sharded( + self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None + ): """Loads tensor with the given name and shards it along the given dimension. The optional range argument can be used to load and split on only a subset of the tensor. @@ -153,7 +159,9 @@ def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[ tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None): + def get_sharded( + self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None + ): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -163,19 +171,25 @@ def get_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim, range=range) - + def get_sharded_prefix(self, module_name: str, prefix: Union[str, Tuple], dim: int): if isinstance(prefix, str): return self.get_sharded(f"{prefix}.{module_name}", dim=dim) else: assert isinstance(prefix, tuple) assert len(prefix) == 2 - return self.get_sharded(f"{prefix[0]}.{module_name}", dim=dim, range=prefix[1]) - - def get_sharded_list(self, module_name: str, prefixes: List[Union[str, Tuple]], dim: int): + return self.get_sharded( + f"{prefix[0]}.{module_name}", dim=dim, range=prefix[1] + ) + + def get_sharded_list( + self, module_name: str, prefixes: List[Union[str, Tuple]], dim: int + ): return [self.get_sharded_prefix(module_name, p, dim=dim) for p in prefixes] - def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str, dim: int): + def get_multi_weights_col( + self, prefixes: List[Union[str, Tuple]], quantize: str, dim: int + ): if quantize in ["gptq", "awq"]: try: qweight = torch.cat( @@ -186,12 +200,8 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `lorax-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - qzeros = torch.cat( - self.get_sharded_list("qzeros", prefixes, dim=1), dim=1 - ) - scales = torch.cat( - self.get_sharded_list("scales", prefixes, dim=1), dim=1 - ) + qzeros = torch.cat(self.get_sharded_list("qzeros", prefixes, dim=1), dim=1) + scales = torch.cat(self.get_sharded_list("scales", prefixes, dim=1), dim=1) if quantize == "gptq": # no tensor parallelism, so remove the range if provided prefixes = [p[0] if isinstance(p, tuple) else p for p in prefixes] @@ -343,6 +353,7 @@ def _set_gptq_params(self, model_id): except Exception: pass + def get_start_stop_idxs_for_rank(offset, size, rank, world_size): block_size = size // world_size start = offset + rank * block_size @@ -350,10 +361,12 @@ def get_start_stop_idxs_for_rank(offset, size, rank, world_size): return start, stop -def shard_on_dim(t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup): +def shard_on_dim( + t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup +): world_size = process_group.size() rank = process_group.rank() - + size = t.shape[dim] start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) diff --git a/server/punica_kernels/setup.py b/server/punica_kernels/setup.py index 2339026a3..61fd01300 100644 --- a/server/punica_kernels/setup.py +++ b/server/punica_kernels/setup.py @@ -11,10 +11,10 @@ def remove_unwanted_pytorch_nvcc_flags(): REMOVE_NVCC_FLAGS = [ - '-D__CUDA_NO_HALF_OPERATORS__', - '-D__CUDA_NO_HALF_CONVERSIONS__', - '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', - '-D__CUDA_NO_HALF2_OPERATORS__', + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", ] for flag in REMOVE_NVCC_FLAGS: try: @@ -90,7 +90,8 @@ def generate_flashinfer_cu() -> List[str]: "punica_kernels/rms_norm/rms_norm_cutlass.cu", "punica_kernels/sgmv/sgmv_cutlass.cu", "punica_kernels/sgmv_flashinfer/sgmv_all.cu", - ] + generate_flashinfer_cu(), + ] + + generate_flashinfer_cu(), include_dirs=[ str(root.resolve() / "third_party/cutlass/include"), str(root.resolve() / "third_party/flashinfer/include"), diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index 793b11d37..c8067ffb7 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -5,40 +5,37 @@ def test_merge_adapter_weights(): - W_0 = torch.tensor([ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ]) - model_weights = { - "model.layers.10.self_attn.q_proj.weight": W_0 - } - - A = torch.tensor([ - [1, 2, 3], - [4, 5, 6] - ]) - B = torch.tensor([ - [1, 2], - [3, 4], - [5, 6] - ]) + W_0 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + model_weights = {"model.layers.10.self_attn.q_proj.weight": W_0} + + A = torch.tensor([[1, 2, 3], [4, 5, 6]]) + B = torch.tensor([[1, 2], [3, 4], [5, 6]]) adapter_weights = { "base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight": A, - "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight": B + "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight": B, } - W_expected = torch.tensor([ - [ 5.5000, 8.0000, 10.5000], - [13.5000, 18.0000, 22.5000], - [21.5000, 28.0000, 34.5000] - ]) + W_expected = torch.tensor( + [ + [5.5000, 8.0000, 10.5000], + [13.5000, 18.0000, 22.5000], + [21.5000, 28.0000, 34.5000], + ] + ) adapter_config = LoraConfig(r=2, lora_alpha=1, fan_in_fan_out=False) - merged_weights, processed_adapter_weight_names = merge_adapter_weights(model_weights, adapter_weights, adapter_config) + merged_weights, processed_adapter_weight_names = merge_adapter_weights( + model_weights, adapter_weights, adapter_config + ) assert len(merged_weights) == 1 assert merged_weights["model.layers.10.self_attn.q_proj.weight"].equal(W_expected) - + assert len(processed_adapter_weight_names) == 2 - assert "base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight" in processed_adapter_weight_names - assert "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight" in processed_adapter_weight_names \ No newline at end of file + assert ( + "base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight" + in processed_adapter_weight_names + ) + assert ( + "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight" + in processed_adapter_weight_names + ) diff --git a/server/tests/utils/test_segments.py b/server/tests/utils/test_segments.py index c81cd0892..02631b282 100644 --- a/server/tests/utils/test_segments.py +++ b/server/tests/utils/test_segments.py @@ -4,7 +4,6 @@ from lorax_server.utils.segments import find_segments, SegmentConcatBuilder - @pytest.mark.parametrize( "adapter_indices,expected_segments,expected_segment_indices", [ diff --git a/server/tests/utils/test_sgmv.py b/server/tests/utils/test_sgmv.py index 5563b2009..590c96857 100644 --- a/server/tests/utils/test_sgmv.py +++ b/server/tests/utils/test_sgmv.py @@ -14,12 +14,12 @@ def lora_ref_impl( layer_idx: int, ): for i in range(len(wa)): - xi = x[s_start[i]:s_end[i]] + xi = x[s_start[i] : s_end[i]] wai = wa[i][layer_idx, :, :] wbi = wb[i][layer_idx, :, :] - yi = y[s_start[i]:s_end[i]] - tmp = (xi @ wai) - y[s_start[i]:s_end[i]] = (yi + tmp @ wbi) + yi = y[s_start[i] : s_end[i]] + tmp = xi @ wai + y[s_start[i] : s_end[i]] = yi + tmp @ wbi @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -32,15 +32,19 @@ def test_add_lora_sgmv_cutlass(lora_rank: int): nlayers = 2 device = torch.device("cuda:0") - + y = torch.zeros((B, H), dtype=torch.float16, device=device) x = torch.randn((B, H), dtype=torch.float16, device=device) wa = torch.randn(nlayers, H, r, dtype=torch.float16, device=device) wb = torch.randn(nlayers, r, H, dtype=torch.float16, device=device) wa_sgmv = orient_for_rank(wa, lora_rank) - wa_ptr = torch.tensor([wa_sgmv.data_ptr(), wa_sgmv.data_ptr()], dtype=torch.int64, device=device) - wb_ptr = torch.tensor([wb.data_ptr(), wb.data_ptr()], dtype=torch.int64, device=device) + wa_ptr = torch.tensor( + [wa_sgmv.data_ptr(), wa_sgmv.data_ptr()], dtype=torch.int64, device=device + ) + wb_ptr = torch.tensor( + [wb.data_ptr(), wb.data_ptr()], dtype=torch.int64, device=device + ) s_start = torch.tensor([0, 2], dtype=torch.int32, device=device) s_end = torch.tensor([1, 3], dtype=torch.int32, device=device)