Skip to content

Commit

Permalink
CUDA graph compilation (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 4, 2024
1 parent 6bfc3a2 commit f20789d
Show file tree
Hide file tree
Showing 50 changed files with 1,006 additions and 184 deletions.
43 changes: 43 additions & 0 deletions docs/guides/cuda_graphs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
LoRAX supports compiling the model into a static CUDA Graph to speedup inference by upwards of 2x. See [Accelerating PyTorch with CUDA Graphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for more details on CUDA graphs and how they can reduce latency.

## Usage

To enable this (experimental) feature:

```
lorax-launcher ... --compile
```

## When should I use this?

CUDA graph compilation is a simple way to decrease latency for smaller LLMs (O(1b params)) that are compute bound rather than memory bound.

There is a tradeoff to be aware of when using CUDA graphs, namely that it increases memory overhead by 3-10GB depending on model size. However, the observed decrease in latency can be as much as 50%, so if you don't need to run with very large batch sizes and are more latency constrained than throughput, this is a very compelling feature to enable.

In practice, CUDA graphs are most useful in cases where there are excess GPU flops available, such as during decoding. As such, we do not use the compiled version of the model during prefill, only during the decoding steps. Which means in practice that the benefits of enabling compilation will be most pronounced when generating longer sequences (for which more time is spent during decoding).

## Limitations

Current limitations:

- Batch size < 256
- Context length (input + output) < 8192
- LoRA rank >= 8 and <= 64
- Only one LoRA rank in the batch
- 1 GPU (no sharding)

If any of these conditions are not met, then LoRAX will fallback to using eager execution for the batch.

## Benchmarks

gpt2-medium, 1x A100, time to generate 100 tokens:

no adapter:

- baseline: 1.044 s
- cuda graph: 0.422 s

1 adapter (rank 16):

- baseline: 1.503 s
- cuda graph: 0.583 s
16 changes: 11 additions & 5 deletions docs/reference/launcher.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# LoRAX Launcher

```shell
LoRAX Launcher

Usage: lorax-launcher [OPTIONS]

Expand All @@ -24,7 +25,7 @@ Options:
[default: hub]

--adapter-source <ADAPTER_SOURCE>
The source of the model to load. Can be `hub` or `s3`. `hub` will load the model from the huggingface hub. `s3` will load the model from the predibase S3 bucket
The source of the model to load. Can be `hub` or `s3` or `pbase` `hub` will load the model from the huggingface hub. `s3` will load the model from the predibase S3 bucket. `pbase` will load an s3 model but resolve the metadata from a predibase server

[env: ADAPTER_SOURCE=]
[default: hub]
Expand Down Expand Up @@ -55,7 +56,12 @@ Options:
Whether you want the model to be quantized. This will use `bitsandbytes` for quantization on the fly, or `gptq`

[env: QUANTIZE=]
[possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq]
[possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq, awq]

--compile
Whether you want to compile the model into a CUDA graph. This will speed up decoding but increase GPU memory usage

[env: COMPILE=]

--dtype <DTYPE>
The dtype to be forced upon the model. This option cannot be used with `--quantize`
Expand Down Expand Up @@ -152,13 +158,13 @@ Options:
--hostname <HOSTNAME>
The IP address to listen on

[env: HOSTNAME=b3687ab43244]
[env: HOSTNAME=]
[default: 0.0.0.0]

-p, --port <PORT>
The port to listen on

[env: PORT=80]
[env: PORT=]
[default: 3000]

--shard-uds-path <SHARD_UDS_PATH>
Expand All @@ -182,7 +188,7 @@ Options:
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance

[env: HUGGINGFACE_HUB_CACHE=/data]
[env: HUGGINGFACE_HUB_CACHE=]

--weights-cache-override <WEIGHTS_CACHE_OVERRIDE>
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance
Expand Down
16 changes: 10 additions & 6 deletions integration-tests/scripts/dynamic_adapter_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def query_lorax(args):
prompt, adapter_id = args
start_t = time.time()
request_params = {
"max_new_tokens": 64,
"max_new_tokens": 128,
"temperature": None,
"details": True,
}
if adapter_id is not None:
# request_params["adapter_source"] = "local"
request_params["adapter_source"] = "local"
request_params["adapter_id"] = adapter_id

print("request_params", request_params)
Expand Down Expand Up @@ -113,10 +113,14 @@ def main():
# ]

# Mistral
prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]"
adapters = [
"vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
]
# prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]"
# adapters = [
# "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
# ]

# GPT2
prompt = "Brand Name : First Aid Beauty ; Product Name : Ultra Repair Cream Intense Hydration ; Review Title :"
adapters = ["/data/adapters/9789adb7-cd03-4862-91d5-b41b6746682e_ludwig/model_weights"]

adapters += [None]
# adapters = [None]
Expand Down
13 changes: 13 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ struct Args {
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,

/// Whether you want to compile the model into a CUDA graph.
/// This will speed up decoding but increase GPU memory usage.
#[clap(long, env, value_enum)]
compile: bool,

/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
Expand Down Expand Up @@ -342,6 +347,7 @@ fn shard_manager(
source: String,
adapter_source: String,
quantize: Option<Quantization>,
compile: bool,
dtype: Option<Dtype>,
trust_remote_code: bool,
uds_path: String,
Expand Down Expand Up @@ -407,6 +413,11 @@ fn shard_manager(
shard_args.push(quantize.to_string())
}

// CUDA graph compilation
if compile {
shard_args.push("--compile".to_string());
}

if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())
Expand Down Expand Up @@ -853,6 +864,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let compile = args.compile;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
Expand All @@ -868,6 +880,7 @@ fn spawn_shards(
source,
adapter_source,
quantize,
compile,
dtype,
trust_remote_code,
uds_path,
Expand Down
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ nav:
- Launcher: reference/launcher.md
- REST API: reference/rest_api.md
- Python Client: reference/python_client.md
# - 🔬 Guides:
- 🔬 Guides:
- CUDA Graph Compilation: guides/cuda_graphs.md
# - GPUs: guides/gpus.md
# - Fine-Tuning: guides/fine_tuning.md
# - Quantization: guides/quantization.md
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
compile: bool = False,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/lorax-server",
Expand Down Expand Up @@ -82,7 +83,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, adapter_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, source, adapter_source
model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code, uds_path, source, adapter_source
)

def _download_weights(
Expand Down
Loading

0 comments on commit f20789d

Please sign in to comment.