diff --git a/integration-tests/scripts/dynamic_adapter_loading.py b/integration-tests/scripts/dynamic_adapter_loading.py index 03c37f1da..1fa784f6d 100644 --- a/integration-tests/scripts/dynamic_adapter_loading.py +++ b/integration-tests/scripts/dynamic_adapter_loading.py @@ -47,7 +47,7 @@ def query_lorax(args): prompt, adapter_id = args start_t = time.time() request_params = { - "max_new_tokens": 128, + "max_new_tokens": 64, "temperature": None, "details": True, } @@ -73,6 +73,7 @@ def query_lorax(args): response_body = json.loads(response.read().decode("utf-8")) ntokens = response_body["details"]["generated_tokens"] duration_s = time.time() - start_t + # print(adapter_id, response_body["generated_text"]) except Exception: print(f"exception in request: {adapter_id}") return adapter_id, 0, None @@ -81,9 +82,9 @@ def query_lorax(args): adapter_id, ntokens, duration_s, - (ntokens / duration_s) + (ntokens / duration_s), )) - return adapter_id, ntokens, duration_s + return adapter_id, ntokens, duration_s, response_body["generated_text"] def get_local_path(model_id): @@ -105,18 +106,19 @@ def main(): ### Response: """ NUM_REQUESTS = 500 - N = 128 - adapters = [get_local_path("arnavgrg/codealpaca_v3")] + [ - get_local_path(f"arnavgrg/codealpaca_v3_{i}") - for i in range(1, N) - ] + # N = 0 + # adapters = [get_local_path("arnavgrg/codealpaca_v3")] + [ + # get_local_path(f"arnavgrg/codealpaca_v3_{i}") + # for i in range(1, N) + # ] # Mistral - # adapters = [ - # "alexsherstinsky/mistralai-7B-v01-based-finetuned-using-ludwig-with-samsum-T4-sharded-4bit-notmerged", - # ] + prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" + adapters = [ + "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", + ] - # adapters += [None] + adapters += [None] # adapters = [None] # adapters += [ @@ -171,15 +173,24 @@ def main(): total_tokens = 0 total_duration_s = 0 - for adapter_id, ntokens, duration_s in results: + responses = collections.defaultdict(set) + for adapter_id, ntokens, duration_s, resp in results: if duration_s is None: continue total_tokens += ntokens total_duration_s += duration_s + responses[adapter_id].add(resp) print(f"Avg Latency: {total_duration_s / total_tokens} s / tokens") print(f"Throughput: {total_tokens / span_s} tokens / s") + for adapter_id, resp in responses.items(): + print("----") + print(f"{adapter_id}: {len(resp)}") + for r in resp: + print(" * " + r) + print("----") + # d = collections.defaultdict(list) # for adapter_id, ntokens, duration_s in results: # d[str(adapter_id)].append(end_t) diff --git a/server/Makefile b/server/Makefile index 7ae2ff9e9..50d50275e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -25,7 +25,7 @@ install: gen-server install-torch run-dev: # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve meta-llama/Llama-2-7b-hf --sharded - SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 lorax_server/cli.py serve mistralai/Mistral-7B-Instruct-v0.1 --sharded + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve mistralai/Mistral-7B-Instruct-v0.1 --sharded # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve flozi00/Mistral-7B-german-assistant-v5-4bit-autogptq --quantize gptq export-requirements: diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index d7b50989d..be0865320 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -31,6 +31,7 @@ from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map from lorax_server.utils.dist import MEMORY_FRACTION from lorax_server.utils.lora import K_PROJ, O_PROJ, Q_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights +from lorax_server.utils.segments import SegmentConcatBuilder, find_segments tracer = trace.get_tracer(__name__) @@ -147,9 +148,6 @@ def from_pb( adapter_indices_list = [] adapter_set = set() - adapter_segment_indices = [] - adapter_segments = [0] - adapter_segment_length = 0 # Cumulative length cumulative_length = 0 @@ -196,12 +194,6 @@ def from_pb( adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) adapter_set.add(r.adapter_index) - adapter_segment_length += input_length - if not adapter_segment_indices or adapter_segment_indices[-1] != r.adapter_index: - adapter_segment_indices.append(r.adapter_index) - adapter_segments.append(adapter_segments[-1] + adapter_segment_length) - adapter_segment_length = 0 - # Paged attention # Remove one as the first token des not have a past total_tokens = input_length + max_new_tokens - 1 @@ -283,6 +275,7 @@ def from_pb( input_lengths, dtype=torch.int32, device=device ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) if all_prefill_logprobs: @@ -376,10 +369,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Cumulative length cumulative_max_length = 0 - adapter_segment_indices = [] - adapter_segments = [0] - adapter_segment_length = 0 - for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) @@ -402,12 +391,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_set.add(self.requests[idx].adapter_index) - adapter_segment_length += 1 - if not adapter_segment_indices or adapter_segment_indices[-1] != self.requests[idx].adapter_index: - adapter_segment_indices.append(self.requests[idx].adapter_index) - adapter_segments.append(adapter_segments[-1] + adapter_segment_length) - adapter_segment_length = 0 - remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -458,6 +441,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) return type(self)( @@ -544,8 +528,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) adapter_set = set() - adapter_segment_indices = [] - adapter_segment_tensors = [] + adapter_segment_builder = SegmentConcatBuilder() start_slots = [] block_tables = [] @@ -593,25 +576,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_set.update(batch.adapter_meta.adapter_set) # Update adapter segments - adapter_segments = batch.adapter_meta.adapter_segments - if adapter_segment_tensors: - # Because we have already processed at least one batch, remove the 0 start index - # from this batch denoting the beginning of the segment, then offset all segment - # positions by the value of the last segment in the previous batch to account for - # the concatenation. - adapter_segments = adapter_segments[1:] + adapter_segment_tensors[-1][-1] - - segment_indices = batch.adapter_meta.segment_indices - if adapter_segment_indices and adapter_segment_indices[-1] == segment_indices[-1]: - # If the last segment in the previous batch is the same as the first segment in this batch, - # then we merge them together into a single segment. In effect, this means removing it from - # the segment indices of this batch, and extending the segment span by removing the segment - # end index from the previous batch. - segment_indices = segment_indices[1:] - adapter_segment_tensors[-1] = adapter_segment_tensors[-1][:-1] - - adapter_segment_indices.extend(segment_indices) - adapter_segment_tensors.append(adapter_segments) + adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -645,7 +610,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch device=batches[0].next_token_chooser.device, ) - adapter_segments = torch.concat(adapter_segment_tensors) + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: @@ -1052,16 +1017,7 @@ def generate_token( if prefill: # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments = [0] - adapter_segment_length = 0 - last_adapter_index = None - for r in batch.requests: - adapter_segment_length += 1 - if last_adapter_index != r.adapter_index: - adapter_segments.append(adapter_segments[-1] + adapter_segment_length) - adapter_segment_length = 0 - last_adapter_index = r.adapter_index - + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( adapter_segments, dtype=torch.int32, diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index f8438f571..5ef4c5c0f 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -35,6 +35,7 @@ ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata +from lorax_server.utils.segments import find_segments tracer = trace.get_tracer(__name__) @@ -97,9 +98,6 @@ def from_pb( adapter_indices_list = [] adapter_set = set() - adapter_segment_indices = [] - adapter_segments = [0] - adapter_segment_length = 0 # Cumulative length cumulative_length = 0 @@ -147,12 +145,6 @@ def from_pb( adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) adapter_set.add(r.adapter_index) - adapter_segment_length += input_length - if not adapter_segment_indices or adapter_segment_indices[-1] != r.adapter_index: - adapter_segment_indices.append(r.adapter_index) - adapter_segments.append(adapter_segments[-1] + adapter_segment_length) - adapter_segment_length = 0 - # Paged attention # Remove one as the first token des not have a past total_tokens = input_length + max_new_tokens - 1 @@ -250,6 +242,7 @@ def from_pb( input_lengths, dtype=torch.int32, device=device ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) if all_prefill_logprobs: diff --git a/server/lorax_server/utils/segments.py b/server/lorax_server/utils/segments.py new file mode 100644 index 000000000..841ee3f6f --- /dev/null +++ b/server/lorax_server/utils/segments.py @@ -0,0 +1,54 @@ +from typing import List, Tuple + +import torch + + +def find_segments(adapter_indices: torch.Tensor) -> Tuple[List[int], List[int]]: + segments = [0] + segment_indices = [] + + # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first + adapter_indices = adapter_indices.cpu() + + start_index = 0 + for i in range(1, adapter_indices.shape[0]): + if adapter_indices[i] != adapter_indices[i - 1]: + segments.append(i) + segment_indices.append(adapter_indices[i - 1].item()) + start_index = i + + # Handle the last segment + if start_index < len(adapter_indices): + segments.append(len(adapter_indices)) + segment_indices.append(adapter_indices[-1].item()) + + return segments, segment_indices + + +class SegmentConcatBuilder: + def __init__(self): + self.adapter_segment_indices = [] + self.adapter_segment_tensors = [] + + def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): + # Update adapter segments + if self.adapter_segment_tensors: + # Because we have already processed at least one batch, remove the 0 start index + # from this batch denoting the beginning of the segment, then offset all segment + # positions by the value of the last segment in the previous batch to account for + # the concatenation. + adapter_segments = adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] + + if self.adapter_segment_indices and self.adapter_segment_indices[-1] == segment_indices[0]: + # If the last segment in the previous batch is the same as the first segment in this batch, + # then we merge them together into a single segment. In effect, this means removing it from + # the segment indices of this batch, and extending the segment span by removing the segment + # end index from the previous batch. + segment_indices = segment_indices[1:] + self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] + + self.adapter_segment_indices.extend(segment_indices) + self.adapter_segment_tensors.append(adapter_segments) + + def build(self) -> Tuple[torch.Tensor, List[int]]: + return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/server/tests/utils/test_segments.py b/server/tests/utils/test_segments.py new file mode 100644 index 000000000..c81cd0892 --- /dev/null +++ b/server/tests/utils/test_segments.py @@ -0,0 +1,55 @@ +import pytest +import torch + +from lorax_server.utils.segments import find_segments, SegmentConcatBuilder + + + +@pytest.mark.parametrize( + "adapter_indices,expected_segments,expected_segment_indices", + [ + ( + torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1]), + [0, 3, 5, 10, 12], + [0, 1, 2, 1], + ), + (torch.tensor([]), [0], []), + (torch.tensor([0]), [0, 1], [0]), + (torch.tensor([1]), [0, 1], [1]), + ], +) +def test_find_segments(adapter_indices, expected_segments, expected_segment_indices): + segments, segment_indices = find_segments(adapter_indices) + assert segments == expected_segments + assert segment_indices == expected_segment_indices + + +@pytest.mark.parametrize( + "batches,expected_segments,expected_segment_indices", + [ + ( + [ + (torch.tensor([0, 1, 4, 7, 8]), [2, 1, 2, 1]), + (torch.tensor([0, 2, 5]), [1, 2]), + ], + [0, 1, 4, 7, 10, 13], + [2, 1, 2, 1, 2], + ), + ( + [ + (torch.tensor([0, 1, 4, 7]), [2, 1, 2]), + (torch.tensor([0, 2, 5]), [1, 2]), + ], + [0, 1, 4, 7, 9, 12], + [2, 1, 2, 1, 2], + ), + ], +) +def test_concat_segments(batches, expected_segments, expected_segment_indices): + builder = SegmentConcatBuilder() + for segment, indices in batches: + builder.concat(segment, indices) + + segments, segment_indices = builder.build() + assert segments.tolist() == expected_segments + assert segment_indices == expected_segment_indices