Skip to content

Commit

Permalink
Use SGMV kernel when all adapters share the same R and alpha values (#15
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tgaddair authored Nov 15, 2023
1 parent 724a4c3 commit 2230652
Show file tree
Hide file tree
Showing 11 changed files with 535 additions and 289 deletions.
15 changes: 15 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm

# Build punica CUDA kernels
FROM kernel-builder as punica-builder

RUN /opt/conda/bin/conda install packaging

WORKDIR /usr/src

COPY server/Makefile-punica Makefile

# Build specific version of punica
RUN make build-punica

# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base

Expand Down Expand Up @@ -181,6 +193,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy builds artifacts from punica builder
COPY --from=punica-builder /usr/src/punica/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

Expand Down
5 changes: 3 additions & 2 deletions integration-tests/scripts/dynamic_adapter_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def main():
### Response:
"""
NUM_REQUESTS = 100
N = 10
NUM_REQUESTS = 500
N = 128
adapters = [get_local_path("arnavgrg/codealpaca_v3")] + [
get_local_path(f"arnavgrg/codealpaca_v3_{i}")
for i in range(1, N)
Expand All @@ -117,6 +117,7 @@ def main():
# ]

# adapters += [None]
# adapters = [None]

# adapters += [
# # get_local_path("arnavgrg/codealpaca_v3"),
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ install: gen-server install-torch
pip install -e ".[bnb, accelerate]"

run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 lorax_server/cli.py serve meta-llama/Llama-2-7b-hf --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve meta-llama/Llama-2-7b-hf --sharded
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve alexsherstinsky/Mistral-7B-v0.1-sharded --sharded

export-requirements:
Expand Down
13 changes: 13 additions & 0 deletions server/Makefile-punica
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
punica_commit := 5ccb1d62ede179bab6c91dfb2f6f320cc1c6b76d

punica:
# Clone punica
git clone https://github.com/predibase/punica.git --recurse

build-punica: punica
cd punica && git fetch && git checkout $(punica_commit)
cd punica && python setup.py build

install-punica: build-punica
pip uninstall punica -y || true
cd punica && python setup.py install
33 changes: 14 additions & 19 deletions server/lorax_server/models/custom_modeling/flash_llama_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TensorParallelHead,
get_linear,
)
from lorax_server.utils.lora import AdapterBatchData


class LlamaConfig(PretrainedConfig):
Expand Down Expand Up @@ -150,10 +151,10 @@ def forward(self, hidden_states, residual=None):
return normed_hidden_states, res


def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
base_layer = load_attention_multi(config, prefix, weights)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(base_layer, sizes=[
return TensorParallelMultiAdapterLinear.load(base_layer, layer_id, sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
Expand Down Expand Up @@ -205,6 +206,7 @@ def __init__(
prefix: str,
config,
weights,
layer_id: int,
):
super().__init__()
self.num_heads = config.num_attention_heads
Expand All @@ -230,7 +232,7 @@ def __init__(
config.num_key_value_heads // weights.process_group.size()
)

self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, layer_id)

self.o_proj = TensorParallelRowLinear.load(
config,
Expand Down Expand Up @@ -275,10 +277,9 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
):
qkv = self.query_key_value(hidden_states, adapter_indices, adapter_set)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
Expand Down Expand Up @@ -374,7 +375,7 @@ def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id,
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

Expand All @@ -399,8 +400,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)

Expand All @@ -415,8 +415,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
)

# faster post attention rms norm
Expand Down Expand Up @@ -469,8 +468,7 @@ def forward(
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_indices: torch.Tensor,
adapter_set: Set[int],
adapter_data: AdapterBatchData,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)

Expand All @@ -493,8 +491,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
)

hidden_states, _ = self.norm(hidden_states, residual)
Expand Down Expand Up @@ -523,8 +520,7 @@ def forward(
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_indices: torch.Tensor,
adapter_set: Set[int],
adapter_data: AdapterBatchData,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
Expand All @@ -536,8 +532,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TensorParallelHead,
get_linear,
)
from lorax_server.utils.lora import AdapterBatchData

if not HAS_FLASH_ATTN_V2:
raise ImportError("Mistral model requires flash attn v2")
Expand Down Expand Up @@ -155,10 +156,10 @@ def forward(self, hidden_states, residual=None):
return normed_hidden_states, res


def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
base_layer = load_attention_multi(config, prefix, weights)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(base_layer, sizes=[
return TensorParallelMultiAdapterLinear.load(base_layer, layer_id, sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
Expand Down Expand Up @@ -210,6 +211,7 @@ def __init__(
prefix: str,
config,
weights,
layer_id: int,
):
super().__init__()
self.max_past = (
Expand Down Expand Up @@ -238,7 +240,7 @@ def __init__(
config.num_key_value_heads // weights.process_group.size()
)

self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, layer_id)

self.o_proj = TensorParallelRowLinear.load(
config,
Expand Down Expand Up @@ -283,11 +285,10 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states, adapter_indices, adapter_set)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
Expand Down Expand Up @@ -389,7 +390,7 @@ def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id,
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

Expand All @@ -414,8 +415,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
Expand All @@ -431,8 +431,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
prefill_cache_indices,
)

Expand Down Expand Up @@ -486,8 +485,7 @@ def forward(
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_indices: torch.Tensor,
adapter_set: Set[int],
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
Expand All @@ -511,8 +509,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
prefill_cache_indices,
)

Expand Down Expand Up @@ -545,8 +542,7 @@ def forward(
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_indices: torch.Tensor,
adapter_set: Set[int],
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand All @@ -568,8 +564,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_indices,
adapter_set,
adapter_data,
prefill_cache_indices,
)
if lm_head_indices is not None:
Expand Down
Loading

0 comments on commit 2230652

Please sign in to comment.