Skip to content

Commit

Permalink
Fixed adapter segments when batches contain multiple distinct adapters (
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 26, 2023
1 parent 5e6215b commit 8c8109c
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 74 deletions.
37 changes: 24 additions & 13 deletions integration-tests/scripts/dynamic_adapter_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def query_lorax(args):
prompt, adapter_id = args
start_t = time.time()
request_params = {
"max_new_tokens": 128,
"max_new_tokens": 64,
"temperature": None,
"details": True,
}
Expand All @@ -73,6 +73,7 @@ def query_lorax(args):
response_body = json.loads(response.read().decode("utf-8"))
ntokens = response_body["details"]["generated_tokens"]
duration_s = time.time() - start_t
# print(adapter_id, response_body["generated_text"])
except Exception:
print(f"exception in request: {adapter_id}")
return adapter_id, 0, None
Expand All @@ -81,9 +82,9 @@ def query_lorax(args):
adapter_id,
ntokens,
duration_s,
(ntokens / duration_s)
(ntokens / duration_s),
))
return adapter_id, ntokens, duration_s
return adapter_id, ntokens, duration_s, response_body["generated_text"]


def get_local_path(model_id):
Expand All @@ -105,18 +106,19 @@ def main():
### Response:
"""
NUM_REQUESTS = 500
N = 128
adapters = [get_local_path("arnavgrg/codealpaca_v3")] + [
get_local_path(f"arnavgrg/codealpaca_v3_{i}")
for i in range(1, N)
]
# N = 0
# adapters = [get_local_path("arnavgrg/codealpaca_v3")] + [
# get_local_path(f"arnavgrg/codealpaca_v3_{i}")
# for i in range(1, N)
# ]

# Mistral
# adapters = [
# "alexsherstinsky/mistralai-7B-v01-based-finetuned-using-ludwig-with-samsum-T4-sharded-4bit-notmerged",
# ]
prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]"
adapters = [
"vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
]

# adapters += [None]
adapters += [None]
# adapters = [None]

# adapters += [
Expand Down Expand Up @@ -171,15 +173,24 @@ def main():

total_tokens = 0
total_duration_s = 0
for adapter_id, ntokens, duration_s in results:
responses = collections.defaultdict(set)
for adapter_id, ntokens, duration_s, resp in results:
if duration_s is None:
continue
total_tokens += ntokens
total_duration_s += duration_s
responses[adapter_id].add(resp)

print(f"Avg Latency: {total_duration_s / total_tokens} s / tokens")
print(f"Throughput: {total_tokens / span_s} tokens / s")

for adapter_id, resp in responses.items():
print("----")
print(f"{adapter_id}: {len(resp)}")
for r in resp:
print(" * " + r)
print("----")

# d = collections.defaultdict(list)
# for adapter_id, ntokens, duration_s in results:
# d[str(adapter_id)].append(end_t)
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ install: gen-server install-torch

run-dev:
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve meta-llama/Llama-2-7b-hf --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 lorax_server/cli.py serve mistralai/Mistral-7B-Instruct-v0.1 --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve mistralai/Mistral-7B-Instruct-v0.1 --sharded
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve flozi00/Mistral-7B-german-assistant-v5-4bit-autogptq --quantize gptq

export-requirements:
Expand Down
58 changes: 7 additions & 51 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map
from lorax_server.utils.dist import MEMORY_FRACTION
from lorax_server.utils.lora import K_PROJ, O_PROJ, Q_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -147,9 +148,6 @@ def from_pb(

adapter_indices_list = []
adapter_set = set()
adapter_segment_indices = []
adapter_segments = [0]
adapter_segment_length = 0

# Cumulative length
cumulative_length = 0
Expand Down Expand Up @@ -196,12 +194,6 @@ def from_pb(
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
adapter_set.add(r.adapter_index)

adapter_segment_length += input_length
if not adapter_segment_indices or adapter_segment_indices[-1] != r.adapter_index:
adapter_segment_indices.append(r.adapter_index)
adapter_segments.append(adapter_segments[-1] + adapter_segment_length)
adapter_segment_length = 0

# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
Expand Down Expand Up @@ -283,6 +275,7 @@ def from_pb(
input_lengths, dtype=torch.int32, device=device
)

adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device)

if all_prefill_logprobs:
Expand Down Expand Up @@ -376,10 +369,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
# Cumulative length
cumulative_max_length = 0

adapter_segment_indices = []
adapter_segments = [0]
adapter_segment_length = 0

for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
Expand All @@ -402,12 +391,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":

adapter_set.add(self.requests[idx].adapter_index)

adapter_segment_length += 1
if not adapter_segment_indices or adapter_segment_indices[-1] != self.requests[idx].adapter_index:
adapter_segment_indices.append(self.requests[idx].adapter_index)
adapter_segments.append(adapter_segments[-1] + adapter_segment_length)
adapter_segment_length = 0

remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
Expand Down Expand Up @@ -458,6 +441,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)

adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device)

return type(self)(
Expand Down Expand Up @@ -544,8 +528,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch

adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size)
adapter_set = set()
adapter_segment_indices = []
adapter_segment_tensors = []
adapter_segment_builder = SegmentConcatBuilder()

start_slots = []
block_tables = []
Expand Down Expand Up @@ -593,25 +576,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
adapter_set.update(batch.adapter_meta.adapter_set)

# Update adapter segments
adapter_segments = batch.adapter_meta.adapter_segments
if adapter_segment_tensors:
# Because we have already processed at least one batch, remove the 0 start index
# from this batch denoting the beginning of the segment, then offset all segment
# positions by the value of the last segment in the previous batch to account for
# the concatenation.
adapter_segments = adapter_segments[1:] + adapter_segment_tensors[-1][-1]

segment_indices = batch.adapter_meta.segment_indices
if adapter_segment_indices and adapter_segment_indices[-1] == segment_indices[-1]:
# If the last segment in the previous batch is the same as the first segment in this batch,
# then we merge them together into a single segment. In effect, this means removing it from
# the segment indices of this batch, and extending the segment span by removing the segment
# end index from the previous batch.
segment_indices = segment_indices[1:]
adapter_segment_tensors[-1] = adapter_segment_tensors[-1][:-1]

adapter_segment_indices.extend(segment_indices)
adapter_segment_tensors.append(adapter_segments)
adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices)

all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
Expand Down Expand Up @@ -645,7 +610,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
device=batches[0].next_token_chooser.device,
)

adapter_segments = torch.concat(adapter_segment_tensors)
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
Expand Down Expand Up @@ -1052,16 +1017,7 @@ def generate_token(

if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments = [0]
adapter_segment_length = 0
last_adapter_index = None
for r in batch.requests:
adapter_segment_length += 1
if last_adapter_index != r.adapter_index:
adapter_segments.append(adapter_segments[-1] + adapter_segment_length)
adapter_segment_length = 0
last_adapter_index = r.adapter_index

adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
Expand Down
11 changes: 2 additions & 9 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID
from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata
from lorax_server.utils.segments import find_segments

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -97,9 +98,6 @@ def from_pb(

adapter_indices_list = []
adapter_set = set()
adapter_segment_indices = []
adapter_segments = [0]
adapter_segment_length = 0

# Cumulative length
cumulative_length = 0
Expand Down Expand Up @@ -147,12 +145,6 @@ def from_pb(
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
adapter_set.add(r.adapter_index)

adapter_segment_length += input_length
if not adapter_segment_indices or adapter_segment_indices[-1] != r.adapter_index:
adapter_segment_indices.append(r.adapter_index)
adapter_segments.append(adapter_segments[-1] + adapter_segment_length)
adapter_segment_length = 0

# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
Expand Down Expand Up @@ -250,6 +242,7 @@ def from_pb(
input_lengths, dtype=torch.int32, device=device
)

adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device)

if all_prefill_logprobs:
Expand Down
54 changes: 54 additions & 0 deletions server/lorax_server/utils/segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List, Tuple

import torch


def find_segments(adapter_indices: torch.Tensor) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []

# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu()

start_index = 0
for i in range(1, adapter_indices.shape[0]):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1].item())
start_index = i

# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1].item())

return segments, segment_indices


class SegmentConcatBuilder:
def __init__(self):
self.adapter_segment_indices = []
self.adapter_segment_tensors = []

def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
# Update adapter segments
if self.adapter_segment_tensors:
# Because we have already processed at least one batch, remove the 0 start index
# from this batch denoting the beginning of the segment, then offset all segment
# positions by the value of the last segment in the previous batch to account for
# the concatenation.
adapter_segments = adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]

if self.adapter_segment_indices and self.adapter_segment_indices[-1] == segment_indices[0]:
# If the last segment in the previous batch is the same as the first segment in this batch,
# then we merge them together into a single segment. In effect, this means removing it from
# the segment indices of this batch, and extending the segment span by removing the segment
# end index from the previous batch.
segment_indices = segment_indices[1:]
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]

self.adapter_segment_indices.extend(segment_indices)
self.adapter_segment_tensors.append(adapter_segments)

def build(self) -> Tuple[torch.Tensor, List[int]]:
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
55 changes: 55 additions & 0 deletions server/tests/utils/test_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import torch

from lorax_server.utils.segments import find_segments, SegmentConcatBuilder



@pytest.mark.parametrize(
"adapter_indices,expected_segments,expected_segment_indices",
[
(
torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1]),
[0, 3, 5, 10, 12],
[0, 1, 2, 1],
),
(torch.tensor([]), [0], []),
(torch.tensor([0]), [0, 1], [0]),
(torch.tensor([1]), [0, 1], [1]),
],
)
def test_find_segments(adapter_indices, expected_segments, expected_segment_indices):
segments, segment_indices = find_segments(adapter_indices)
assert segments == expected_segments
assert segment_indices == expected_segment_indices


@pytest.mark.parametrize(
"batches,expected_segments,expected_segment_indices",
[
(
[
(torch.tensor([0, 1, 4, 7, 8]), [2, 1, 2, 1]),
(torch.tensor([0, 2, 5]), [1, 2]),
],
[0, 1, 4, 7, 10, 13],
[2, 1, 2, 1, 2],
),
(
[
(torch.tensor([0, 1, 4, 7]), [2, 1, 2]),
(torch.tensor([0, 2, 5]), [1, 2]),
],
[0, 1, 4, 7, 9, 12],
[2, 1, 2, 1, 2],
),
],
)
def test_concat_segments(batches, expected_segments, expected_segment_indices):
builder = SegmentConcatBuilder()
for segment, indices in batches:
builder.concat(segment, indices)

segments, segment_indices = builder.build()
assert segments.tolist() == expected_segments
assert segment_indices == expected_segment_indices

0 comments on commit 8c8109c

Please sign in to comment.