diff --git a/docs/guides/cuda_graphs.md b/docs/guides/cuda_graphs.md new file mode 100644 index 000000000..5e8a21ac5 --- /dev/null +++ b/docs/guides/cuda_graphs.md @@ -0,0 +1,43 @@ +LoRAX supports compiling the model into a static CUDA Graph to speedup inference by upwards of 2x. See [Accelerating PyTorch with CUDA Graphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for more details on CUDA graphs and how they can reduce latency. + +## Usage + +To enable this (experimental) feature: + +``` +lorax-launcher ... --compile +``` + +## When should I use this? + +CUDA graph compilation is a simple way to decrease latency for smaller LLMs (O(1b params)) that are compute bound rather than memory bound. + +There is a tradeoff to be aware of when using CUDA graphs, namely that it increases memory overhead by 3-10GB depending on model size. However, the observed decrease in latency can be as much as 50%, so if you don't need to run with very large batch sizes and are more latency constrained than throughput, this is a very compelling feature to enable. + +In practice, CUDA graphs are most useful in cases where there are excess GPU flops available, such as during decoding. As such, we do not use the compiled version of the model during prefill, only during the decoding steps. Which means in practice that the benefits of enabling compilation will be most pronounced when generating longer sequences (for which more time is spent during decoding). + +## Limitations + +Current limitations: + +- Batch size < 256 +- Context length (input + output) < 8192 +- LoRA rank >= 8 and <= 64 +- Only one LoRA rank in the batch +- 1 GPU (no sharding) + +If any of these conditions are not met, then LoRAX will fallback to using eager execution for the batch. + +## Benchmarks + +gpt2-medium, 1x A100, time to generate 100 tokens: + +no adapter: + +- baseline: 1.044 s +- cuda graph: 0.422 s + +1 adapter (rank 16): + +- baseline: 1.503 s +- cuda graph: 0.583 s \ No newline at end of file diff --git a/docs/reference/launcher.md b/docs/reference/launcher.md index 5f8d664d9..099d49dd6 100644 --- a/docs/reference/launcher.md +++ b/docs/reference/launcher.md @@ -1,6 +1,7 @@ # LoRAX Launcher ```shell +LoRAX Launcher Usage: lorax-launcher [OPTIONS] @@ -24,7 +25,7 @@ Options: [default: hub] --adapter-source - The source of the model to load. Can be `hub` or `s3`. `hub` will load the model from the huggingface hub. `s3` will load the model from the predibase S3 bucket + The source of the model to load. Can be `hub` or `s3` or `pbase` `hub` will load the model from the huggingface hub. `s3` will load the model from the predibase S3 bucket. `pbase` will load an s3 model but resolve the metadata from a predibase server [env: ADAPTER_SOURCE=] [default: hub] @@ -55,7 +56,12 @@ Options: Whether you want the model to be quantized. This will use `bitsandbytes` for quantization on the fly, or `gptq` [env: QUANTIZE=] - [possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq] + [possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq, awq] + + --compile + Whether you want to compile the model into a CUDA graph. This will speed up decoding but increase GPU memory usage + + [env: COMPILE=] --dtype The dtype to be forced upon the model. This option cannot be used with `--quantize` @@ -152,13 +158,13 @@ Options: --hostname The IP address to listen on - [env: HOSTNAME=b3687ab43244] + [env: HOSTNAME=] [default: 0.0.0.0] -p, --port The port to listen on - [env: PORT=80] + [env: PORT=] [default: 3000] --shard-uds-path @@ -182,7 +188,7 @@ Options: --huggingface-hub-cache The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance - [env: HUGGINGFACE_HUB_CACHE=/data] + [env: HUGGINGFACE_HUB_CACHE=] --weights-cache-override The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance diff --git a/integration-tests/scripts/dynamic_adapter_loading.py b/integration-tests/scripts/dynamic_adapter_loading.py index 1fa784f6d..6ec3f43bc 100644 --- a/integration-tests/scripts/dynamic_adapter_loading.py +++ b/integration-tests/scripts/dynamic_adapter_loading.py @@ -47,12 +47,12 @@ def query_lorax(args): prompt, adapter_id = args start_t = time.time() request_params = { - "max_new_tokens": 64, + "max_new_tokens": 128, "temperature": None, "details": True, } if adapter_id is not None: - # request_params["adapter_source"] = "local" + request_params["adapter_source"] = "local" request_params["adapter_id"] = adapter_id print("request_params", request_params) @@ -113,10 +113,14 @@ def main(): # ] # Mistral - prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" - adapters = [ - "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", - ] + # prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" + # adapters = [ + # "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", + # ] + + # GPT2 + prompt = "Brand Name : First Aid Beauty ; Product Name : Ultra Repair Cream Intense Hydration ; Review Title :" + adapters = ["/data/adapters/9789adb7-cd03-4862-91d5-b41b6746682e_ludwig/model_weights"] adapters += [None] # adapters = [None] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9e9ce877e..10c6451cf 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -135,6 +135,11 @@ struct Args { #[clap(long, env, value_enum)] quantize: Option, + /// Whether you want to compile the model into a CUDA graph. + /// This will speed up decoding but increase GPU memory usage. + #[clap(long, env, value_enum)] + compile: bool, + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. #[clap(long, env, value_enum)] dtype: Option, @@ -342,6 +347,7 @@ fn shard_manager( source: String, adapter_source: String, quantize: Option, + compile: bool, dtype: Option, trust_remote_code: bool, uds_path: String, @@ -407,6 +413,11 @@ fn shard_manager( shard_args.push(quantize.to_string()) } + // CUDA graph compilation + if compile { + shard_args.push("--compile".to_string()); + } + if let Some(dtype) = dtype { shard_args.push("--dtype".to_string()); shard_args.push(dtype.to_string()) @@ -853,6 +864,7 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; + let compile = args.compile; let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; @@ -868,6 +880,7 @@ fn spawn_shards( source, adapter_source, quantize, + compile, dtype, trust_remote_code, uds_path, diff --git a/mkdocs.yml b/mkdocs.yml index d9d5ea21c..27925f901 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,8 @@ nav: - Launcher: reference/launcher.md - REST API: reference/rest_api.md - Python Client: reference/python_client.md - # - 🔬 Guides: + - 🔬 Guides: + - CUDA Graph Compilation: guides/cuda_graphs.md # - GPUs: guides/gpus.md # - Fine-Tuning: guides/fine_tuning.md # - Quantization: guides/quantization.md diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 987f0096c..945f7ab56 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -31,6 +31,7 @@ def serve( revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + compile: bool = False, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/lorax-server", @@ -82,7 +83,7 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, source, adapter_source + model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code, uds_path, source, adapter_source ) def _download_weights( diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index f6a12135b..dcb1f027e 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -97,6 +97,7 @@ def get_model( revision: Optional[str], sharded: bool, quantize: Optional[str], + compile: bool, dtype: Optional[str], trust_remote_code: bool, source: str, @@ -137,6 +138,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, dtypetrust_remote_code=trust_remote_code, ) @@ -147,6 +149,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -159,6 +162,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -169,6 +173,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -181,6 +186,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -190,12 +196,13 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type == "mpt": return MPTSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, revision, quantize=quantize, compile=compile, trust_remote_code=trust_remote_code ) if model_type == "gpt_neox": @@ -204,6 +211,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -212,6 +220,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -220,6 +229,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -232,6 +242,7 @@ def get_model( adapter_source, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -242,6 +253,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -254,6 +266,7 @@ def get_model( adapter_source, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -268,6 +281,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -278,6 +292,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -286,6 +301,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -298,6 +314,7 @@ def get_model( adapter_source, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -311,6 +328,7 @@ def get_model( adapter_source, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -324,6 +342,7 @@ def get_model( adapter_source, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -337,6 +356,7 @@ def get_model( adapter_source, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -347,6 +367,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -356,6 +377,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -380,6 +402,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -388,6 +411,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -399,6 +423,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -407,6 +432,7 @@ def get_model( model_id, revision, quantize=quantize, + compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/lorax_server/models/bloom.py b/server/lorax_server/models/bloom.py index 0d548128a..1448f9aab 100644 --- a/server/lorax_server/models/bloom.py +++ b/server/lorax_server/models/bloom.py @@ -42,9 +42,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with Bloom") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index d9d09cc40..5e00ea809 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -471,9 +471,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with CausalLM") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index 0216734d1..b89e2ba51 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -375,7 +375,7 @@ def forward( slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, - adapter_data: AdapterBatchData, # TODO: plumb this through + adapter_data: AdapterBatchData, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index da335d98b..20980bacd 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -218,6 +218,7 @@ def __init__( dim=self.head_size, base=config.rope_theta, device=weights.device, + dtype=weights.dtype, ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 1c0902265..9b51e6443 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -227,6 +227,7 @@ def __init__( dim=self.head_size, base=config.rope_theta, device=weights.device, + dtype=weights.dtype, ) self.softmax_scale = self.head_size**-0.5 @@ -549,7 +550,7 @@ def forward( input_lengths: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, - prefill_cache_indices: Optional[torch.Tensor], + prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: if prefill_cache_indices is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index a878b0471..6a3d8b1f0 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -314,6 +314,7 @@ def __init__( dim=self.head_size, base=config.rope_theta, device=weights.device, + dtype=weights.dtype, ) self.softmax_scale = self.head_size ** -0.5 @@ -963,7 +964,7 @@ def forward( input_lengths: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, - prefill_cache_indices: Optional[torch.Tensor], + prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: if prefill_cache_indices is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 3f3586d7e..8b1a3fdc2 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -82,6 +82,7 @@ def __init__( dim=config.rotary_dim, base=rope_theta, device=weights.device, + dtype=weights.dtype, ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 1e4a69cca..964910a1e 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -185,6 +185,7 @@ def __init__( dim=self.head_size, base=config.rope_theta, device=weights.device, + dtype=weights.dtype, ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 2426d754e..6bb8fea6f 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -130,7 +130,7 @@ def __init__( self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( - dim=self.head_size, base=10000.0, device=weights.device + dim=self.head_size, base=10000.0, device=weights.device, dtype=weights.dtype ) self.softmax_scale = self.head_size ** (-0.5) @@ -242,7 +242,7 @@ def __init__( self.head_size = hidden_size // num_heads self.rotary_emb = PositionRotaryEmbedding.static( - self.head_size, base=10000.0, device=weights.device + self.head_size, base=10000.0, device=weights.device, dtype=weights.dtype ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 5d5c1f6b5..91be95a86 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -10,7 +10,6 @@ from dataclasses import dataclass from opentelemetry import trace from peft import LoraConfig -from tqdm import tqdm from transformers import PreTrainedTokenizerBase from typing import Optional, Set, Tuple, List, Type, Union, Dict @@ -33,6 +32,8 @@ from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.weights import shard_on_dim +from lorax_server.utils.graph import GraphCache +from lorax_server.utils.sgmv import get_tmp_tensor tracer = trace.get_tracer(__name__) @@ -677,6 +678,7 @@ def __init__( rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, + compile: bool = False, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads @@ -697,6 +699,9 @@ def __init__( sliding_window=sliding_window, ) + self.compile = compile + self.model_graph_wrapper: GraphCache = None + self.target_to_layer = self.adapter_target_to_layer() @property @@ -873,6 +878,19 @@ def warmup(self, batch: FlashCausalLMBatch): torch.cuda.synchronize(self.device) + graph_cache_memory = 0 + if self.compile: + if self.world_size > 1: + raise ValueError("Cannot enable `--compile` when sharding across multiple GPUs") + + # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. + # Needs to be estimated here rather than fully initialized as the graph cache relies on the + # cache manager being set. + self.model_graph_wrapper = GraphCache(self.model, self.device, self.adapter_layers) + graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory() + logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024) + torch.cuda.synchronize(self.device) + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.dtype).element_size() @@ -880,6 +898,8 @@ def warmup(self, batch: FlashCausalLMBatch): total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_free_memory, _ = torch.cuda.mem_get_info(self.device) + total_free_memory -= graph_cache_memory + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory free_memory = max( @@ -905,6 +925,14 @@ def warmup(self, batch: FlashCausalLMBatch): self.device, ) + torch.cuda.synchronize(self.device) + + if self.model_graph_wrapper is not None: + # Warmup the graph cache. Needs to be done after setting cache manager as + # tracing will use the static kv cache tensors + self.model_graph_wrapper.warmup() + torch.cuda.synchronize(self.device) + return int(num_blocks * BLOCK_SIZE) def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: @@ -913,8 +941,17 @@ def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: ) def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + prefill = batch.cu_seqlen_prefill is not None + model = self.model + if ( + self.model_graph_wrapper is not None and + not prefill and + self.model_graph_wrapper.can_use_graph(batch, adapter_data) + ): + model = self.model_graph_wrapper + # Model Forward - return self.model.forward( + return model.forward( input_ids=batch.input_ids, position_ids=batch.position_ids, cu_seqlen_prefill=batch.cu_seqlen_prefill, diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index db28c5d70..1aea62437 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -43,6 +43,7 @@ def __init__( adapter_source: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -110,6 +111,7 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) @property diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index ac108a4d0..e67d22336 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -36,6 +36,7 @@ def __init__( adapter_source: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -103,6 +104,7 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) @property diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index c150b582c..1b4beae1a 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -308,6 +308,7 @@ def __init__( adapter_source: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -386,6 +387,7 @@ def __init__( rank=rank, world_size=world_size, sliding_window=config.sliding_window, + compile=compile, ) @property @@ -397,8 +399,17 @@ def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch def forward(self, batch: FlashMistralBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + prefill = batch.cu_seqlen_prefill is not None + model = self.model + if ( + self.model_graph_wrapper is not None and + not prefill and + self.model_graph_wrapper.can_use_graph(batch, adapter_data) + ): + model = self.model_graph_wrapper + # Model Forward - logits = self.model.forward( + logits = model.forward( input_ids=batch.input_ids, position_ids=batch.position_ids, cu_seqlen_prefill=batch.cu_seqlen_prefill, diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 31a005586..18269da47 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -315,6 +315,7 @@ def __init__( adapter_source: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -393,6 +394,7 @@ def __init__( rank=rank, world_size=world_size, sliding_window=config.sliding_window, + compile=compile, ) @property @@ -404,8 +406,17 @@ def batch_type(self) -> Type[FlashMixtralBatch]: return FlashMixtralBatch def forward(self, batch: FlashMixtralBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: + prefill = batch.cu_seqlen_prefill is not None + model = self.model + if ( + self.model_graph_wrapper is not None and + not prefill and + self.model_graph_wrapper.can_use_graph(batch, adapter_data) + ): + model = self.model_graph_wrapper + # Model Forward - logits = self.model.forward( + logits = model.forward( input_ids=batch.input_ids, position_ids=batch.position_ids, cu_seqlen_prefill=batch.cu_seqlen_prefill, diff --git a/server/lorax_server/models/flash_neox.py b/server/lorax_server/models/flash_neox.py index 2b913624f..6989d9f15 100644 --- a/server/lorax_server/models/flash_neox.py +++ b/server/lorax_server/models/flash_neox.py @@ -24,6 +24,7 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -68,4 +69,5 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index 4284796f0..8e9c1d9d6 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -40,6 +40,7 @@ def __init__( adapter_source: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -108,6 +109,7 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) @property diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index f87e10f31..15679965d 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -41,6 +41,7 @@ def __init__( adapter_source: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -109,6 +110,7 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) @property diff --git a/server/lorax_server/models/flash_rw.py b/server/lorax_server/models/flash_rw.py index 723a11e87..055887e06 100644 --- a/server/lorax_server/models/flash_rw.py +++ b/server/lorax_server/models/flash_rw.py @@ -25,6 +25,7 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -74,4 +75,5 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) diff --git a/server/lorax_server/models/flash_santacoder.py b/server/lorax_server/models/flash_santacoder.py index 4a136683f..88c3a75cb 100644 --- a/server/lorax_server/models/flash_santacoder.py +++ b/server/lorax_server/models/flash_santacoder.py @@ -27,6 +27,7 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -78,6 +79,7 @@ def __init__( device=device, rank=rank, world_size=world_size, + compile=compile, ) def decode(self, generated_ids: List[int]) -> str: diff --git a/server/lorax_server/models/galactica.py b/server/lorax_server/models/galactica.py index 898abe3e0..e95340afa 100644 --- a/server/lorax_server/models/galactica.py +++ b/server/lorax_server/models/galactica.py @@ -158,9 +158,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with GalacticaSharded") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/models/gpt_neox.py b/server/lorax_server/models/gpt_neox.py index 648f5d76d..ae3a56862 100644 --- a/server/lorax_server/models/gpt_neox.py +++ b/server/lorax_server/models/gpt_neox.py @@ -24,9 +24,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with GPT-NeoX") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/models/mpt.py b/server/lorax_server/models/mpt.py index fab079376..52a06718b 100644 --- a/server/lorax_server/models/mpt.py +++ b/server/lorax_server/models/mpt.py @@ -43,8 +43,12 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with MPT") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/models/opt.py b/server/lorax_server/models/opt.py index 9b545e840..86755c363 100644 --- a/server/lorax_server/models/opt.py +++ b/server/lorax_server/models/opt.py @@ -22,9 +22,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with OPT") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/models/rw.py b/server/lorax_server/models/rw.py index 56fa6cbb6..d8c5f6bfe 100644 --- a/server/lorax_server/models/rw.py +++ b/server/lorax_server/models/rw.py @@ -12,9 +12,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with RW") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/lorax_server/models/santacoder.py b/server/lorax_server/models/santacoder.py index d52c89813..67c68fe12 100644 --- a/server/lorax_server/models/santacoder.py +++ b/server/lorax_server/models/santacoder.py @@ -19,9 +19,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with SantaCoder") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index be432fd33..33587fae6 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -504,9 +504,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with Seq2SeqLM") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/lorax_server/models/t5.py b/server/lorax_server/models/t5.py index b8cd7ab8c..327dfa3a0 100644 --- a/server/lorax_server/models/t5.py +++ b/server/lorax_server/models/t5.py @@ -25,9 +25,13 @@ def __init__( model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if compile: + raise ValueError("`--compile` is not supported with T5") + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5fdae98db..db2a4dfbd 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -199,6 +199,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + compile: bool, dtype: Optional[str], trust_remote_code: bool, uds_path: Path, @@ -211,6 +212,7 @@ async def serve_inner( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + compile: bool = False, dtype: Optional[str] = None, trust_remote_code: bool = False, ): @@ -227,7 +229,7 @@ async def serve_inner( try: model = get_model( - model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, source, adapter_source + model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code, source, adapter_source ) except Exception: logger.exception("Error when initializing model") @@ -275,7 +277,7 @@ async def serve_inner( await server.stop(0) asyncio.run( - serve_inner(model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner(model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code) ) diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py new file mode 100644 index 000000000..9a53fe1e0 --- /dev/null +++ b/server/lorax_server/utils/graph.py @@ -0,0 +1,411 @@ +# CUDA Graph implementation modified from vLLM: +# https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py + +from dataclasses import dataclass +from functools import lru_cache +from statistics import median +from typing import TYPE_CHECKING, List, Optional, Tuple +import numpy as np + +import torch +from torch import nn +from tqdm import tqdm + +from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, AdapterWeightData, RankSegments +from lorax_server.models.cache_manager import get_cache_manager, BLOCK_SIZE +from lorax_server.utils.sgmv import get_tmp_expand_size, get_tmp_tensors, use_cutlass_shrink + +if TYPE_CHECKING: + from lorax_server.models.flash_causal_lm import FlashCausalLMBatch + + +# TODO(travis): make this configurable by model / user +MAX_BATCH_SIZE = 256 +MAX_CONTEXT_LENGTH = 8192 +MAX_RANK = 64 + +SLOT_PAD_VALUE = -1 +SEGMENT_PAD_VALUE = -1 + +# Cached batch sizes used in vLLM. This and the helper function `get_cached_batch_size` below +# must be kept in sync. +BATCH_SIZE_INCREMENT = 32 +CACHED_BATCH_SIZES = [1, 2, 4, 8, 16] + [BATCH_SIZE_INCREMENT * (i + 1) for i in range(8)] + +# Include 0 to ensure we can use cuda graphs without adapters +# TODO(travis): use padding to allow for more ranks without increasing memory usage +CACHED_MAX_RANKS = [0, 8, 16, 32, 64] +_allowed_ranks = set(CACHED_MAX_RANKS) + +MAX_SAMPLES = 3 + + +def get_cached_batch_size(batch_size: int) -> int: + if batch_size == 1: + return 1 + if batch_size == 2: + return 2 + if batch_size <= 4: + return 4 + if batch_size <= 8: + return 8 + if batch_size <= 16: + return 16 + return (batch_size + BATCH_SIZE_INCREMENT - 1) // BATCH_SIZE_INCREMENT * BATCH_SIZE_INCREMENT + + +def pad_and_fill(dest: torch.Tensor, src: torch.Tensor, pad_value: int): + dest[:src.shape[0]] = src + dest[src.shape[0]:].fill_(pad_value) + + +def next_pow_2(x: int) -> int: + assert x > 0 + return 1 << (x-1).bit_length() + + +@dataclass +class GraphState: + input_ids: torch.Tensor + position_ids: torch.Tensor + block_tables: torch.Tensor + slots: torch.Tensor + input_lengths: torch.Tensor + adapter_data: AdapterBatchData + + +@lru_cache(maxsize=1) +def get_max_graph_state(device: torch.device, adapter_layers: Tuple[str]) -> GraphState: + max_num_blocks = (MAX_CONTEXT_LENGTH + BLOCK_SIZE - 1) // BLOCK_SIZE + block_tables_arr = np.zeros((MAX_BATCH_SIZE, max_num_blocks), dtype=np.int32) + block_tables = torch.from_numpy(block_tables_arr).to(device=device) + + input_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device) + position_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) + slots = torch.full((MAX_BATCH_SIZE,), SLOT_PAD_VALUE, dtype=torch.int64, device=device) + input_lengths = torch.ones((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) + + tmp_shrink, tmp_expand = get_tmp_tensors(MAX_BATCH_SIZE, MAX_RANK, device) + + adapter_weight_data = {} + for layer_name in adapter_layers: + adapter_weight_data[layer_name] = AdapterWeightData( + lora_a={}, + lora_b={}, + adapter_index_configs={}, + rank_data={ + MAX_RANK: RankSegments( + rank=MAX_RANK, + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + segment_starts=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device), + segment_ends=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device), + ), + }, + ) + + return GraphState( + input_ids=input_ids, + position_ids=position_ids, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + adapter_data=AdapterBatchData( + meta=AdapterBatchMetadata( + adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + adapter_set=set(), + adapter_segments=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + segment_indices=[], + ), + data=adapter_weight_data, + ), + ) + + +class GraphWrapper: + def __init__( + self, + graph: torch.cuda.CUDAGraph, + memory_pool: Tuple[int, int], + input_state: GraphState, + output_states: torch.Tensor, + model, + ): + self.graph = graph + self.memory_pool = memory_pool + self.input_state = input_state + self.output_states = output_states + self.model = model + + @staticmethod + def trace( + model: nn.Module, + device: torch.device, + adapter_layers: Tuple[str], + batch_size: int, + max_rank: int, + memory_pool: Tuple[int, int], + ) -> "GraphWrapper": + max_input_state = get_max_graph_state(device, adapter_layers) + + # WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2 + # as the segment size. This is a workaround until we can figure out why. + # Specifically, this issue has been observed with batch_size=96. + # I suspect it is related to synchronization and the chunk size (256) used in the kernel. + # But we need to investigate further. + segment_size = next_pow_2(batch_size) + + adapter_weight_data = {} + for layer_name, weight_data in max_input_state.adapter_data.data.items(): + tmp_expand_size = get_tmp_expand_size(segment_size) + + tmp_shrink = weight_data.rank_data[MAX_RANK].tmp_shrink + if use_cutlass_shrink(max_rank): + # cutlass shrink uses a custom temp buffer per rank + tmp_shrink = tmp_shrink[:tmp_expand_size] + + adapter_weight_data[layer_name] = AdapterWeightData( + lora_a={}, + lora_b={}, + adapter_index_configs={}, + rank_data={ + max_rank: RankSegments( + rank=max_rank, + tmp_shrink=tmp_shrink, + tmp_expand=weight_data.rank_data[MAX_RANK].tmp_expand[:tmp_expand_size], + lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], + lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], + segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:segment_size], + segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:segment_size], + ), + } if max_rank > 0 else {}, + ) + + input_state = GraphState( + input_ids=max_input_state.input_ids[:batch_size], + position_ids=max_input_state.position_ids[:batch_size], + block_tables=max_input_state.block_tables[:batch_size], + slots=max_input_state.slots[:batch_size], + input_lengths=max_input_state.input_lengths[:batch_size], + adapter_data=AdapterBatchData( + meta=AdapterBatchMetadata( + adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], + adapter_set=max_input_state.adapter_data.meta.adapter_set, + adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], + segment_indices=max_input_state.adapter_data.meta.segment_indices, + ), + data=adapter_weight_data, + ), + ) + + torch.cuda.synchronize(device) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=memory_pool): # noqa: SIM117 + output_states = model.forward( + input_ids=input_state.input_ids, + position_ids=input_state.position_ids, + cu_seqlen_prefill=None, + kv_cache=get_cache_manager().kv_cache, + block_tables=input_state.block_tables, + slots=input_state.slots, + input_lengths=input_state.input_lengths, + max_s=MAX_CONTEXT_LENGTH, + adapter_data=input_state.adapter_data, + lm_head_indices=None, + ) + + torch.cuda.synchronize(device) + + return GraphWrapper( + graph, graph.pool(), input_state, output_states, model + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> None: + pad_and_fill(self.input_state.input_ids, input_ids, 0) + pad_and_fill(self.input_state.position_ids, position_ids, 0) + pad_and_fill(self.input_state.slots, slots, SLOT_PAD_VALUE) + pad_and_fill(self.input_state.input_lengths, input_lengths, 0) + + self.input_state.block_tables.zero_() + self.input_state.block_tables[:block_tables.shape[0], :block_tables.shape[1]] = block_tables + + for layer_name, weight_data in self.input_state.adapter_data.data.items(): + if layer_name not in adapter_data.data: + # zero out all the segments + for rank_data in weight_data.rank_data.values(): + rank_data.segment_starts.fill_(SEGMENT_PAD_VALUE) + rank_data.segment_ends.fill_(SEGMENT_PAD_VALUE) + continue + + source_data = adapter_data.data[layer_name] + dest_data = weight_data + for rank, source_rank_data in source_data.rank_data.items(): + dest_rank_data = dest_data.rank_data[rank] + + pad_and_fill(dest_rank_data.lora_a_ptr, source_rank_data.lora_a_ptr, 0) + pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0) + + pad_and_fill(dest_rank_data.segment_starts, source_rank_data.segment_starts, SEGMENT_PAD_VALUE) + pad_and_fill(dest_rank_data.segment_ends, source_rank_data.segment_ends, SEGMENT_PAD_VALUE) + + self.graph.replay() + + return self.output_states[:input_ids.shape[0]] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +class GraphCache: + def __init__(self, model: nn.Module, device: torch.device, adapter_layers: List[str]): + self.model = model + self.device = device + self.adapter_layers = tuple(adapter_layers) + self.memory_pool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + self.cache = {} + + def can_use_graph( + self, + batch: "FlashCausalLMBatch", + adapter_data: AdapterBatchData, + ) -> bool: + ranks = adapter_data.ranks() + nranks = len(ranks) + max_rank = max(ranks) if len(ranks) > 0 else 0 + + batch_size = batch.input_ids.shape[0] + max_s = batch.max_seqlen + + # TODO(travis): allow using CUDA graphs with multi-rank batches + return ( + torch.cuda.is_available() + and batch_size <= MAX_BATCH_SIZE + and max_s <= MAX_CONTEXT_LENGTH + and max_rank <= MAX_RANK + and nranks <= 1 + and max_rank in _allowed_ranks + ) + + def get_estimated_cache_memory(self) -> int: + # Store off graphs into temporary cache to discard after estimation + tmp_cache = {} + pool = None + + # Use the largest batch size to overestimate memory overhead + batch_size = CACHED_BATCH_SIZES[-1] + + samples = [] + for i, max_rank in enumerate(reversed(CACHED_MAX_RANKS)): + torch.cuda.synchronize(self.device) + free_memory_before, _ = torch.cuda.mem_get_info(self.device) + + key = (batch_size, max_rank) + graph = GraphWrapper.trace( + self.model, + self.device, + self.adapter_layers, + batch_size, + max_rank, + pool, + ) + tmp_cache[key] = graph + pool = graph.memory_pool + + torch.cuda.synchronize(self.device) + free_memory_after, _ = torch.cuda.mem_get_info(self.device) + + # Measure memory difference after tracing the graph, + # discard first sample to account for global state initialization + delta_memory = free_memory_before - free_memory_after + if i > 0: + samples.append(delta_memory) + + # Tracing all graphs can take a while, so limit the number of samples + if len(samples) == MAX_SAMPLES: + break + + # Estimate memory usage for all batch sizes and ranks + ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS) + per_graph_memory = median(samples) + return ngraphs * per_graph_memory + + def warmup(self): + ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS) + pool = None + with tqdm(total=ngraphs, desc="Trace CUDA graphs") as pbar: + for batch_size in reversed(CACHED_BATCH_SIZES): + pbar.set_postfix({'batch_size': batch_size}) + for max_rank in reversed(CACHED_MAX_RANKS): + key = (batch_size, max_rank) + graph = GraphWrapper.trace( + self.model, + self.device, + self.adapter_layers, + batch_size, + max_rank, + pool, + ) + self.cache[key] = graph + pool = graph.memory_pool + pbar.update(1) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + lm_head_indices: Optional[torch.Tensor] = None, + **kwargs + ) -> None: + batch_size = get_cached_batch_size(input_ids.shape[0]) + max_rank = adapter_data.max_rank + + key = (batch_size, max_rank) + if key not in self.cache: + self.cache[key] = GraphWrapper.trace( + self.model, + self.device, + self.adapter_layers, + batch_size, + max_rank, + self.memory_pool, + ) + + output_states = self.cache[key].forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + adapter_data=adapter_data, + lm_head_indices=lm_head_indices, + ) + + return output_states + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 35d39689d..2e19b3ff7 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -404,7 +404,7 @@ def forward_layer_type( end_idx: int, ) -> torch.Tensor: data = adapter_data.data.get(layer_type) - + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if end_idx - start_idx != result.shape[1]: proj = torch.zeros_like(result[:, start_idx:end_idx]) @@ -415,8 +415,9 @@ def forward_layer_type( lora_a_ptr = rank_segments.lora_a_ptr lora_b_ptr = rank_segments.lora_b_ptr if lora_a_ptr is not None and lora_b_ptr is not None: - v, tmp = lora_a_sgmv_cutlass( + v = lora_a_sgmv_cutlass( input, + rank_segments.tmp_shrink, lora_a_ptr, rank_segments.segment_starts, rank_segments.segment_ends, @@ -430,7 +431,7 @@ def forward_layer_type( lora_b_sgmv_cutlass( proj, v, - tmp, + rank_segments.tmp_expand, lora_b_ptr, rank_segments.segment_starts, rank_segments.segment_ends, @@ -674,7 +675,7 @@ def _get_rope_config(config): return getattr(config, "rope_scaling", None) class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq, scaling_factor): + def __init__(self, inv_freq, scaling_factor, max_position_embeddings, device, dtype): super().__init__() self.inv_freq = inv_freq self._seq_len_cached = 0 @@ -684,9 +685,10 @@ def __init__(self, inv_freq, scaling_factor): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None + self._update_cos_sin_cache(dtype, device, max_position_embeddings) @classmethod - def static(cls, config, dim, base, device): + def static(cls, config, dim, base, device, dtype): inv_freq = _create_inv_freq(dim, base, device) scaling_factor = None rope_scaling = _get_rope_config(config) @@ -716,7 +718,7 @@ def static(cls, config, dim, base, device): raise NotImplementedError( f"rope scaling type {rope_type} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings, device, dtype) @classmethod def load(cls, config, prefix, weights): @@ -781,9 +783,6 @@ def get_cos_sin( """ Return cos and sin for the asked position ids """ - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) return cos.unsqueeze(1), sin.unsqueeze(1) diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 72eea7e2e..da8587db2 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -6,8 +6,7 @@ from peft import LoraConfig from torch.distributed import ProcessGroup -from lorax_server.utils.sgmv import MIN_SGMV_RANK, orient_for_rank -from lorax_server.utils.weights import shard_on_dim +from lorax_server.utils.sgmv import MIN_SGMV_RANK, get_tmp_tensors, orient_for_rank # Constants @@ -28,6 +27,8 @@ @dataclass class RankSegments: rank: int + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor lora_a_ptr: torch.Tensor lora_b_ptr: torch.Tensor segment_starts: torch.Tensor @@ -72,6 +73,18 @@ def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, "BatchedLoraWeights continue data[k] = v.get_data(meta) return AdapterBatchData(meta=meta, data=data) + + def ranks(self) -> Set[int]: + return set( + rank_data.rank + for layer in self.data.values() + for rank_data in layer.rank_data.values() + ) + + @property + def max_rank(self) -> int: + ranks = self.ranks() + return max(ranks) if len(ranks) > 0 else 0 class MergedLoraWeights: @@ -124,7 +137,9 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: AdapterWeightData: The adapter weight data. """ - device = list(self.lora_weights.values())[0].weights_a.device + first_weights = list(self.lora_weights.values())[0] + device = first_weights.weights_a.device + dtype = first_weights.weights_a.dtype segment_indices = meta.segment_indices lora_a = { @@ -174,9 +189,14 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: rank_data = {} for rank, indices in rank_indices.items(): + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + rank_data[rank] = RankSegments( rank=rank, - lora_a_ptr=lora_a_ptr[indices], + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=lora_a_ptr_indices, lora_b_ptr=lora_b_ptr[indices], segment_starts=meta.adapter_segments[indices], segment_ends=meta.adapter_segments[[i+1 for i in indices]], diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 874ca20a8..26b383376 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -23,6 +23,10 @@ def has_sgmv() -> bool: return HAS_SGMV +def use_cutlass_shrink(lora_rank: int) -> bool: + return lora_rank < MIN_RANK_CUSTOM + + def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: return t.transpose(0, 1) @@ -90,29 +94,41 @@ def get_tmp_tensor(device: torch.device) -> torch.Tensor: return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) -@lru_cache(maxsize=1) +@lru_cache(maxsize=32) def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: tmp_size = _kernels.sgmv_cutlass_tmp_size(size) return torch.empty((tmp_size,), dtype=torch.uint8, device=device) +def get_tmp_expand_size(size: int) -> int: + return _kernels.sgmv_cutlass_tmp_size(size) + + +def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + if use_cutlass_shrink(lora_rank): + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + else: + tmp_shrink = get_tmp_tensor(device) + tmp_expand = get_tmp_tensor_for_size(nsegments, device) + return tmp_shrink, tmp_expand + + def lora_a_sgmv_cutlass( x: torch.Tensor, + tmp: torch.Tensor, wa_ptr: torch.Tensor, s_start: torch.IntTensor, s_end: torch.IntTensor, layer_idx: int, lora_rank: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: - tmp1 = get_tmp_tensor(x.device) - tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device) - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) else: - tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device) _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - return v, tmp + return v def lora_b_sgmv_cutlass( @@ -125,3 +141,44 @@ def lora_b_sgmv_cutlass( layer_idx: int, ): _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +""" +Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + +Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + v: Shape: `[B, R]`. Temporary vector. + x: Shape: `[B, H1]`. Input vectors. + wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. + wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. +""" + + +def add_lora_a_bgmv( + v: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) + + +def add_lora_b_bgmv( + y: torch.Tensor, + v: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) diff --git a/server/punica_kernels/README.md b/server/punica_kernels/README.md index 82e537cbc..60dcc2ffe 100644 --- a/server/punica_kernels/README.md +++ b/server/punica_kernels/README.md @@ -1,3 +1,5 @@ These kernels are forked from the [Punica](https://github.com/punica-ai/punica) project. -Forked from commit: https://github.com/punica-ai/punica/commit/87cb9f504cf4e97eb1339b0fbfcca15b3273a5d6 \ No newline at end of file +Forked from commit: https://github.com/punica-ai/punica/commit/07a40b9d30e98d88963e8a7e140120a25ac0d518 + +Modifications to BGMV kernel from vLLM: https://github.com/vllm-project/vllm \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h index 26edcf486..460eb2735 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -2,9 +2,10 @@ template void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, - const T* __restrict__ W, const int64_t* __restrict__ indicies, - int64_t batch_size, int64_t num_layers, int64_t layer_idx, - float scale); + T** __restrict__ W, + const int64_t* __restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, + int64_t layer_idx, float scale); // clang-format off diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh b/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh index d2ad68722..c30181e86 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh @@ -1,54 +1,60 @@ #pragma once +#include #include -#include - #include +#include #include +#include #include "flashinfer/vec_dtypes.cuh" namespace cg = cooperative_groups; // nthrs = (32, 4) -template -__global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, - const T* __restrict__ W, - const int64_t* __restrict__ indicies, - int64_t num_layers, int64_t layer_idx, - float scale) { +template +__global__ void +bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, + T** __restrict__ W, + const int64_t* __restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx]; + if (idx < 0) { + return; + } + auto block = cg::this_thread_block(); size_t j = blockIdx.x; - size_t batch_idx = blockIdx.y; - constexpr size_t vec_size = 16 / sizeof(T); - constexpr size_t tx = 32; - constexpr size_t ty = 4; constexpr size_t num_pipeline_stages = 2; constexpr size_t tile_size = tx * ty * vec_size; __shared__ T W_shared[num_pipeline_stages * tile_size]; __shared__ T X_shared[num_pipeline_stages * tile_size]; __shared__ float y_warpwise[ty]; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; auto pipe = cuda::make_pipeline(); + const T* W_ptr = W[idx]; + // pipeline load W/X and compute WX; pipe.producer_acquire(); cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + + W_ptr + (layer_idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t<16>(16), pipe); - cuda::memcpy_async( - X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t<16>(16), pipe); + cuda::aligned_size_t(X_copy_size), pipe); pipe.producer_commit(); size_t copy_idx, compute_idx; float y = 0.f; - flashinfer::vec_t x_vec, w_vec; + flashinfer::vec_t x_vec; + flashinfer::vec_t w_vec; size_t tile_idx; #pragma unroll @@ -60,15 +66,15 @@ __global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + + W_ptr + (layer_idx * feat_out + j) * feat_in + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t<16>(16), pipe); + cuda::aligned_size_t(W_copy_size), pipe); cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size, X + (batch_idx * feat_in) + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t<16>(16), pipe); + cuda::aligned_size_t(X_copy_size), pipe); } pipe.producer_commit(); @@ -132,27 +138,30 @@ __global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, // write Y; if (block.thread_rank() == 0) { - Y[batch_idx * feat_out + j] += y; + Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); } } // nthrs = (2, 16, 4) -template -__global__ void bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X, - const T* __restrict__ W, - const int64_t* __restrict__ indicies, - int64_t num_layers, int64_t layer_idx, - float scale) { +template +__global__ void +bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X, + T** __restrict__ W, + const int64_t* __restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx]; + + if (idx < 0) { + return; + } + auto block = cg::this_thread_block(); - constexpr size_t vec_size = 16 / sizeof(T); - constexpr size_t tx = feat_in / vec_size; - static_assert(feat_in % vec_size == 0); - constexpr size_t ty = 32 / tx; - static_assert(32 % tx == 0); - constexpr size_t tz = 4; size_t tile_idx = blockIdx.x; - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + const T* W_ptr = W[idx]; // load X; flashinfer::vec_t x_vec; @@ -160,7 +169,7 @@ __global__ void bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X, // load W; flashinfer::vec_t w_vec; - w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + w_vec.load(W_ptr + (layer_idx * feat_out + tile_idx * tz * ty) * feat_in + block.thread_rank() * vec_size); float sum = 0.f; @@ -177,41 +186,111 @@ __global__ void bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { - Y[batch_idx * feat_out + tile_idx * (tz * ty) + threadIdx.z * ty + - threadIdx.y] += sum; + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y] += static_cast(sum); } } template void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, - const T* __restrict__ W, const int64_t* __restrict__ indicies, - int64_t batch_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t vec_size = 16 / sizeof(T); + T** __restrict__ W, + const int64_t* __restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, + int64_t layer_idx, float scale) { + constexpr size_t vec_size = 8; + constexpr int tz = 4; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if constexpr (feat_in < feat_out) { - size_t tx = feat_in / vec_size; - size_t ty = 32 / tx; - size_t tz = 4; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, num_layers, layer_idx, scale); + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || + (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || + (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, layer_idx, + scale); + } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, layer_idx, + scale); + } else { + constexpr int ty = 8 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, layer_idx, + scale); + } } else { - assert(feat_in % (vec_size * 32) == 0); - dim3 nblks(feat_out, batch_size); - dim3 nthrs(32, 4); - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, num_layers, layer_idx, scale); + static_assert(feat_in % (vec_size * 32) == 0 || + feat_in % (vec_size * 16) == 0 || + feat_in % (vec_size * 8) == 0); + + if constexpr (feat_in % (vec_size * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { + constexpr int tx = 16; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, layer_idx, + scale); + } } } -#define INST_BGMV(feat_in, feat_out, T) \ - template void bgmv_kernel( \ - T* __restrict__ Y, const T* __restrict__ X, const T* __restrict__ W, \ - const int64_t* __restrict__ indicies, int64_t batch_size, \ - int64_t num_layers, int64_t layer_idx, float scale); +#define INST_BGMV(feat_in, feat_out, T) \ + template void bgmv_kernel( \ + T* __restrict__ Y, const T* __restrict__ X, \ + T** __restrict__ W, const int64_t* __restrict__ indicies, \ + int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ + int64_t layer_idx, float scale); -#define INST_BGMV_TWOSIDE(T, narrow, wide) \ - INST_BGMV(narrow, wide, T) \ - INST_BGMV(wide, narrow, T) +#define INST_BGMV_TWOSIDE(T, narrow, wide) \ + INST_BGMV(narrow, wide, T) \ + INST_BGMV(wide, narrow, T) \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/punica_ops.cc b/server/punica_kernels/punica_kernels/punica_ops.cc index 4f0a84c29..3722fb9a8 100644 --- a/server/punica_kernels/punica_kernels/punica_ops.cc +++ b/server/punica_kernels/punica_kernels/punica_ops.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -250,16 +251,18 @@ void append_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr, //====== bgmv ====== template -inline bool launch_bgmv_kernel(T* Y, const T* X, const T* W, +inline bool launch_bgmv_kernel(T* Y, const T* X, T** W, const int64_t* lora_indices, uint16_t in_features, uint16_t out_features, - int64_t batch_size, int64_t num_layers, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, int64_t layer_idx, float scale) { switch (pack_u16(in_features, out_features)) { #define CASE_ONESIDE(_T, feat_in, feat_out) \ case pack_u16(feat_in, feat_out): \ - bgmv_kernel(Y, X, W, lora_indices, batch_size, \ - num_layers, layer_idx, scale); \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, \ + layer_idx, scale); \ break; #define CASE(_T, narrow, wide) \ CASE_ONESIDE(T, narrow, wide) \ @@ -275,24 +278,21 @@ inline bool launch_bgmv_kernel(T* Y, const T* X, const T* W, return true; } -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, torch::Tensor indicies, int64_t layer_idx, float scale) { CHECK_INPUT(y); CHECK_INPUT(x); - CHECK_INPUT(w); + CHECK_INPUT(w_ptr); CHECK_INPUT(indicies); CHECK_DIM(2, y); CHECK_DIM(2, x); - CHECK_DIM(4, w); + CHECK_DIM(1, w_ptr); CHECK_DIM(1, indicies); int64_t B = x.size(0); int64_t h_in = x.size(1); int64_t h_out = y.size(1); - int64_t num_layers = w.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); CHECK_EQ(indicies.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0)); bool ok = false; @@ -301,16 +301,16 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, case at::ScalarType::Half: ok = launch_bgmv_kernel(static_cast(y.data_ptr()), static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, B, - num_layers, layer_idx, scale); + static_cast(w_ptr.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, h_out, B, + layer_idx, scale); break; case at::ScalarType::BFloat16: ok = launch_bgmv_kernel(static_cast(y.data_ptr()), static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, B, - num_layers, layer_idx, scale); + static_cast(w_ptr.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, h_out, B, + layer_idx, scale); break; default: break; @@ -343,11 +343,12 @@ void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr int d_in = x.size(1); int d_out = y.size(1); CHECK_EQ(tmp.size(0), static_cast(sgmv_tmp_size(num_problems))); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] { return sgmv((c_type*)y.data_ptr(), (c_type*)x.data_ptr(), (c_type**)w_ptr.data_ptr(), s_start.data_ptr(), s_end.data_ptr(), tmp.data_ptr(), num_problems, d_in, d_out, - layer_idx); + layer_idx, stream); }); TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type()); } @@ -373,13 +374,14 @@ void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, uint32_t d_out = y.size(1); CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte); CHECK_EQ(tmp.size(0), 8 * 1024 * 1024); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); #define CASE(_T, D_OUT) \ case D_OUT: \ return sgmv_shrink( \ (c_type*)y.data_ptr(), (c_type*)x.data_ptr(), \ (c_type**)w_ptr.data_ptr(), s_start.data_ptr(), s_end.data_ptr(), \ - tmp.data_ptr(), num_problems, d_in, layer_idx); + tmp.data_ptr(), num_problems, d_in, layer_idx, stream); bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] { switch (d_out) { diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv.h b/server/punica_kernels/punica_kernels/sgmv/sgmv.h index b3a92ba49..5c2215f90 100644 --- a/server/punica_kernels/punica_kernels/sgmv/sgmv.h +++ b/server/punica_kernels/punica_kernels/sgmv/sgmv.h @@ -1,5 +1,10 @@ +#pragma once +#include + +#include + template bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, - void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx); + void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx, cudaStream_t stream); size_t sgmv_tmp_size(int num_problems); diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu index 0cd68bcd9..96917c0d6 100644 --- a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu +++ b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu @@ -6,9 +6,9 @@ template bool sgmv(nv_half *y, nv_half *x, nv_half **w, int32_t *s_start, int32_t *s_end, void *tmp_d, int num_problems, int d_in, int d_out, - int layer_idx); + int layer_idx, cudaStream_t stream); template bool sgmv(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w, int32_t *s_start, int32_t *s_end, void *tmp_d, int num_problems, int d_in, int d_out, - int layer_idx); + int layer_idx, cudaStream_t stream); diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh index 69e180c41..906fd2a80 100644 --- a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh +++ b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh @@ -37,6 +37,11 @@ __global__ void precompute_sgmv_args(cutlass::gemm::GemmCoord *all_problems, int layer_idx) { int i = blockIdx.x; int m = s_end[i] - s_start[i], k = d_in, n = d_out; + if (m <= 0) { + m = 0; + n = 0; + k = 0; + } all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); ptr_w[i] = w[i] + layer_idx * d_in * d_out; ptr_x[i] = x + s_start[i] * d_in; @@ -61,7 +66,8 @@ inline T *alloc_from_buf(void **buf, int n) { template bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, - void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx) { + void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx, + cudaStream_t stream) { using cutlass_t = typename cutlass_dtype::type; auto ptr_Y = alloc_from_buf(&tmp_d, num_problems); @@ -73,7 +79,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, auto all_problems = alloc_from_buf(&tmp_d, num_problems); - precompute_sgmv_args<<>>( + precompute_sgmv_args<<>>( all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y, (cutlass_t *)x, (cutlass_t **)w, s_start, s_end, d_in, d_out, layer_idx); @@ -112,13 +118,13 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, ptr_Y, ld_X, ld_W, ld_Y, ld_Y); GemmGrouped gemm; - auto status = gemm.initialize(args); + auto status = gemm.initialize(args, nullptr, stream); if (status != cutlass::Status::kSuccess) { fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n", cutlassGetStatusString(status)); return false; } - status = gemm.run(); + status = gemm.run(stream); if (status != cutlass::Status::kSuccess) { fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n", cutlassGetStatusString(status)); @@ -157,13 +163,13 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, ptr_Y, ld_X, ld_W, ld_Y, ld_Y); GemmGrouped gemm; - auto status = gemm.initialize(args); + auto status = gemm.initialize(args, nullptr, stream); if (status != cutlass::Status::kSuccess) { fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n", cutlassGetStatusString(status)); return false; } - status = gemm.run(); + status = gemm.run(stream); if (status != cutlass::Status::kSuccess) { fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n", cutlassGetStatusString(status)); diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu index b4ce71df0..99e8b827a 100644 --- a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu @@ -9,7 +9,7 @@ template bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp, - uint32_t num_problems, uint32_t d_in, uint32_t layer_idx) { + uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream) { static_assert(d_out % 16 == 0); constexpr uint32_t num_warps = 4; @@ -18,15 +18,15 @@ bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp, constexpr uint32_t num_blocks_n = d_out / 16; uint32_t smem = num_stages * sizeof(T) * num_k_frags_per_stage * 16 * 16 * (num_warps + num_blocks_n); - cudaStream_t stream = nullptr; auto cooperative_kernel = flashinfer::sgmv::sgmv_shrink; auto kernel = flashinfer::sgmv::sgmv_shrink; - uint32_t dev_id = 0; + int dev_id = 0; int num_blocks_per_sm = 0; int num_sm = 0; bool use_cooperative = true; + cudaGetDevice(&dev_id); cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( &num_blocks_per_sm, cooperative_kernel, num_warps * 32, smem); @@ -67,7 +67,7 @@ bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp, #define INST(T, d_out) \ template bool sgmv_shrink(T * y, T * x, T * *w, int32_t * s_start, int32_t * s_end, \ void* tmp, uint32_t num_problems, \ - uint32_t d_in, uint32_t layer_idx); + uint32_t d_in, uint32_t layer_idx, cudaStream_t stream); FOR_SGMV_NARROW(INST, nv_half); FOR_SGMV_NARROW(INST, nv_bfloat16); \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h index f69876c9d..19044c946 100644 --- a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h @@ -3,7 +3,7 @@ template bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp, - uint32_t num_problems, uint32_t d_in, uint32_t layer_idx); + uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream); // clang-format off diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh index a1a4fd938..4f1424b23 100644 --- a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -20,6 +20,7 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s_starts, IdType* s_ends, constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero; const uint32_t problem_id = blockIdx.y; const uint32_t bx = blockIdx.x; + constexpr uint32_t num_stages = 2; constexpr uint32_t num_k_frags = 8; constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity(); diff --git a/server/tests/utils/test_sgmv.py b/server/tests/utils/test_sgmv.py index 5563b2009..d13a89c5b 100644 --- a/server/tests/utils/test_sgmv.py +++ b/server/tests/utils/test_sgmv.py @@ -1,22 +1,37 @@ +from typing import List, Tuple import pytest import torch -from lorax_server.utils.sgmv import add_lora_sgmv_cutlass, has_sgmv, orient_for_rank +from lorax_server.utils.sgmv import ( + get_tmp_tensors, + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + has_sgmv, + use_cutlass_shrink, +) def lora_ref_impl( y: torch.Tensor, x: torch.Tensor, - wa: torch.Tensor, - wb: torch.Tensor, + wa: List[torch.Tensor], + wb: List[torch.Tensor], s_start: torch.IntTensor, s_end: torch.IntTensor, layer_idx: int, + lora_rank: int, ): for i in range(len(wa)): + if s_end[i] - s_start[i] <= 0: + continue + xi = x[s_start[i]:s_end[i]] wai = wa[i][layer_idx, :, :] wbi = wb[i][layer_idx, :, :] + + if not use_cutlass_shrink(lora_rank): + wai = wai.t() + yi = y[s_start[i]:s_end[i]] tmp = (xi @ wai) y[s_start[i]:s_end[i]] = (yi + tmp @ wbi) @@ -24,8 +39,14 @@ def lora_ref_impl( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_sgmv(), reason="SGMV not available") -@pytest.mark.parametrize("lora_rank", [4, 8, 16, 32, 64, 128]) -def test_add_lora_sgmv_cutlass(lora_rank: int): +@pytest.mark.parametrize("segments", [ + ([0, 2], [1, 3]), + ([0, -1], [1, -1]), +]) +@pytest.mark.parametrize("lora_rank", [8, 16, 32, 64, 128]) +def test_add_lora_sgmv(lora_rank: int, segments: Tuple[List[int], List[int]]): + torch.manual_seed(42) + B = 3 H = 1024 r = lora_rank @@ -35,44 +56,48 @@ def test_add_lora_sgmv_cutlass(lora_rank: int): y = torch.zeros((B, H), dtype=torch.float16, device=device) x = torch.randn((B, H), dtype=torch.float16, device=device) - wa = torch.randn(nlayers, H, r, dtype=torch.float16, device=device) + wa = torch.randn(nlayers, r, H, dtype=torch.float16, device=device) + if use_cutlass_shrink(r): + # cutlass uses (H, r) layout + wa = wa.transpose(1, 2).contiguous() + + # TODO(travis): transpose (r, H) -> (H, r) when not using cutlass wb = torch.randn(nlayers, r, H, dtype=torch.float16, device=device) - wa_sgmv = orient_for_rank(wa, lora_rank) - wa_ptr = torch.tensor([wa_sgmv.data_ptr(), wa_sgmv.data_ptr()], dtype=torch.int64, device=device) - wb_ptr = torch.tensor([wb.data_ptr(), wb.data_ptr()], dtype=torch.int64, device=device) + s1, s2 = segments + s_start = torch.tensor(s1, dtype=torch.int32, device=device) + s_end = torch.tensor(s2, dtype=torch.int32, device=device) + + wa_list = [wa if y - x > 0 else None for x, y in zip(s1, s2)] + wb_list = [wb if y - x > 0 else None for x, y in zip(s1, s2)] - s_start = torch.tensor([0, 2], dtype=torch.int32, device=device) - s_end = torch.tensor([1, 3], dtype=torch.int32, device=device) + wa_ptr = torch.tensor([wa.data_ptr() if wa is not None else 0 for wa in wa_list], dtype=torch.int64, device=device) + wb_ptr = torch.tensor([wb.data_ptr() if wb is not None else 0 for wb in wb_list], dtype=torch.int64, device=device) layer_idx = 0 y_ref = y.clone() - lora_ref_impl(y_ref, x, [wa, wa], [wb, wb], s_start, s_end, layer_idx) - # print(y_ref) + lora_ref_impl(y_ref, x, wa_list, wb_list, s_start, s_end, layer_idx, r) - lora_a = wa[0, :, :] - lora_b = wb[0, :, :] - mask = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float16, device=device) - # out = torch.matmul(torch.matmul(x, lora_a), lora_b) - out = ((x @ lora_a) @ lora_b) * mask.view(-1, 1) - # print(x @ lora_a) - # print(out) + tmp_shrink, tmp_expand = get_tmp_tensors(wa_ptr.size(0), r, x.device) + y_ours = torch.zeros((B, H), dtype=torch.float16, device=device) - # assert torch.allclose(y_ref, out) + v = lora_a_sgmv_cutlass(x, tmp_shrink, wa_ptr, s_start, s_end, layer_idx, r) + lora_b_sgmv_cutlass(y_ours, v, tmp_expand, wb_ptr, s_start, s_end, layer_idx) - y_ours = torch.zeros((B, H), dtype=torch.float16, device=device) - add_lora_sgmv_cutlass( - y_ours, - x, - wa_ptr, - wb_ptr, - s_start, - s_end, - layer_idx, - r, - ) - # print(y_ours) - - # assert torch.allclose(y_ref, y_ours) - assert torch.allclose(out, y_ours) + assert torch.allclose(y_ref, y_ours, rtol=1e-2, atol=1e-3) + + # graph trace + tmp_shrink, tmp_expand = get_tmp_tensors(wa_ptr.size(0), r, x.device) + y_ours_graph = torch.zeros((B, H), dtype=torch.float16, device=device) + + torch.cuda.synchronize(device) + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=None): + v = lora_a_sgmv_cutlass(x, tmp_shrink, wa_ptr, s_start, s_end, layer_idx, r) + lora_b_sgmv_cutlass(y_ours_graph, v, tmp_expand, wb_ptr, s_start, s_end, layer_idx) + + torch.cuda.synchronize(device) + graph.replay() + + assert torch.allclose(y_ours, y_ours_graph, rtol=1e-2, atol=1e-3) \ No newline at end of file