From 832d9053f03de81adc0fc3ef238698ac13b8ac69 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 23 Oct 2024 12:11:53 -0700 Subject: [PATCH 01/76] Collect timings --- server/lorax_server/models/flash_causal_lm.py | 973 +++++++++--------- server/lorax_server/server.py | 104 +- 2 files changed, 545 insertions(+), 532 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2e667ef84..d244f8af7 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -28,6 +28,7 @@ from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed from lorax_server.utils.graph import GraphCache from lorax_server.utils.import_utils import get_cuda_free_memory +from lorax_server.utils.profiler import timer from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files @@ -1408,517 +1409,521 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> def generate_token( self, batch: FlashCausalLMBatch, is_warmup: bool = False ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - prefill = batch.prefilling - if prefill: - batch.prepare_for_prefill() - prefill_logprobs = batch.prefill_next_token_indices is not None - return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) - - # Update adapter indices for speculative tokens (if present) - adapter_meta = batch.adapter_meta - if batch.speculative_ids is not None: - B, speculative_length = batch.speculative_ids.shape - new_length = speculative_length + 1 - adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) - adapter_segments = adapter_meta.adapter_segments * new_length - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_meta.segment_indices, - ) - - # Assign pointers to adapter weights - # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices - ) - - out, speculative_logits = self.forward(batch, adapter_data) - - if prefill: - next_token_logits = out[batch.prefill_next_token_indices] if prefill_logprobs else out - if speculative_logits is not None: - speculative_logits = ( - speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits + stage_str = "prefill" if batch.prefilling else "decode" + with timer(f"{stage_str}::generate_token::pre_forward"): + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + prefill_logprobs = batch.prefill_next_token_indices is not None + return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) + + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, ) - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - else: - prefill_logprobs = None - next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices - - finished_prefilling = True - next_chunk_lengths = [] - current_prefilling_mask = batch.prefilling_mask - if prefill: - if get_supports_chunking(): - next_prefilling_mask = [] - # Budget in tokens for the next batch - # We remove (len(batch) - 1) to always have enough space for at least a single decode - # for the remaining requests -1 because the first request does not need to be removed from the budget - # (ex: you have one request in the batch, you want it to take the full budget not budget -1) - batch_budget = get_max_prefill_tokens() - (len(batch) - 1) - # We reverse to prioritize older requests - # zip() is not reversible so reverse the underlying lists instead - for cache_length, input_length, prompt_length in zip( - reversed(batch.cache_lengths), - reversed(batch.input_lengths), - reversed(batch.prompt_lengths), - ): - remaining_prefill_tokens = max(prompt_length - cache_length - input_length, 0) - if remaining_prefill_tokens > 0: - next_chunk_length = max(min(remaining_prefill_tokens, batch_budget), 1) - batch_budget -= next_chunk_length - finished_prefilling = False - next_prefilling_mask.append(True) - else: - # FIXME: use true number of accepted tokens instead of 1 - # Since speculation will be turned off, this is always true - next_chunk_length = 1 - next_prefilling_mask.append(False) - next_chunk_lengths.append(next_chunk_length) - - # Reverse back the obtained values² - next_chunk_lengths.reverse() - next_prefilling_mask.reverse() - else: - # The model does not support chunking - # We know we only do a single prefill - finished_prefilling = True - next_prefilling_mask = [False] * len(batch) - batch.prefilling = not finished_prefilling - batch.prefilling_mask = next_prefilling_mask + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + ) - speculative_tokens = get_speculative_tokens() - ( - next_input_ids, - next_token_logprobs, - accepted_ids, - speculative_ids, - ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], - next_token_logits, - speculative_tokens, - batch.speculative_ids, - speculative_logits, - ) + with timer(f"{stage_str}::generate_token::forward"): + out, speculative_logits = self.forward(batch, adapter_data) - if return_alternatives: - alternative_token_logprobs, alternative_token_ids = torch.sort( - torch.log_softmax(next_token_logits, -1), dim=-1, stable=True, descending=True + with timer(f"{stage_str}::generate_token::post_forward"): + if prefill: + next_token_logits = out[batch.prefill_next_token_indices] if prefill_logprobs else out + if speculative_logits is not None: + speculative_logits = ( + speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits + ) + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) + else: + prefill_logprobs = None + next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices + + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + if get_supports_chunking(): + next_prefilling_mask = [] + # Budget in tokens for the next batch + # We remove (len(batch) - 1) to always have enough space for at least a single decode + # for the remaining requests -1 because the first request does not need to be removed from the budget + # (ex: you have one request in the batch, you want it to take the full budget not budget -1) + batch_budget = get_max_prefill_tokens() - (len(batch) - 1) + # We reverse to prioritize older requests + # zip() is not reversible so reverse the underlying lists instead + for cache_length, input_length, prompt_length in zip( + reversed(batch.cache_lengths), + reversed(batch.input_lengths), + reversed(batch.prompt_lengths), + ): + remaining_prefill_tokens = max(prompt_length - cache_length - input_length, 0) + if remaining_prefill_tokens > 0: + next_chunk_length = max(min(remaining_prefill_tokens, batch_budget), 1) + batch_budget -= next_chunk_length + finished_prefilling = False + next_prefilling_mask.append(True) + else: + # FIXME: use true number of accepted tokens instead of 1 + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_prefilling_mask.append(False) + next_chunk_lengths.append(next_chunk_length) + + # Reverse back the obtained values² + next_chunk_lengths.reverse() + next_prefilling_mask.reverse() + else: + # The model does not support chunking + # We know we only do a single prefill + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) + + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask + + speculative_tokens = get_speculative_tokens() + ( + next_input_ids, + next_token_logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_current_length], + next_token_logits, + speculative_tokens, + batch.speculative_ids, + speculative_logits, ) - # Since we are done prefilling, all the tensors that were concatenating values for all the requests - # instantly become of shape [BATCH_SIZE] - if prefill and finished_prefilling: - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) - elif not prefill: - next_position_ids = batch.position_ids - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.all_input_ids, - accepted_ids, - current_prefilling_mask, - batch.prefilling_mask, - ) - - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a GPU <-> CPU sync - # It is faster if we delay this sync for the maximum amount of time + if return_alternatives: + alternative_token_logprobs, alternative_token_ids = torch.sort( + torch.log_softmax(next_token_logits, -1), dim=-1, stable=True, descending=True + ) - # For each member of the batch - index = 0 - # Cumulative length - cumulative_length = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - all_input_ids, - n_accepted_ids, - request_was_prefilling, - request_is_prefilling, - ) in enumerate(iterator): + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: - # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length - - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] - - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] - - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] - - index += n_accepted_ids - cumulative_length += input_length - - # Update values - # These values can be updated without a GPU -> CPU sync - if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if prefill and prefill_logprobs: - # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) - torch.log_softmax(out, -1, out=out) - prefill_logprobs_tensor = out - prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - - # Does a GPU <-> CPU sync internally - if prefill and finished_prefilling: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, + next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) + elif not prefill: + next_position_ids = batch.position_ids + + # Zipped iterator + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.all_input_ids, + accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, ) - # GPU <-> CPU sync - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids.tolist() - accepted_ids = accepted_ids.tolist() - - if return_alternatives: - alternative_token_logprobs = alternative_token_logprobs.tolist() - alternative_token_ids = alternative_token_ids.tolist() - - # Update values if we need to continue prefilling - # This represents the `else` case of the `Update values` if above - # but since this require the `next_token_ids` to be on CPU, it is better to do it here - if prefill and not finished_prefilling: - # Speculation must be ignored while we prefill even with chunking - # it simplifies everything - assert batch.speculative_ids is None + # We do two for loops as the first one can run completely asynchronously from the GPU while for the second + # one, we need to first do a GPU <-> CPU sync + # It is faster if we delay this sync for the maximum amount of time - all_postfix_ids = [] + # For each member of the batch + index = 0 + # Cumulative length + cumulative_length = 0 for i, ( - request_prefilling, - next_token_id, - all_input_ids, + request, + prompt_length, cache_length, input_length, - next_chunk_length, - ) in enumerate( - zip( - batch.prefilling_mask, - next_token_ids, - batch.all_input_ids, - batch.cache_lengths, - batch.input_lengths, - next_chunk_lengths, - ) - ): - if request_prefilling: - next_cache_length = cache_length + input_length - # Get new prompt IDs to prefill - postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] - else: - # This request is done prefilling, the new id is the one selected the sampling method - postfix_ids = [next_token_id] - - all_postfix_ids.append(postfix_ids) - - batch.input_ids = all_postfix_ids - - # Results - generations: List[Generation] = [] - stopped = not is_warmup - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - batch.stopping_criterias, - batch.all_input_ids, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - current_prefilling_mask, - batch.prefilling_mask, - accepted_ids, - ) + all_input_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, + ) in enumerate(iterator): + if prefill and finished_prefilling: + # Indexing metadata + _start_index = cumulative_length + end_index = cumulative_length + input_length + + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] + + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids - # Reset max_input_length - batch.max_input_length = 0 - # For each member of the batch - index = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - request_was_prefilling, - request_is_prefilling, - n_accepted_ids, - ) in enumerate(iterator): - all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None - - # TODO(travis): return_k_alternatives - # if request.parameters.return_k_alternatives > 0: - # # Limit the number of alternatives to the vocabulary size - # num_alternatives = min( - # request.parameters.return_k_alternatives, - # len(alternative_token_ids[token_idx]), - # ) - - # # Select top-k logprobs - # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] - # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] - - # # Decode tokens - # request_alternative_token_texts = [] - # for alternative_token_id in request_alternative_token_ids: - # all_input_ids.append(alternative_token_id) - # alternative_token_text, _, _ = self.decode_token( - # all_input_ids, - # prefix_offset, - # read_offset, - # ) - # request_alternative_token_texts.append(alternative_token_text) - # all_input_ids.pop() - # alternative_tokens = AlternativeTokens( - # request_alternative_token_ids, - # request_alternative_token_logprobs, - # request_alternative_token_texts, - # ) - # all_alternative_tokens.append(alternative_tokens) - - # Compute logprobs first as, even though we might skip the token, - # it can still be required to compute the logprobs - # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need - # this state to be stable - if request.id % self.world_size == self.rank: - # Prefill - if request_was_prefilling and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] if not request_is_prefilling: - # The request is dones prefilling, meaning that we started generating new tokens - # The last logprob is a logprob for a generated token that was not part of the prompt - # We need to remove it - out_end_index -= 1 - - request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] - - past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] - - if past_prefill_logprob_tokens is None: - # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs - prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids - - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + # Only save tokens if we are done prefilling for this request + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] - prefill_logprob_tokens = NextTokens( - prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - [], - all_alternative_tokens, - ) - if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens + batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] - batch.prefill_logprob_tokens[i] = prefill_logprob_tokens - else: - batch.prefill_logprob_tokens[i] = None - - # If it is, the tokens we decoded should be ignored - if request_is_prefilling: - # Make sure that we do not stop as even though this request did not create a token, it is still - # processing - stopped = False - new_input_length = next_chunk_lengths[i] - else: - new_input_length = n_accepted_ids - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - logger.debug(f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) + index += n_accepted_ids + cumulative_length += input_length - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) + # Update values + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill and prefill_logprobs: + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out + prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break + # GPU <-> CPU sync + next_token_logprobs = next_token_logprobs.tolist() + next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() + + if return_alternatives: + alternative_token_logprobs = alternative_token_logprobs.tolist() + alternative_token_ids = alternative_token_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + + # Results + generations: List[Generation] = [] + stopped = not is_warmup + + # Zipped iterator + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + batch.stopping_criterias, + batch.all_input_ids, + batch.next_token_chooser.do_sample, + batch.next_token_chooser.seeds, + current_prefilling_mask, + batch.prefilling_mask, + accepted_ids, + ) - # Shard generations - # All generations will be appended in the rust sharded client + # Reset max_input_length + batch.max_input_length = 0 + # For each member of the batch + index = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + prefix_offset, + read_offset, + stopping_criteria, + all_input_ids, + do_sample, + seed, + request_was_prefilling, + request_is_prefilling, + n_accepted_ids, + ) in enumerate(iterator): + all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None + + # TODO(travis): return_k_alternatives + # if request.parameters.return_k_alternatives > 0: + # # Limit the number of alternatives to the vocabulary size + # num_alternatives = min( + # request.parameters.return_k_alternatives, + # len(alternative_token_ids[token_idx]), + # ) + + # # Select top-k logprobs + # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] + # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] + + # # Decode tokens + # request_alternative_token_texts = [] + # for alternative_token_id in request_alternative_token_ids: + # all_input_ids.append(alternative_token_id) + # alternative_token_text, _, _ = self.decode_token( + # all_input_ids, + # prefix_offset, + # read_offset, + # ) + # request_alternative_token_texts.append(alternative_token_text) + # all_input_ids.pop() + # alternative_tokens = AlternativeTokens( + # request_alternative_token_ids, + # request_alternative_token_logprobs, + # request_alternative_token_texts, + # ) + # all_alternative_tokens.append(alternative_tokens) + + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True, + # Prefill + if request_was_prefilling and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs + prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids + + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + + prefill_logprob_tokens = NextTokens( + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + [], + all_alternative_tokens, ) + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens + + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - generated_text = None - - # TODO(travis): top tokens - # if top_n_tokens > 0: - # all_top_tokens = [] - # for top_token_ids, top_token_logprobs in zip( - # top_token_ids, top_token_logprobs - # ): - # toptoken_texts = self.tokenizer.batch_decode( - # top_token_ids, - # clean_up_tokenization_spaces=False, - # skip_special_tokens=False, - # ) - # special_toptokens = [ - # token_id in self.all_special_ids - # for token_id in top_token_ids - # ] - # top_tokens = Tokens( - # top_token_ids, - # top_token_logprobs, - # toptoken_texts, - # special_toptokens, - # ) - # all_top_tokens.append(top_tokens) - # top_tokens = all_top_tokens - # else: - # top_tokens = None - - generation = Generation( - request.id, - batch.prefill_logprob_tokens[i], - len(all_input_ids[:-1]) if prefill else 0, - NextTokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - all_alternative_tokens, - ), - generated_text, - ) + batch.prefill_logprob_tokens[i] = None + + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + else: + new_input_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 + + if n_accepted_ids > 1: + logger.debug(f"speculated ids {n_accepted_ids - 1}") + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - generations.append(generation) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - # advance the FSM for each accepted token (as we may have more than one from speculative decoding) - for next_token_id in _next_token_ids: - batch.next_token_chooser.next_state(i, next_token_id) + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # TODO(travis): top tokens + # if top_n_tokens > 0: + # all_top_tokens = [] + # for top_token_ids, top_token_logprobs in zip( + # top_token_ids, top_token_logprobs + # ): + # toptoken_texts = self.tokenizer.batch_decode( + # top_token_ids, + # clean_up_tokenization_spaces=False, + # skip_special_tokens=False, + # ) + # special_toptokens = [ + # token_id in self.all_special_ids + # for token_id in top_token_ids + # ] + # top_tokens = Tokens( + # top_token_ids, + # top_token_logprobs, + # toptoken_texts, + # special_toptokens, + # ) + # all_top_tokens.append(top_tokens) + # top_tokens = all_top_tokens + # else: + # top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + len(all_input_ids[:-1]) if prefill else 0, + NextTokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + all_alternative_tokens, + ), + generated_text, + ) - # Update values - index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length - batch.max_current_length = max(batch.max_current_length, current_length) - - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.all_input_ids[i] = all_input_ids - - if stopped: - # No need to return a batch if we know that all requests stopped - return generations, None - - if prefill and finished_prefilling: - # We do not need prefill tensors anymore - batch.cu_seqlen_prefill = None - batch.prefill_cache_indices = None - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + generations.append(generation) - return generations, batch + # advance the FSM for each accepted token (as we may have more than one from speculative decoding) + for next_token_id in _next_token_ids: + batch.next_token_chooser.next_state(i, next_token_id) + + # Update values + index += n_accepted_ids + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.all_input_ids[i] = all_input_ids + + if stopped: + # No need to return a batch if we know that all requests stopped + return generations, None + + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None + + return generations, batch diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index a75906ea6..6fcc53503 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -23,6 +23,7 @@ enum_string_to_adapter_source, is_base_model, ) +from lorax_server.utils.profiler import timer from lorax_server.utils.sgmv import has_sgmv from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens @@ -90,30 +91,34 @@ async def Warmup(self, request: generate_pb2.WarmupRequest, context): return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) async def Prefill(self, request: generate_pb2.PrefillRequest, context): - batch = self.model.batch_type.from_pb( - request.batch, - self.model.tokenizer, - self.model.tokenizers, - self.model.processor, - self.model.model.config, - self.model.dtype, - self.model.device, - ) - - if self.model.supports_chunking: - if request.HasField("cached_batch"): - cached_batch = self.cache.pop(request.cached_batch.id) - if cached_batch is None: - raise ValueError(f"Batch ID {request.cached_batch.id} not found in cache.") - batch = self.model.batch_type.concatenate([cached_batch, batch]) - - generations, next_batch = self.model.generate_token(batch) - self.cache.set(next_batch) + with timer("prefill::total"): + with timer("prefill::batch::from_pb"): + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.tokenizers, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) - return generate_pb2.PrefillResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) + if self.model.supports_chunking: + if request.HasField("cached_batch"): + with timer("prefill::batch::concatenate"): + cached_batch = self.cache.pop(request.cached_batch.id) + if cached_batch is None: + raise ValueError(f"Batch ID {request.cached_batch.id} not found in cache.") + batch = self.model.batch_type.concatenate([cached_batch, batch]) + + with timer("prefill::generate_token"): + generations, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) async def Classify(self, request: generate_pb2.ClassifyRequest, context): if not self.model.supports_classification: @@ -150,31 +155,34 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): return embeddings_pb async def Decode(self, request: generate_pb2.DecodeRequest, context): - if len(request.batches) == 0: - raise ValueError("Must provide at least one batch") - - batches = [] - for batch_pb in request.batches: - batch = self.cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) - - if len(batches) == 0: - raise ValueError("All batches are empty") - - if len(batches) > 1: - batch = self.model.batch_type.concatenate(batches) - else: - batch = batches[0] - - generations, next_batch = self.model.generate_token(batch) - self.cache.set(next_batch) - - return generate_pb2.DecodeResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) + with timer("decode::total"): + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") + + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") + + if len(batches) > 1: + with timer("decode::batch::concatenate"): + batch = self.model.batch_type.concatenate(batches) + else: + batch = batches[0] + + with timer("decode::generate_token"): + generations, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context): if ( From 697bf4de25cfad7ddfbea21ab2237a753c38afa4 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 23 Oct 2024 12:12:08 -0700 Subject: [PATCH 02/76] Profiler --- server/lorax_server/utils/profiler.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 server/lorax_server/utils/profiler.py diff --git a/server/lorax_server/utils/profiler.py b/server/lorax_server/utils/profiler.py new file mode 100644 index 000000000..48ffe9ffc --- /dev/null +++ b/server/lorax_server/utils/profiler.py @@ -0,0 +1,36 @@ +import time +from contextlib import contextmanager + + +class TimingContextManager: + def __init__(self, name: str): + self.name = name + self.total_time = 0 + self.count = 0 + + @contextmanager + def timing(self): + start = time.time() + try: + yield + finally: + end = time.time() + self.total_time += end - start + self.count += 1 + print(f"=== {self.name}: avg={self.get_average_time():.3f} s total={self.total_time:.3f} s count={self.count}") + + def get_average_time(self): + if self.count == 0: + return 0 + return self.total_time / self.count + + +_timers = {} + + +@contextmanager +def timer(name: str): + if name not in _timers: + _timers[name] = TimingContextManager(name) + with _timers[name].timing(): + yield From d155163f8611750b7b7db7162d45c73a1b8da65b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 23 Oct 2024 12:43:34 -0700 Subject: [PATCH 03/76] Allow max batch prefill tokens < max input length --- Cargo.lock | 1 + launcher/Cargo.toml | 1 + launcher/src/main.rs | 285 +++++++++++++++++++++++++++++++++++++------ router/src/main.rs | 24 ++-- 4 files changed, 262 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a5d9ea898..82f5029e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1550,6 +1550,7 @@ dependencies = [ "ctrlc", "float_eq", "h2", + "hf-hub", "nix", "openssl", "reqwest", diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index cd1f3ef2c..efaf4e0dd 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -11,6 +11,7 @@ clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } nix = "0.26.2" openssl = "0.10.66" +hf-hub = { version = "0.3.0", features = ["tokio"] } h2 = "0.3.26" rustix = "0.37.25" serde = { version = "1.0.152", features = ["derive"] } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c181d898e..a65114d97 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,5 +1,8 @@ use clap::{Parser, ValueEnum}; -use nix::libc::ip_mreq_source; +use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, +}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use serde::Deserialize; @@ -20,33 +23,177 @@ use tracing_subscriber::EnvFilter; mod env_runtime; -#[derive(Clone, Copy, Debug, ValueEnum)] +fn get_config( + model_id: &str, + revision: &Option, +) -> Result> { + let mut path = std::path::Path::new(model_id).to_path_buf(); + let model_id = model_id.to_string(); + let filename = if !path.exists() { + // Assume it's a hub id + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; + + let config: Config = config.into(); + Ok(config) +} + +#[derive(Deserialize)] +struct RawConfig { + max_position_embeddings: Option, + n_positions: Option, + model_type: Option, + max_seq_len: Option, + quantization_config: Option, + n_embd: Option, + hidden_size: Option, + num_attention_heads: Option, + head_dim: Option, + vision_config: Option, + is_encoder_decoder: Option, +} + +#[derive(Deserialize)] +struct QuantizationConfig { + quant_method: Option, +} + +#[derive(Deserialize)] +struct VisionConfig {} + +#[derive(Deserialize)] +struct Config { + max_position_embeddings: Option, + quantize: Option, + head_dim: Option, + model_type: Option, + vision_config: Option, + is_encoder_decoder: bool, +} + +impl From for Config { + fn from(other: RawConfig) -> Self { + let max_position_embeddings = other + .max_position_embeddings + .or(other.max_seq_len) + .or(other.n_positions); + let quantize = other.quantization_config.and_then(|q| q.quant_method); + let head_dim = other.head_dim.or_else(|| { + match (other.hidden_size, other.n_embd, other.num_attention_heads) { + (Some(hidden_size), _, Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + // Legacy + (_, Some(hidden_size), Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + _ => None, + } + }); + let model_type = other.model_type; + let vision_config = other.vision_config; + let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); + Config { + max_position_embeddings, + quantize, + head_dim, + model_type, + vision_config, + is_encoder_decoder, + } + } +} + +#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)] +#[serde(rename_all = "kebab-case")] enum Quantization { - Bitsandbytes, - BitsandbytesNF4, - BitsandbytesFP4, - Gptq, + /// 4 bit quantization. Requires a specific AWQ quantized model: + /// . + /// Should replace GPTQ models wherever possible because of the better latency Awq, + /// 8 bit quantization, doesn't require specific model. + /// Should be a drop-in replacement to bitsandbytes with much better performance. + /// Kernels are from Eetq, + /// Variable bit quantization. Requires a specific EXL2 quantized model: + /// . Requires exllama2 kernels and does + /// not support tensor parallelism (num_shard > 1). + Exl2, + /// 4 bit quantization. Requires a specific GTPQ quantized model: . + /// text-generation-inference will use exllama (faster) kernels wherever possible, and use + /// triton kernel (wider support) when it's not. + /// AWQ has faster kernels. + Gptq, + /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, + /// but it is known that the model will be much slower to run than the native f16. + // #[deprecated( + // since = "1.1.0", + // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + // )] + Bitsandbytes, + /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, + /// but it is known that the model will be much slower to run than the native f16. + BitsandbytesNf4, + /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better + /// perplexity performance for you model + BitsandbytesFp4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + /// This is currently not the fastest because of local unpacking + padding to satisfy matrix + /// multiplication limitations. + Fp8, + /// 4 bit quantization. Requires a specific HQQ quantized model. Hqq_4bit, + /// 3 bit quantization. Requires a specific HQQ quantized model. Hqq_3bit, + /// 2 bit quantization. Requires a specific HQQ quantized model. Hqq_2bit, - Fp8, } impl std::fmt::Display for Quantization { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // To keep in track with `server`. match self { + #[allow(deprecated)] + // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases Quantization::Bitsandbytes => { write!(f, "bitsandbytes") } - Quantization::BitsandbytesNF4 => { + Quantization::BitsandbytesNf4 => { write!(f, "bitsandbytes-nf4") } - Quantization::BitsandbytesFP4 => { + Quantization::BitsandbytesFp4 => { write!(f, "bitsandbytes-fp4") } + Quantization::Exl2 => { + write!(f, "exl2") + } Quantization::Gptq => { write!(f, "gptq") } @@ -56,6 +203,9 @@ impl std::fmt::Display for Quantization { Quantization::Eetq => { write!(f, "eetq") } + Quantization::Fp8 => { + write!(f, "fp8") + } Quantization::Hqq_4bit => { write!(f, "hqq-4bit") } @@ -65,9 +215,6 @@ impl std::fmt::Display for Quantization { Quantization::Hqq_2bit => { write!(f, "hqq-2bit") } - Quantization::Fp8 => { - write!(f, "fp8") - } } } } @@ -250,8 +397,9 @@ struct Args { /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. /// Please note that some models have a finite range of sequence they can handle. - #[clap(default_value = "1024", long, env)] - max_input_length: usize, + /// Default to min(max_position_embeddings - 1, 4095) + #[clap(long, env)] + max_input_length: Option, /// This is the most important value to set as it defines the "memory budget" /// of running clients requests. @@ -261,8 +409,9 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, + /// Default to min(max_position_embeddings, 4096) + #[clap(long, env)] + max_total_tokens: Option, /// This represents the ratio of waiting queries vs running queries where /// you want to start considering pausing the running queries to include the waiting @@ -280,8 +429,9 @@ struct Args { /// Limits the number of tokens for the prefill operation. /// Since this operation take the most memory and is compute bound, it is interesting /// to limit the number of requests that can be sent. - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, + /// Default to `max_input_tokens + 50` to give a bit of room. + #[clap(long, env)] + max_batch_prefill_tokens: Option, /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. @@ -1178,6 +1328,9 @@ fn spawn_shards( fn spawn_webserver( args: Args, + max_input_tokens: usize, + max_total_tokens: usize, + max_batch_prefill_tokens: u32, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, ) -> Result { @@ -1192,11 +1345,11 @@ fn spawn_webserver( "--max-stop-sequences".to_string(), args.max_stop_sequences.to_string(), "--max-input-length".to_string(), - args.max_input_length.to_string(), + max_input_tokens.to_string(), "--max-total-tokens".to_string(), - args.max_total_tokens.to_string(), + max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), - args.max_batch_prefill_tokens.to_string(), + max_batch_prefill_tokens.to_string(), "--max-active-adapters".to_string(), args.max_active_adapters.to_string(), "--adapter-cycle-time-s".to_string(), @@ -1403,18 +1556,69 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:?}", args); + let config: Option = get_config(&args.model_id, &args.revision).ok(); + let max_default = 4096; + let max_position_embeddings = if let Some(config) = &config { + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_length.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + max_default + } else { + max_position_embeddings + } + } else { + max_default + } + } else { + max_default + }; + + // Defaults + let max_input_tokens = { + match args.max_input_length { + Some(max_input_tokens) => max_input_tokens, + None => { + let value = max_position_embeddings - 1; + tracing::info!("Default `max_input_tokens` to {value}"); + value + } + } + }; + let max_total_tokens = { + match args.max_total_tokens { + Some(max_total_tokens) => max_total_tokens, + None => { + let value = max_position_embeddings; + tracing::info!("Default `max_total_tokens` to {value}"); + value + } + } + }; + let max_batch_prefill_tokens = { + match args.max_batch_prefill_tokens { + Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, + None => { + // Adding some edge in order to account for potential block_size alignement + // issue. + let value: u32 = (max_input_tokens + 50) as u32; + tracing::info!("Default `max_batch_prefill_tokens` to {value}"); + value + } + } + }; + // Validate args - if args.max_input_length >= args.max_total_tokens { + if max_input_tokens >= max_total_tokens { return Err(LauncherError::ArgumentValidation( "`max_input_length` must be < `max_total_tokens`".to_string(), )); } - if args.max_input_length as u32 > args.max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", - args.max_batch_prefill_tokens, args.max_input_length - ))); - } if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( @@ -1434,16 +1638,16 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if args.max_batch_prefill_tokens > *max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_batch_prefill_tokens, max_batch_total_tokens + max_batch_prefill_tokens, max_batch_total_tokens ))); } - if args.max_total_tokens as u32 > *max_batch_total_tokens { + if max_total_tokens as u32 > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_total_tokens, max_batch_total_tokens + max_total_tokens, max_batch_total_tokens ))); } } @@ -1509,11 +1713,18 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } - let mut webserver = - spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { - shutdown_shards(shutdown.clone(), &shutdown_receiver); - err - })?; + let mut webserver = spawn_webserver( + args, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + shutdown.clone(), + &shutdown_receiver, + ) + .map_err(|err| { + shutdown_shards(shutdown.clone(), &shutdown_receiver); + err + })?; // Default exit code let mut exit_code = Ok(()); diff --git a/router/src/main.rs b/router/src/main.rs index 5d47d93ea..ec8b61aae 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -141,9 +141,6 @@ async fn main() -> Result<(), RouterError> { "`max_input_length` must be < `max_total_tokens`".to_string(), )); } - if max_input_length as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); - } if validation_workers == 0 { return Err(RouterError::ArgumentValidation( @@ -151,15 +148,6 @@ async fn main() -> Result<(), RouterError> { )); } - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -445,6 +433,18 @@ async fn main() -> Result<(), RouterError> { tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); + let supports_chunking = shard_info.chunked_prefill; + let max_batch_total_tokens = max_supported_batch_total_tokens; + if max_input_length as u32 > max_batch_prefill_tokens && !supports_chunking { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + let addr = match hostname.parse() { Ok(ip) => SocketAddr::new(ip, port), Err(_) => { From ca3280ce7be25ae58404060c30f57fbf148a9948 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 23 Oct 2024 12:54:06 -0700 Subject: [PATCH 04/76] Fix fallback --- server/lorax_server/models/flash_causal_lm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index d244f8af7..140caa1ab 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1187,9 +1187,6 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model if self.world_size > 1: raise ValueError("Cannot enable `--compile` when sharding across multiple GPUs") - # This will be recalculated in the graph step - self.decode_state = None - # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set. From 830ce3d86e6ded0638699936e8f829e86418f538 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 23 Oct 2024 15:35:15 -0700 Subject: [PATCH 05/76] Vectorize test --- docs/guides/contributing/development_env.md | 13 +- server/lorax_server/models/flash_causal_lm.py | 833 ++++++++++-------- server/lorax_server/server.py | 13 +- 3 files changed, 486 insertions(+), 373 deletions(-) diff --git a/docs/guides/contributing/development_env.md b/docs/guides/contributing/development_env.md index f33c4c2d6..5f35eebb1 100644 --- a/docs/guides/contributing/development_env.md +++ b/docs/guides/contributing/development_env.md @@ -47,12 +47,12 @@ We'll be working out of three different terminals during development, each servi Install development dependencies: ```shell -DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y +DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \ PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP + rm -f $PROTOC_ZIP && \ hash -r ``` @@ -71,8 +71,7 @@ tmux new -s server From within the `tmux` session, move into the LoRAX `server` directory within the repo (assumed to be in `/data/lorax`) and install dependencies: ```shell -cd /data/lorax/server -pip install -e . +cd /data/lorax/server && pip install -e . make gen-server ``` @@ -95,9 +94,9 @@ tmux new -s router Now move into the `router` directory within the repo and install dependencies: ```shell -cd /data/lorax/router -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -export PATH=$PATH:$HOME/.cargo/bin +cd /data/lorax/router && \ +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ +export PATH=$PATH:$HOME/.cargo/bin && \ touch ../proto/generate.proto ``` diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 140caa1ab..e3739797b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1496,33 +1496,35 @@ def generate_token( batch.prefilling = not finished_prefilling batch.prefilling_mask = next_prefilling_mask - speculative_tokens = get_speculative_tokens() - ( - next_input_ids, - next_token_logprobs, - accepted_ids, - speculative_ids, - ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], - next_token_logits, - speculative_tokens, - batch.speculative_ids, - speculative_logits, - ) + with timer(f"generate_token::next_token_chooser"): + speculative_tokens = get_speculative_tokens() + ( + next_input_ids, + next_token_logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_current_length], + next_token_logits, + speculative_tokens, + batch.speculative_ids, + speculative_logits, + ) if return_alternatives: alternative_token_logprobs, alternative_token_ids = torch.sort( torch.log_softmax(next_token_logits, -1), dim=-1, stable=True, descending=True ) - # Since we are done prefilling, all the tensors that were concatenating values for all the requests - # instantly become of shape [BATCH_SIZE] - if prefill and finished_prefilling: - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) - elif not prefill: - next_position_ids = batch.position_ids + with timer(f"{stage_str}::generate_token::new_empty"): + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: + next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) + elif not prefill: + next_position_ids = batch.position_ids # Zipped iterator iterator = zip( @@ -1540,376 +1542,432 @@ def generate_token( # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time - # For each member of the batch - index = 0 - # Cumulative length - cumulative_length = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - all_input_ids, - n_accepted_ids, - request_was_prefilling, - request_is_prefilling, - ) in enumerate(iterator): + with timer(f"generate_token::cumulative_length"): if prefill and finished_prefilling: - # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length - + current_prefilling_mask_tensor = torch.tensor(current_prefilling_mask, device=batch.all_input_ids_tensor.device) + + # Discard first elem, which is 0 + end_index = batch.cu_seqlen_prefill[1:] + # Initialize position_ids # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] + next_position_ids[:] = batch.position_ids[end_index - 1] # Initialize adapter indices # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] + next_adapter_indices[:] = batch.adapter_meta.adapter_indices[end_index - 1] + + # if not request_is_prefilling: + # # Only save tokens if we are done prefilling for this request + # for j in range(n_accepted_ids): + # batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] + + # Only save tokens if we are done prefilling for this request + offsets = batch.cache_lengths_tensor + batch.input_lengths_tensor + + batch.all_input_ids_tensor = update_all_input_ids_tensor( + accepted_ids, + batch.all_input_ids_tensor, + offsets, + next_input_ids, + current_prefilling_mask_tensor, + ) + + # batch_size = accepted_ids.shape[0] - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] - - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] - - index += n_accepted_ids - cumulative_length += input_length - - # Update values - # These values can be updated without a GPU -> CPU sync - if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if prefill and prefill_logprobs: - # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) - torch.log_softmax(out, -1, out=out) - prefill_logprobs_tensor = out - prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() + # # Create a batch index tensor [0, 1, 2, ..., batch_size - 1] + # batch_indices = torch.arange(batch_size, device=batch.all_input_ids_tensor.device) - # Does a GPU <-> CPU sync internally - if prefill and finished_prefilling: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) + # # Apply the current_prefilling_mask to get only the rows that should be updated + # valid_batch_indices = batch_indices[current_prefilling_mask_tensor] - # GPU <-> CPU sync - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids.tolist() - accepted_ids = accepted_ids.tolist() + # # Gather only the accepted_ids, offsets, and next_input_ids for valid rows + # print("!!! offsets", offsets, offsets.shape) + # print("!!! current_prefilling_mask", current_prefilling_mask_tensor) - if return_alternatives: - alternative_token_logprobs = alternative_token_logprobs.tolist() - alternative_token_ids = alternative_token_ids.tolist() + # accepted_ids_valid = accepted_ids[current_prefilling_mask_tensor] + # offsets_valid = offsets[current_prefilling_mask_tensor] - # Update values if we need to continue prefilling - # This represents the `else` case of the `Update values` if above - # but since this require the `next_token_ids` to be on CPU, it is better to do it here - if prefill and not finished_prefilling: - # Speculation must be ignored while we prefill even with chunking - # it simplifies everything - assert batch.speculative_ids is None + # print("!!! accepted_ids_valid", accepted_ids_valid, accepted_ids_valid.shape) + # print("!!! offsets_valid", offsets_valid, offsets_valid.shape) - all_postfix_ids = [] - for i, ( - request_prefilling, - next_token_id, - all_input_ids, - cache_length, - input_length, - next_chunk_length, - ) in enumerate( - zip( - batch.prefilling_mask, - next_token_ids, - batch.all_input_ids, - batch.cache_lengths, - batch.input_lengths, - next_chunk_lengths, - ) - ): - if request_prefilling: - next_cache_length = cache_length + input_length - # Get new prompt IDs to prefill - postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] - else: - # This request is done prefilling, the new id is the one selected the sampling method - postfix_ids = [next_token_id] + # # Reshape next_input_ids to [batch_size, S] + # max_candidates_per_batch = next_input_ids.shape[0] // batch_size + # next_input_ids_2d = next_input_ids.view(batch_size, max_candidates_per_batch) - all_postfix_ids.append(postfix_ids) + # # Gather only the next_input_ids for valid rows + # next_input_ids_valid = next_input_ids_2d[current_prefilling_mask_tensor] + # print("!!! next_input_ids_valid", next_input_ids_valid, next_input_ids_valid.shape) - batch.input_ids = all_postfix_ids + # # Generate a mask to gather the accepted IDs from next_input_ids + # expanded_accepted_ids = torch.repeat_interleave(accepted_ids_valid) + # print("!!! expanded_accepted_ids", expanded_accepted_ids, expanded_accepted_ids.shape) - # Results - generations: List[Generation] = [] - stopped = not is_warmup + # # Flatten the accepted next_input_ids + # accepted_next_input_ids = next_input_ids_valid[torch.arange(accepted_ids_valid.shape[0]), expanded_accepted_ids] - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - batch.stopping_criterias, - batch.all_input_ids, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - current_prefilling_mask, - batch.prefilling_mask, - accepted_ids, - ) + # # Compute insertion points for each valid batch, from the offset + # print("!!! expanded_accepted_ids", expanded_accepted_ids, expanded_accepted_ids.shape) + # insertion_indices = torch.repeat_interleave(offsets_valid) + torch.arange(expanded_accepted_ids.shape[0], device=offsets.device) - # Reset max_input_length - batch.max_input_length = 0 - # For each member of the batch - index = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - request_was_prefilling, - request_is_prefilling, - n_accepted_ids, - ) in enumerate(iterator): - all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None - - # TODO(travis): return_k_alternatives - # if request.parameters.return_k_alternatives > 0: - # # Limit the number of alternatives to the vocabulary size - # num_alternatives = min( - # request.parameters.return_k_alternatives, - # len(alternative_token_ids[token_idx]), - # ) - - # # Select top-k logprobs - # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] - # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] - - # # Decode tokens - # request_alternative_token_texts = [] - # for alternative_token_id in request_alternative_token_ids: - # all_input_ids.append(alternative_token_id) - # alternative_token_text, _, _ = self.decode_token( - # all_input_ids, - # prefix_offset, - # read_offset, - # ) - # request_alternative_token_texts.append(alternative_token_text) - # all_input_ids.pop() - # alternative_tokens = AlternativeTokens( - # request_alternative_token_ids, - # request_alternative_token_logprobs, - # request_alternative_token_texts, - # ) - # all_alternative_tokens.append(alternative_tokens) - - # Compute logprobs first as, even though we might skip the token, - # it can still be required to compute the logprobs - # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need - # this state to be stable - if request.id % self.world_size == self.rank: - # Prefill - if request_was_prefilling and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - if not request_is_prefilling: - # The request is dones prefilling, meaning that we started generating new tokens - # The last logprob is a logprob for a generated token that was not part of the prompt - # We need to remove it - out_end_index -= 1 - - request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] - - past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] - - if past_prefill_logprob_tokens is None: - # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs - prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids - - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + # # Place the accepted tokens into the right locations in valid rows only + # batch.all_input_ids_tensor[valid_batch_indices, insertion_indices] = accepted_next_input_ids - prefill_logprob_tokens = NextTokens( - prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - [], - all_alternative_tokens, - ) - if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens + print("all_input_ids", batch.all_input_ids_tensor.shape) - batch.prefill_logprob_tokens[i] = prefill_logprob_tokens - else: - batch.prefill_logprob_tokens[i] = None - - # If it is, the tokens we decoded should be ignored - if request_is_prefilling: - # Make sure that we do not stop as even though this request did not create a token, it is still - # processing - stopped = False - new_input_length = next_chunk_lengths[i] - else: - new_input_length = n_accepted_ids - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - logger.debug(f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) + with timer(f"generate_token::prefill_logprobs"): + if prefill and finished_prefilling: + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + for i, request in enumerate(batch.requests): + request_was_prefilling = current_prefilling_mask[i] + if request.prefill_logprobs and request_was_prefilling: + # For each member of the batch + index = 0 + + # TODO(travis): tons of d2h copies here make this super slow, should vectorize or do transfer + # up front + cache_length = batch.cache_lengths[i] + input_length = batch.input_lengths[i] + n_accepted_ids = accepted_ids[index] + + print("!!! prefill_logprobs") + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids + + with timer(f"generate_token::update_values"): + # Update values + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill and prefill_logprobs: + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out + prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + with timer(f"generate_token::find_segments"): + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) + with timer(f"{stage_str}::generate_token::d2h"): + # GPU <-> CPU sync + next_token_logprobs = next_token_logprobs.tolist() + next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped + if return_alternatives: + alternative_token_logprobs = alternative_token_logprobs.tolist() + alternative_token_ids = alternative_token_ids.tolist() - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + with timer(f"generate_token::update_values_continue_prefill"): + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + + with timer(f"generate_token::get_results"): + # Results + generations: List[Generation] = [] + stopped = not is_warmup + + # Zipped iterator + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + batch.stopping_criterias, + batch.all_input_ids, + batch.next_token_chooser.do_sample, + batch.next_token_chooser.seeds, + current_prefilling_mask, + batch.prefilling_mask, + accepted_ids, + ) - # Shard generations - # All generations will be appended in the rust sharded client + # Reset max_input_length + batch.max_input_length = 0 + # For each member of the batch + index = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + prefix_offset, + read_offset, + stopping_criteria, + all_input_ids, + do_sample, + seed, + request_was_prefilling, + request_is_prefilling, + n_accepted_ids, + ) in enumerate(iterator): + all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None + + # TODO(travis): return_k_alternatives + # if request.parameters.return_k_alternatives > 0: + # # Limit the number of alternatives to the vocabulary size + # num_alternatives = min( + # request.parameters.return_k_alternatives, + # len(alternative_token_ids[token_idx]), + # ) + + # # Select top-k logprobs + # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] + # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] + + # # Decode tokens + # request_alternative_token_texts = [] + # for alternative_token_id in request_alternative_token_ids: + # all_input_ids.append(alternative_token_id) + # alternative_token_text, _, _ = self.decode_token( + # all_input_ids, + # prefix_offset, + # read_offset, + # ) + # request_alternative_token_texts.append(alternative_token_text) + # all_input_ids.pop() + # alternative_tokens = AlternativeTokens( + # request_alternative_token_ids, + # request_alternative_token_logprobs, + # request_alternative_token_texts, + # ) + # all_alternative_tokens.append(alternative_tokens) + + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True, + # Prefill + if request_was_prefilling and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs + prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids + + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + + prefill_logprob_tokens = NextTokens( + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + [], + all_alternative_tokens, ) + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens + + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - generated_text = None - - # TODO(travis): top tokens - # if top_n_tokens > 0: - # all_top_tokens = [] - # for top_token_ids, top_token_logprobs in zip( - # top_token_ids, top_token_logprobs - # ): - # toptoken_texts = self.tokenizer.batch_decode( - # top_token_ids, - # clean_up_tokenization_spaces=False, - # skip_special_tokens=False, - # ) - # special_toptokens = [ - # token_id in self.all_special_ids - # for token_id in top_token_ids - # ] - # top_tokens = Tokens( - # top_token_ids, - # top_token_logprobs, - # toptoken_texts, - # special_toptokens, - # ) - # all_top_tokens.append(top_tokens) - # top_tokens = all_top_tokens - # else: - # top_tokens = None - - generation = Generation( - request.id, - batch.prefill_logprob_tokens[i], - len(all_input_ids[:-1]) if prefill else 0, - NextTokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - all_alternative_tokens, - ), - generated_text, - ) + batch.prefill_logprob_tokens[i] = None + + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + else: + new_input_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 + + if n_accepted_ids > 1: + logger.debug(f"speculated ids {n_accepted_ids - 1}") + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - generations.append(generation) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - # advance the FSM for each accepted token (as we may have more than one from speculative decoding) - for next_token_id in _next_token_ids: - batch.next_token_chooser.next_state(i, next_token_id) + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # TODO(travis): top tokens + # if top_n_tokens > 0: + # all_top_tokens = [] + # for top_token_ids, top_token_logprobs in zip( + # top_token_ids, top_token_logprobs + # ): + # toptoken_texts = self.tokenizer.batch_decode( + # top_token_ids, + # clean_up_tokenization_spaces=False, + # skip_special_tokens=False, + # ) + # special_toptokens = [ + # token_id in self.all_special_ids + # for token_id in top_token_ids + # ] + # top_tokens = Tokens( + # top_token_ids, + # top_token_logprobs, + # toptoken_texts, + # special_toptokens, + # ) + # all_top_tokens.append(top_tokens) + # top_tokens = all_top_tokens + # else: + # top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + len(all_input_ids[:-1]) if prefill else 0, + NextTokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + all_alternative_tokens, + ), + generated_text, + ) - # Update values - index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length - batch.max_current_length = max(batch.max_current_length, current_length) - - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.all_input_ids[i] = all_input_ids + generations.append(generation) + + # advance the FSM for each accepted token (as we may have more than one from speculative decoding) + for next_token_id in _next_token_ids: + batch.next_token_chooser.next_state(i, next_token_id) + + with timer(f"generate_token::update_remaining_values"): + # Update values + index += n_accepted_ids + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.all_input_ids[i] = all_input_ids if stopped: # No need to return a batch if we know that all requests stopped @@ -1924,3 +1982,56 @@ def generate_token( batch.prefill_next_token_indices = None return generations, batch + + +def update_all_input_ids_tensor( + accepted_ids, + all_input_ids_tensor, + offsets, + next_input_ids, + current_prefilling_mask +): + # Get batch size and compute S (number of candidate tokens per batch) + batch_size = accepted_ids.size(0) + S = next_input_ids.size(0) // batch_size + + # Reshape next_input_ids to [batch_size, S] + next_input_ids = next_input_ids.view(batch_size, S) + + # Select indices of batches to process based on the current_prefilling_mask + batch_indices = torch.nonzero(current_prefilling_mask, as_tuple=False).squeeze(1) + num_batches = batch_indices.size(0) + + # Gather the accepted_ids, offsets, and next_input_ids for the selected batches + accepted_ids_selected = accepted_ids[batch_indices] + offsets_selected = offsets[batch_indices] + next_input_ids_selected = next_input_ids[batch_indices] + + # Determine the maximum number of accepted IDs to pad sequences + max_accepted_ids = accepted_ids_selected.max() + + # Create sequence indices offsets for each batch + seq_indices_offsets = torch.arange(max_accepted_ids, device=accepted_ids.device).unsqueeze(0) + seq_indices_offsets = seq_indices_offsets.expand(num_batches, -1) + + # Create a mask to identify valid positions within accepted_ids for each batch + seq_mask = seq_indices_offsets < accepted_ids_selected.unsqueeze(1) + + # Calculate the sequence indices where updates will occur + seq_indices = seq_indices_offsets + offsets_selected.unsqueeze(1) + + # Expand batch indices to align with seq_indices + batch_indices_expanded = batch_indices.unsqueeze(1).expand(-1, max_accepted_ids) + + # Extract the values to be written into all_input_ids_tensor + values = next_input_ids_selected[:, :max_accepted_ids] + + # Flatten tensors and apply the mask to select valid positions + batch_indices_flat = batch_indices_expanded.reshape(-1)[seq_mask.reshape(-1)] + seq_indices_flat = seq_indices.reshape(-1)[seq_mask.reshape(-1)] + values_flat = values.reshape(-1)[seq_mask.reshape(-1)] + + # Update all_input_ids_tensor at the specified positions with the accepted IDs + all_input_ids_tensor[batch_indices_flat, seq_indices_flat] = values_flat + + return all_input_ids_tensor diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 6fcc53503..3c9b0f46f 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -59,11 +59,14 @@ async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() - return generate_pb2.ClearCacheResponse() + try: + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + return generate_pb2.ClearCacheResponse() + except: + exit(1) async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) From 7f250fe942be2afb7a638909c42cd472aa256896 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 23 Oct 2024 22:58:01 -0700 Subject: [PATCH 06/76] Triton punica kernels --- server/lorax_server/adapters/weights.py | 5 + server/lorax_server/models/flash_causal_lm.py | 210 ++++--- server/lorax_server/server.py | 7 + server/lorax_server/utils/ops/__init__.py | 0 server/lorax_server/utils/ops/bgmv_expand.py | 169 +++++ .../utils/ops/bgmv_expand_slice.py | 182 ++++++ server/lorax_server/utils/ops/bgmv_shrink.py | 150 +++++ server/lorax_server/utils/ops/libentry.py | 166 +++++ server/lorax_server/utils/ops/sgmv_expand.py | 192 ++++++ .../utils/ops/sgmv_expand_slice.py | 203 ++++++ server/lorax_server/utils/ops/sgmv_shrink.py | 188 ++++++ server/lorax_server/utils/ops/utils.py | 46 ++ server/lorax_server/utils/profiler.py | 5 +- server/lorax_server/utils/sgmv.py | 583 +++++++++++++++++- server/lorax_server/utils/state.py | 4 + 15 files changed, 2024 insertions(+), 86 deletions(-) create mode 100644 server/lorax_server/utils/ops/__init__.py create mode 100644 server/lorax_server/utils/ops/bgmv_expand.py create mode 100644 server/lorax_server/utils/ops/bgmv_expand_slice.py create mode 100644 server/lorax_server/utils/ops/bgmv_shrink.py create mode 100644 server/lorax_server/utils/ops/libentry.py create mode 100644 server/lorax_server/utils/ops/sgmv_expand.py create mode 100644 server/lorax_server/utils/ops/sgmv_expand_slice.py create mode 100644 server/lorax_server/utils/ops/sgmv_shrink.py create mode 100644 server/lorax_server/utils/ops/utils.py diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index bfa4a0bf2..668b36be7 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -25,6 +25,11 @@ class AdapterBatchMetadata: # segment_indices[s] == adapter_indices[i] segment_indices: List[int] + @property + def token_indices(self) -> torch.Tensor: + # Create the `token_indices` by repeating each segment index by the number of tokens in it + return torch.cat([torch.full((count,), idx, dtype=torch.long) for idx, count in enumerate(self.segment_indices)]) + class AdapterWeights(ABC): @abstractclassmethod diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e3739797b..8abaac919 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,9 +4,11 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union +from lorax_server.utils.sgmv import PunicaWrapper import numpy as np import torch import torch.distributed +import torch.profiler from loguru import logger from opentelemetry import trace from tqdm import tqdm @@ -35,6 +37,7 @@ from lorax_server.utils.state import ( BLOCK_SIZE, FLASH_INFER, + LORAX_PROFILER_DIR, get_max_prefill_tokens, get_speculative_tokens, get_supports_chunking, @@ -89,6 +92,7 @@ class FlashCausalLMBatch(Batch): prefilling: bool # Whether each request is prefilling prefilling_mask: List[bool] + prefilling_mask_tensor: Optional[torch.Tensor] # Prefill metadata tensors to efficiently compute logprobs # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill @@ -300,6 +304,9 @@ def from_pb( block_tables_tensor = block_tables_tensor.to(device) prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32, device=device) + prefilling_mask = [True] * len(pb.requests) + prefilling_mask_tensor = torch.tensor(prefilling_mask, dtype=torch.bool, device=device) + return cls( batch_id=pb.id, requests=pb.requests, @@ -311,7 +318,8 @@ def from_pb( max_input_length=max_input_length, max_current_length=max_current_length, prefilling=True, - prefilling_mask=[True] * len(pb.requests), + prefilling_mask=prefilling_mask, + prefilling_mask_tensor=prefilling_mask_tensor, prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, prompt_lengths=prompt_lengths, @@ -455,6 +463,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": cache_lengths_tensor = None input_lengths_tensor = None adapter_meta = None + prefilling_mask_tensor = self.prefilling_mask_tensor[indices] else: # Index into tensors input_ids = self.input_ids[indices] @@ -463,6 +472,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] + prefilling_mask_tensor = None # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) @@ -493,6 +503,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": max_current_length=max_current_length, prefilling=self.prefilling, prefilling_mask=prefilling_mask, + prefilling_mask_tensor=prefilling_mask_tensor, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -556,6 +567,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices = None cache_lengths_tensor = None input_lengths_tensor = None + prefilling_mask_tensor = batches[0].prefilling_mask_tensor.new_empty(total_batch_size) adapter_meta = None adapter_segment_builder = None else: @@ -565,6 +577,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(total_batch_size) + prefilling_mask_tensor = None total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) adapter_segment_builder = SegmentConcatBuilder() @@ -647,6 +660,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() input_ids.extend(batch.input_ids) + prefilling_mask_tensor[start_index:end_index] = batch.prefilling_mask_tensor prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) @@ -711,6 +725,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_current_length=max_current_length, prefilling=prefilling, prefilling_mask=prefilling_mask, + prefilling_mask_tensor=prefilling_mask_tensor, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -1076,6 +1091,24 @@ def __init__( num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, ) + + self.profiler = None + if LORAX_PROFILER_DIR is not None: + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler(LORAX_PROFILER_DIR, use_gzip=True) + ) + self.steps = 0 + + self.punica_wrapper = PunicaWrapper( + max_num_batched_tokens=10000, + max_batches=128, + device=self.device, + ) @property def block_size(self) -> int: @@ -1244,6 +1277,9 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.model_graph_wrapper.warmup() torch.cuda.synchronize(self.device) + if self.profiler is not None: + self.profiler.start() + return int(num_blocks * BLOCK_SIZE) def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: @@ -1456,6 +1492,7 @@ def generate_token( finished_prefilling = True next_chunk_lengths = [] current_prefilling_mask = batch.prefilling_mask + current_prefilling_mask_tensor = batch.prefilling_mask_tensor if prefill: if get_supports_chunking(): next_prefilling_mask = [] @@ -1487,11 +1524,14 @@ def generate_token( # Reverse back the obtained values² next_chunk_lengths.reverse() next_prefilling_mask.reverse() + + batch.prefilling_mask_tensor = torch.tensor(next_prefilling_mask, device=batch.all_input_ids_tensor.device) else: # The model does not support chunking # We know we only do a single prefill finished_prefilling = True next_prefilling_mask = [False] * len(batch) + batch.prefilling_mask_tensor = None batch.prefilling = not finished_prefilling batch.prefilling_mask = next_prefilling_mask @@ -1544,18 +1584,19 @@ def generate_token( with timer(f"generate_token::cumulative_length"): if prefill and finished_prefilling: - current_prefilling_mask_tensor = torch.tensor(current_prefilling_mask, device=batch.all_input_ids_tensor.device) - # Discard first elem, which is 0 - end_index = batch.cu_seqlen_prefill[1:] + with timer(f"generate_token::cumulative_length::end_index"): + end_index = batch.cu_seqlen_prefill[1:] # Initialize position_ids # In decode, we do not need this as we can just increment position ids - next_position_ids[:] = batch.position_ids[end_index - 1] + with timer(f"generate_token::cumulative_length::next_position_ids"): + next_position_ids[:] = batch.position_ids[end_index - 1] # Initialize adapter indices # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[:] = batch.adapter_meta.adapter_indices[end_index - 1] + with timer(f"generate_token::cumulative_length::adapter_indices"): + next_adapter_indices[:] = batch.adapter_meta.adapter_indices[end_index - 1] # if not request_is_prefilling: # # Only save tokens if we are done prefilling for this request @@ -1563,55 +1604,17 @@ def generate_token( # batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] # Only save tokens if we are done prefilling for this request - offsets = batch.cache_lengths_tensor + batch.input_lengths_tensor - - batch.all_input_ids_tensor = update_all_input_ids_tensor( - accepted_ids, - batch.all_input_ids_tensor, - offsets, - next_input_ids, - current_prefilling_mask_tensor, - ) - - # batch_size = accepted_ids.shape[0] - - # # Create a batch index tensor [0, 1, 2, ..., batch_size - 1] - # batch_indices = torch.arange(batch_size, device=batch.all_input_ids_tensor.device) - - # # Apply the current_prefilling_mask to get only the rows that should be updated - # valid_batch_indices = batch_indices[current_prefilling_mask_tensor] - - # # Gather only the accepted_ids, offsets, and next_input_ids for valid rows - # print("!!! offsets", offsets, offsets.shape) - # print("!!! current_prefilling_mask", current_prefilling_mask_tensor) - - # accepted_ids_valid = accepted_ids[current_prefilling_mask_tensor] - # offsets_valid = offsets[current_prefilling_mask_tensor] - - # print("!!! accepted_ids_valid", accepted_ids_valid, accepted_ids_valid.shape) - # print("!!! offsets_valid", offsets_valid, offsets_valid.shape) - - # # Reshape next_input_ids to [batch_size, S] - # max_candidates_per_batch = next_input_ids.shape[0] // batch_size - # next_input_ids_2d = next_input_ids.view(batch_size, max_candidates_per_batch) - - # # Gather only the next_input_ids for valid rows - # next_input_ids_valid = next_input_ids_2d[current_prefilling_mask_tensor] - # print("!!! next_input_ids_valid", next_input_ids_valid, next_input_ids_valid.shape) - - # # Generate a mask to gather the accepted IDs from next_input_ids - # expanded_accepted_ids = torch.repeat_interleave(accepted_ids_valid) - # print("!!! expanded_accepted_ids", expanded_accepted_ids, expanded_accepted_ids.shape) - - # # Flatten the accepted next_input_ids - # accepted_next_input_ids = next_input_ids_valid[torch.arange(accepted_ids_valid.shape[0]), expanded_accepted_ids] - - # # Compute insertion points for each valid batch, from the offset - # print("!!! expanded_accepted_ids", expanded_accepted_ids, expanded_accepted_ids.shape) - # insertion_indices = torch.repeat_interleave(offsets_valid) + torch.arange(expanded_accepted_ids.shape[0], device=offsets.device) - - # # Place the accepted tokens into the right locations in valid rows only - # batch.all_input_ids_tensor[valid_batch_indices, insertion_indices] = accepted_next_input_ids + with timer(f"generate_token::cumulative_length::offsets"): + offsets = batch.cache_lengths_tensor + batch.input_lengths_tensor + + with timer(f"generate_token::cumulative_length::update_all_input_ids_tensor"): + batch.all_input_ids_tensor = update_all_input_ids_tensor( + accepted_ids, + batch.all_input_ids_tensor, + offsets, + next_input_ids, + current_prefilling_mask_tensor, + ) print("all_input_ids", batch.all_input_ids_tensor.shape) @@ -1991,47 +1994,86 @@ def update_all_input_ids_tensor( next_input_ids, current_prefilling_mask ): - # Get batch size and compute S (number of candidate tokens per batch) - batch_size = accepted_ids.size(0) + # Get batch size + batch_size = all_input_ids_tensor.size(0) + # Calculate S (number of candidate tokens per batch) S = next_input_ids.size(0) // batch_size # Reshape next_input_ids to [batch_size, S] next_input_ids = next_input_ids.view(batch_size, S) - # Select indices of batches to process based on the current_prefilling_mask - batch_indices = torch.nonzero(current_prefilling_mask, as_tuple=False).squeeze(1) - num_batches = batch_indices.size(0) + # Since accepted_ids is always 1, we only need the first candidate token for each batch + values = next_input_ids[:, 0] - # Gather the accepted_ids, offsets, and next_input_ids for the selected batches - accepted_ids_selected = accepted_ids[batch_indices] - offsets_selected = offsets[batch_indices] - next_input_ids_selected = next_input_ids[batch_indices] + # Update all_input_ids_tensor at the specified positions with the accepted IDs + all_input_ids_tensor[torch.arange(batch_size), offsets] = values - # Determine the maximum number of accepted IDs to pad sequences - max_accepted_ids = accepted_ids_selected.max() + return all_input_ids_tensor - # Create sequence indices offsets for each batch - seq_indices_offsets = torch.arange(max_accepted_ids, device=accepted_ids.device).unsqueeze(0) - seq_indices_offsets = seq_indices_offsets.expand(num_batches, -1) + # NO SPECULATION + # # Get batch size + # batch_size = all_input_ids_tensor.size(0) + # # Calculate S (number of candidate tokens per batch) + # S = next_input_ids.size(0) // batch_size - # Create a mask to identify valid positions within accepted_ids for each batch - seq_mask = seq_indices_offsets < accepted_ids_selected.unsqueeze(1) + # # Reshape next_input_ids to [batch_size, S] + # next_input_ids = next_input_ids.view(batch_size, S) - # Calculate the sequence indices where updates will occur - seq_indices = seq_indices_offsets + offsets_selected.unsqueeze(1) + # # Select indices of batches to process based on current_prefilling_mask + # batch_indices = torch.nonzero(current_prefilling_mask, as_tuple=False).squeeze(1) - # Expand batch indices to align with seq_indices - batch_indices_expanded = batch_indices.unsqueeze(1).expand(-1, max_accepted_ids) + # # Gather offsets and next_input_ids for the selected batches + # offsets_selected = offsets[batch_indices] + # # Since accepted_ids is always 1, we only need the first candidate token + # values = next_input_ids[batch_indices, 0] - # Extract the values to be written into all_input_ids_tensor - values = next_input_ids_selected[:, :max_accepted_ids] + # # Update all_input_ids_tensor at the specified positions with the accepted IDs + # all_input_ids_tensor[batch_indices, offsets_selected] = values - # Flatten tensors and apply the mask to select valid positions - batch_indices_flat = batch_indices_expanded.reshape(-1)[seq_mask.reshape(-1)] - seq_indices_flat = seq_indices.reshape(-1)[seq_mask.reshape(-1)] - values_flat = values.reshape(-1)[seq_mask.reshape(-1)] + # return all_input_ids_tensor - # Update all_input_ids_tensor at the specified positions with the accepted IDs - all_input_ids_tensor[batch_indices_flat, seq_indices_flat] = values_flat + # FULL + # # Get batch size and compute S (number of candidate tokens per batch) + # batch_size = accepted_ids.size(0) + # S = next_input_ids.size(0) // batch_size - return all_input_ids_tensor + # # Reshape next_input_ids to [batch_size, S] + # next_input_ids = next_input_ids.view(batch_size, S) + + # # Select indices of batches to process based on the current_prefilling_mask + # batch_indices = torch.nonzero(current_prefilling_mask, as_tuple=False).squeeze(1) + # num_batches = batch_indices.size(0) + + # # Gather the accepted_ids, offsets, and next_input_ids for the selected batches + # accepted_ids_selected = accepted_ids[batch_indices] + # offsets_selected = offsets[batch_indices] + # next_input_ids_selected = next_input_ids[batch_indices] + + # # Determine the maximum number of accepted IDs to pad sequences + # max_accepted_ids = accepted_ids_selected.max() + + # # Create sequence indices offsets for each batch + # seq_indices_offsets = torch.arange(max_accepted_ids, device=accepted_ids.device).unsqueeze(0) + # seq_indices_offsets = seq_indices_offsets.expand(num_batches, -1) + + # # Create a mask to identify valid positions within accepted_ids for each batch + # seq_mask = seq_indices_offsets < accepted_ids_selected.unsqueeze(1) + + # # Calculate the sequence indices where updates will occur + # seq_indices = seq_indices_offsets + offsets_selected.unsqueeze(1) + + # # Expand batch indices to align with seq_indices + # batch_indices_expanded = batch_indices.unsqueeze(1).expand(-1, max_accepted_ids) + + # # Extract the values to be written into all_input_ids_tensor + # values = next_input_ids_selected[:, :max_accepted_ids] + + # # Flatten tensors and apply the mask to select valid positions + # batch_indices_flat = batch_indices_expanded.reshape(-1)[seq_mask.reshape(-1)] + # seq_indices_flat = seq_indices.reshape(-1)[seq_mask.reshape(-1)] + # values_flat = values.reshape(-1)[seq_mask.reshape(-1)] + + # # Update all_input_ids_tensor at the specified positions with the accepted IDs + # all_input_ids_tensor[batch_indices_flat, seq_indices_flat] = values_flat + + # return all_input_ids_tensor diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 3c9b0f46f..99f164856 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -64,6 +64,7 @@ async def ClearCache(self, request, context): self.cache.delete(request.id) else: self.cache.clear() + return generate_pb2.ClearCacheResponse() except: exit(1) @@ -117,6 +118,12 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): with timer("prefill::generate_token"): generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) + + if self.model.profiler: + self.model.steps += 1 + if self.model.steps == 10: + self.model.profiler.stop() + print(self.model.profiler.key_averages()) return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], diff --git a/server/lorax_server/utils/ops/__init__.py b/server/lorax_server/utils/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/lorax_server/utils/ops/bgmv_expand.py b/server/lorax_server/utils/ops/bgmv_expand.py new file mode 100644 index 000000000..d214da0b6 --- /dev/null +++ b/server/lorax_server/utils/ops/bgmv_expand.py @@ -0,0 +1,169 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + current_n_c = tl.max_contiguous(current_n, BLOCK_N) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n_c[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + batches (int): batch size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + batches = lora_indices_tensor.size(0) + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return \ No newline at end of file diff --git a/server/lorax_server/utils/ops/bgmv_expand_slice.py b/server/lorax_server/utils/ops/bgmv_expand_slice.py new file mode 100644 index 000000000..1444fa8e5 --- /dev/null +++ b/server/lorax_server/utils/ops/bgmv_expand_slice.py @@ -0,0 +1,182 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + + slice_offset * cn_stride) + + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + batches (int): batch size + add_inputs (bool, optional): Defaults to False. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + batches = lora_indices_tensor.size(0) + + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return \ No newline at end of file diff --git a/server/lorax_server/utils/ops/bgmv_shrink.py b/server/lorax_server/utils/ops/bgmv_shrink.py new file mode 100644 index 000000000..c532ba526 --- /dev/null +++ b/server/lorax_server/utils/ops/bgmv_shrink.py @@ -0,0 +1,150 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's + performance + """ + pid_sk = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + offset_n = tl.arange(0, BLOCK_N) + offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K + a_ptr = input_ptr + cur_batch * xm_stride + b_ptr = lora_ptr + l0_stride * lora_index + accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) + for k in range(0, K, BLOCK_K * SPLIT_K): + current_k = k + offset_k + current_k_c = tl.max_contiguous(current_k, BLOCK_K) + tiled_a = tl.load( + a_ptr + current_k_c, + mask=current_k < K, + other=0.0, + ) # [BLOCK_K] + b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) + + tiled_b = tl.load( + b_ptr + offset_n[:, None] * lora_k_stride + + current_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + accumulator += tl.sum(tiled_a * tiled_b, 1) + accumulator *= scaling + offset_cn = tl.arange(0, BLOCK_N) + c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride + c_mask = offset_cn < N + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + scaling (float): Scaling factor. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + batches = lora_indices_tensor.size(0) + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_N = triton.next_power_of_2(N) + if override_config: + config = override_config + else: + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) + + grid = lambda META: ( + META["SPLIT_K"], + batches, + ) + _bgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_N=BLOCK_N, + **config, + ) + return \ No newline at end of file diff --git a/server/lorax_server/utils/ops/libentry.py b/server/lorax_server/utils/ops/libentry.py new file mode 100644 index 000000000..867d71662 --- /dev/null +++ b/server/lorax_server/utils/ops/libentry.py @@ -0,0 +1,166 @@ +# Copied From https://github.com/FlagOpen/FlagGems + +import inspect + +import triton + + +class LibEntry(triton.KernelInterface): + + def __init__( + self, + fn, + ): + self.fn = fn + self.arg_names = fn.arg_names + self.divisibility = 16 + self.kernel_cache = dict() + fn = self.fn + while not isinstance(fn, triton.runtime.JITFunction): + fn = fn.fn + self.jit_function: triton.runtime.JITFunction = fn + self.specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and not p.do_not_specialize + ] + self.do_not_specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and p.do_not_specialize + ] + + def key(self, spec_args, dns_args, const_args): + spec_key = [(arg.dtype, arg.data_ptr() % + self.divisibility == 0) if hasattr(arg, "data_ptr") else + (type(arg), arg) for arg in spec_args] + dns_key = [ + arg.dtype if hasattr( + arg, "data_ptr") else type(arg) if not isinstance(arg, int) + else "i32" if -(2**31) <= arg and arg <= 2**31 - + 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + for arg in dns_args + ] + # const args passed by position + return tuple(spec_key + dns_key + const_args) + + def run(self, *args, **kwargs): + grid = kwargs["grid"] + # collect all the arguments + spec_args = [] # specialize arguments + dns_args = [] # do not specialize arguments + const_args = [] # constexpr arguments + k_args = [] # kernel arguments + for i, arg in enumerate(args): + if i in self.specialize_indices: + k_args.append(arg) + spec_args.append(arg) + elif i in self.do_not_specialize_indices: + k_args.append(arg) + dns_args.append(arg) + else: + const_args.append(arg) + for p in self.jit_function.params[len(args):]: + if p.name in kwargs: + val = kwargs[p.name] + elif p.default is inspect._empty: + continue + else: + val = p.default + + if p.is_constexpr: + const_args.append(val) + elif p.do_not_specialize: + dns_args.append(val) + k_args.append(val) + else: + spec_args.append(val) + k_args.append(val) + + entry_key = self.key(spec_args, dns_args, const_args) + + if entry_key not in self.kernel_cache: + # compile the kernel also completes the related computations + kernel = self.fn.run(*args, **kwargs) + fn = self.fn + # collect constexpr arguments for grid computation + constexprs = {} + while not isinstance(fn, triton.runtime.JITFunction): + if isinstance(fn, triton.runtime.Autotuner): + config = fn.best_config + constexprs["num_warps"] = config.num_warps + constexprs["num_stages"] = config.num_stages + constexprs["num_ctas"] = config.num_ctas + constexprs = {**constexprs, **config.kwargs} + elif isinstance(fn, triton.runtime.Heuristics): + for v, heur in fn.values.items(): + constexprs[v] = heur({ + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + }) + else: + raise RuntimeError("Invalid Runtime Function") + fn = fn.fn + # In vLLM, certain kernels like fused_moe_kernel get the + # best_config(as kwargs) from a configuration json file, rather + # than using Autotuner & Heuristics. Therefore, all their constexprs + # (tl.constexpr) are assigned values through the following loop. + for p in self.jit_function.params: + if p.is_constexpr and p.name not in constexprs: + constexprs[p.name] = p.default #default=inspect._empty + self.kernel_cache[entry_key] = (kernel, constexprs) + else: + # load kernel from cache directly + kernel, constexprs = self.kernel_cache[entry_key] + + if callable(grid): + # collect all arguments to the grid fn,ie: + # 1. args, + # 2. kwargs, + # 3. all all other captured arguments in CompiledKernel from + # Autotunner & Heuristics when kwargs & captured args conflict, + # captured args have higher priority + # 4. We must filter out captured args with default value firstly + constexprs = { + k: v + for k, v in constexprs.items() if v is not inspect._empty + } + meta = { + **dict(zip(self.arg_names, args)), + **kwargs, + **constexprs, + } + grid = grid(meta) + if isinstance(grid, tuple): + grid = grid + (1, 1) + elif isinstance(grid, list): + grid = grid + [1, 1] + kernel[grid[0:3]](*k_args) + # maintaining the same return type as the JITFunction.run + return kernel + + +def libentry(): + """ + Decorator for triton library entries. + Motivation: + The runtime overhead of Triton kernels is the reason for the lower + performance of small kernels, particularly evident with smaller models. + Using this decorator can reduce Triton runtime overhead. + How: + The `run` function of JITFunction needs to accomplish: + - Parameter binding using inspect + - KernelArg type wrapping + - Cache key calculation + When dealing with small size, these steps can become bottlenecks in + Triton runtime. Libentry simplifies these steps to reduce runtime + overhead, thereby improving the runtime expenses of small kernels. + NOTE: + When Triton is upgraded to version 3.0.0, libentry can be removed, + see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 + + """ + + def decorator(fn): + return LibEntry(fn) + + return decorator \ No newline at end of file diff --git a/server/lorax_server/utils/ops/sgmv_expand.py b/server/lorax_server/utils/ops/sgmv_expand.py new file mode 100644 index 000000000..867e6406f --- /dev/null +++ b/server/lorax_server/utils/ops/sgmv_expand.py @@ -0,0 +1,192 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from lorax_server.utils.ops import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + The sgmv's expand triton kernel is based on GroupGEMM. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + add_inputs: bool = False, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return \ No newline at end of file diff --git a/server/lorax_server/utils/ops/sgmv_expand_slice.py b/server/lorax_server/utils/ops/sgmv_expand_slice.py new file mode 100644 index 000000000..021d8639d --- /dev/null +++ b/server/lorax_server/utils/ops/sgmv_expand_slice.py @@ -0,0 +1,203 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from lorax_server.utils.ops import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < + (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + """_summary_ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output.. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return \ No newline at end of file diff --git a/server/lorax_server/utils/ops/sgmv_shrink.py b/server/lorax_server/utils/ops/sgmv_shrink.py new file mode 100644 index 000000000..584e51980 --- /dev/null +++ b/server/lorax_server/utils/ops/sgmv_shrink.py @@ -0,0 +1,188 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from lorax_server.utils.ops import libentry + + +@libentry() +@triton.jit +def _sgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_sk = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride) + b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + + offset_k[:, None] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < k_remaining, + other=0.0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < k_remaining, + other=0.0) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + scaling: float, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K, + batches, + ) + + _sgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + ) + return \ No newline at end of file diff --git a/server/lorax_server/utils/ops/utils.py b/server/lorax_server/utils/ops/utils.py new file mode 100644 index 000000000..c4615d40f --- /dev/null +++ b/server/lorax_server/utils/ops/utils.py @@ -0,0 +1,46 @@ +import functools +from typing import Dict + + +@functools.lru_cache +def _get_op_configs(op_type: str, batch: int, hidden_size: int): + # TODO: add optimal configurations + return None + + +def _check_divisibility(hidden_size: int): + # The bgmv_expand kernel requires that the hidden_size be divisible by + # the number below. + divisibility = [2, 4, 8, 16, 32, 64] + divisibility.sort(reverse=True) + for div in divisibility: + if hidden_size % div == 0: + return div + # hidden_size is an odd number + return 1 + + +def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "expand": + return { + "BLOCK_N": 256, + "SPLIT_N": _check_divisibility(hidden_size), + "num_warps": 8 + } + else: + return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} + + +def get_lora_op_configs(op_type: str, batch: int, + hidden_size: int) -> Dict[str, int]: + """Inspired by `fused_moe_kernel` + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using + scripts similar to `benchmark_moe.py`. + """ + config = _get_op_configs(op_type, batch, hidden_size) + if not config: + config = _get_default_config(op_type, batch, hidden_size) + return config \ No newline at end of file diff --git a/server/lorax_server/utils/profiler.py b/server/lorax_server/utils/profiler.py index 48ffe9ffc..5f4a43bf0 100644 --- a/server/lorax_server/utils/profiler.py +++ b/server/lorax_server/utils/profiler.py @@ -1,6 +1,8 @@ import time from contextlib import contextmanager +import torch + class TimingContextManager: def __init__(self, name: str): @@ -17,7 +19,7 @@ def timing(self): end = time.time() self.total_time += end - start self.count += 1 - print(f"=== {self.name}: avg={self.get_average_time():.3f} s total={self.total_time:.3f} s count={self.count}") + # print(f"=== {self.name}: avg={self.get_average_time():.3f} s total={self.total_time:.3f} s count={self.count}") def get_average_time(self): if self.count == 0: @@ -34,3 +36,4 @@ def timer(name: str): _timers[name] = TimingContextManager(name) with _timers[name].timing(): yield + torch.cuda.synchronize() diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 6efb2647f..c1abdae5f 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -1,11 +1,15 @@ import os import warnings from functools import lru_cache -from typing import List, Tuple +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union +from lorax_server.utils.ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink import torch import torch.nn.functional as F +if TYPE_CHECKING: + from lorax_server.adapters.weights import AdapterBatchData + try: import punica_kernels as _kernels @@ -234,3 +238,580 @@ def segmented_matmul( wi = w[i] bi = b[i] y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + meta: "AdapterBatchData", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = meta.meta.token_indices.tolist() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in meta.meta.adapter_indices + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) + + +class PunicaWrapper: + """ + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: str): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.max_length: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def update_metadata( + self, + meta: "AdapterBatchData", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context = None, + ): + + self._update_base_metadata(meta, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if meta.prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + def _update_base_metadata( + self, + meta: "AdapterBatchData", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + meta, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.no_lora = no_lora + + @property + def prefill_metadata( + self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions + 2. seq_lengths: Tensor of sequence lengths + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: batch size after clustering identical lora indices + 5. max_length: The maximum sequence length in the batch + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the shrink_decode function + should be called. + """ + shrink_fun: Callable = (self.shrink_prefill + if self.is_prefill else self.shrink_decode) + shrink_fun(y, x, w_t_all, scale) + + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'b. + When `is_prefill` is true, it indicates that it is currently the + prefill stage, and the `expand_prefill` function should be called. + Otherwise, it is the decode stage, and the expand_decode function + should be called. + """ + + expand_fun: Callable = (self.expand_prefill + if self.is_prefill else self.expand_decode) + expand_fun(y, x, w_t_all, add_input) + + def add_expand_slice(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True): + """ + Similar to `add_expand` + """ + + expand_slice_fun: Callable = (self.expand_slice_prefill + if self.is_prefill else + self.expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def add_lora(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight + scale (float): Scaling factor. + y_offset (Optional[int], optional): Offset to apply to the starting + column of y. + y_slice_size (Optional[int], optional): Size of the y column slice.. + buffer (Optional[torch.Tensor], optional): Defaults to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + self.add_shrink(buffer, x, wa_t_all, scale) + if y_offset is None and y_slice_size is None: + self.add_expand(y, buffer, wb_t_all, add_input=True) + else: + self.add_expand_slice(y, + buffer, + wb_t_all, + y_offset, + y_slice_size, + add_input=True) + y = y.view_as(y_org) + + def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + scale: float, + output_slices: Tuple[int, ...]) -> None: + """ + Applies lora to each input. Similar to add_lora, This method is + used for layers that are composed of multiple sublayers + (slices) packed together. + """ + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + # TODO fuse these kernels + for slice_idx in range(len(output_slices)): + self.add_lora(y, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], scale, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + LogitsProcessorWithLoRA always using bgmv + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + y = y.view_as(y_org) diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 130cdc6e5..d6b3cea9f 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -8,6 +8,7 @@ SPECULATIVE_TOKENS = 0 +LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None) PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) CHUNKED_PREFILL = bool(os.environ.get("CHUNKED_PREFILL", "")) @@ -21,6 +22,9 @@ logger.info(f"Prefix caching = {PREFIX_CACHING}") logger.info(f"Chunked prefill = {CHUNKED_PREFILL}") +if LORAX_PROFILER_DIR: + logger.info(f"Torch profiling enabled, output dir = {LORAX_PROFILER_DIR}") + SUPPORTS_CHUNKING: Optional[bool] = None MAX_PREFILL_TOKENS: Optional[int] = None From e4fb765b5b822d28b7f685a3f0d2acf9a66c2a62 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 10:48:43 -0700 Subject: [PATCH 07/76] Use triton punica --- server/lorax_server/adapters/weights.py | 17 ++- server/lorax_server/models/flash_causal_lm.py | 29 +++-- server/lorax_server/utils/layers.py | 117 ++++++++++-------- server/lorax_server/utils/profiler.py | 2 +- server/lorax_server/utils/sgmv.py | 85 +++++++------ 5 files changed, 150 insertions(+), 100 deletions(-) diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 668b36be7..4035fa16e 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -1,19 +1,25 @@ from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Type import torch from lorax_server.adapters.types import LORA from lorax_server.utils.lora import LM_HEAD +if TYPE_CHECKING: + from lorax_server.utils.sgmv import PunicaWrapper + @dataclass class AdapterBatchMetadata: - # [batch_size] + # [num_tokens] adapter_indices: torch.Tensor + # [batch_size] + adapter_list: List[int] + # [num_adapters] adapter_set: Set[int] @@ -28,7 +34,7 @@ class AdapterBatchMetadata: @property def token_indices(self) -> torch.Tensor: # Create the `token_indices` by repeating each segment index by the number of tokens in it - return torch.cat([torch.full((count,), idx, dtype=torch.long) for idx, count in enumerate(self.segment_indices)]) + return torch.cat([torch.full((count,), self.adapter_indices[idx], dtype=torch.long) for idx, count in enumerate(self.segment_indices)]) class AdapterWeights(ABC): @@ -111,12 +117,15 @@ class AdapterBatchData: # layer type -> adapter type -> batch weight data data: Dict[str, Dict[str, BatchAdapterWeights]] + punica_wrapper: "PunicaWrapper" + prefill: bool @staticmethod def from_meta( meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights], + punica_wrapper: "PunicaWrapper", prefill: bool, prefill_head_indices: Optional[torch.Tensor], ) -> "AdapterBatchData": @@ -127,7 +136,7 @@ def from_meta( layer_weights = v.get_data(meta, k, prefill, prefill_head_indices if k == LM_HEAD else None) if layer_weights: data[k] = layer_weights - return AdapterBatchData(meta=meta, data=data, prefill=prefill) + return AdapterBatchData(meta=meta, data=data, punica_wrapper=punica_wrapper, prefill=prefill) def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 8abaac919..617241b7e 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -386,7 +386,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_logprob_tokens = [] stopping_criterias = [] - adapter_set = set() + adapter_list = [] num_blocks = 0 max_blocks = 0 @@ -423,7 +423,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) - adapter_set.add(self.requests[idx].adapter_index) + adapter_list.append(self.requests[idx].adapter_index) request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) @@ -481,7 +481,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) @@ -581,6 +582,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) adapter_segment_builder = SegmentConcatBuilder() + adapter_list = [] adapter_set = set() prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(total_batch_size) @@ -648,6 +650,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index + adapter_list.extend(batch.adapter_meta.adapter_list) adapter_set.update(batch.adapter_meta.adapter_set) adapter_segment_builder.concat( batch.adapter_meta.adapter_segments, @@ -701,6 +704,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, + adapter_list=adapter_list, adapter_set=adapter_set, adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, @@ -771,7 +775,7 @@ def prepare_for_prefill(self): slots = [] adapter_indices_list = [] - adapter_set = set() + adapter_list = [] for i, ( r, @@ -853,7 +857,7 @@ def prepare_for_prefill(self): prefill_cache_indices.append(request_prefill_cache_indices) adapter_indices_list.append(torch.full((next_chunk_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) + adapter_list.append(r.adapter_index) # Update cumulative_length += next_chunk_length @@ -906,7 +910,8 @@ def prepare_for_prefill(self): adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) @@ -1459,6 +1464,7 @@ def generate_token( adapter_segments = adapter_meta.adapter_segments * new_length adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, + adapter_list=adapter_meta.adapter_list, adapter_set=adapter_meta.adapter_set, adapter_segments=adapter_segments, segment_indices=adapter_meta.segment_indices, @@ -1466,8 +1472,17 @@ def generate_token( # Assign pointers to adapter weights # TODO(travis): don't update this if indices haven't changed + # self.punica_wrapper.update_metadata( + # adapter_meta, + # prefill, + # len(adapter_meta.adapter_set), + # self.model.config.vocab_size, + # self.model.config.vocab_size, + # None, + # ) + self.punica_wrapper.update_metadata(adapter_meta, prefill) adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + adapter_meta, self.layer_to_adapter_weights, self.punica_wrapper, prefill, batch.prefill_head_indices ) with timer(f"{stage_str}::generate_token::forward"): diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index d7f54a420..5ce80d102 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -77,64 +77,79 @@ def forward_layer_type( if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if end_idx - start_idx != result.shape[1]: - proj = torch.zeros_like(result[:, start_idx:end_idx]) + # proj = torch.zeros_like(result[:, start_idx:end_idx]) + y_offset = start_idx + y_slice_size = end_idx - start_idx else: - proj = result + # proj = result + y_offset = None + y_slice_size = None for r, rank_segments in data.rank_data.items(): lora_a_ptr = rank_segments.lora_a_ptr lora_b_ptr = rank_segments.lora_b_ptr - if data.use_sgmv: - # Use SGMV for prefill - if lora_a_ptr is not None and lora_b_ptr is not None: - v = lora_a_sgmv_cutlass( - input, - rank_segments.tmp_shrink, - lora_a_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - r, - ) - - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - lora_b_sgmv_cutlass( - proj, - v, - rank_segments.tmp_expand, - lora_b_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - ) - else: - # Use BGMV for decode - if lora_a_ptr is not None and lora_b_ptr is not None: - v = torch.zeros((input.size(0), r), dtype=input.dtype, device=input.device) - add_lora_a_bgmv( - v, - input, - lora_a_ptr, - rank_segments.indices, - self.layer_id, - ) - - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - add_lora_b_bgmv( - proj, - v, - lora_b_ptr, - rank_segments.indices, - self.layer_id, - ) + adapter_data.punica_wrapper.add_lora( + result, + input, + lora_a_ptr, + lora_b_ptr, + 1.0, + y_offset, + y_slice_size, + callback=self.collect_lora_a if self.process_group.size() > 1 else None, + ) - if end_idx - start_idx != result.shape[1]: - result[:, start_idx:end_idx] += proj + # if data.use_sgmv: + # # Use SGMV for prefill + # if lora_a_ptr is not None and lora_b_ptr is not None: + # v = lora_a_sgmv_cutlass( + # input, + # rank_segments.tmp_shrink, + # lora_a_ptr, + # rank_segments.segment_starts, + # rank_segments.segment_ends, + # self.layer_id, + # r, + # ) + + # if self.process_group.size() > 1: + # v = self.collect_lora_a(v) + + # lora_b_sgmv_cutlass( + # proj, + # v, + # rank_segments.tmp_expand, + # lora_b_ptr, + # rank_segments.segment_starts, + # rank_segments.segment_ends, + # self.layer_id, + # ) + # else: + # # Use BGMV for decode + # if lora_a_ptr is not None and lora_b_ptr is not None: + # v = torch.zeros((input.size(0), r), dtype=input.dtype, device=input.device) + # add_lora_a_bgmv( + # v, + # input, + # lora_a_ptr, + # rank_segments.indices, + # self.layer_id, + # ) + + # if self.process_group.size() > 1: + # v = self.collect_lora_a(v) + + # add_lora_b_bgmv( + # proj, + # v, + # lora_b_ptr, + # rank_segments.indices, + # self.layer_id, + # ) + + # if end_idx - start_idx != result.shape[1]: + # result[:, start_idx:end_idx] += proj else: adapter_indices = adapter_data.meta.adapter_indices if data is not None and data.prefill_head_indices is not None and data.layer_name == LM_HEAD: diff --git a/server/lorax_server/utils/profiler.py b/server/lorax_server/utils/profiler.py index 5f4a43bf0..bb74be6d2 100644 --- a/server/lorax_server/utils/profiler.py +++ b/server/lorax_server/utils/profiler.py @@ -36,4 +36,4 @@ def timer(name: str): _timers[name] = TimingContextManager(name) with _timers[name].timing(): yield - torch.cuda.synchronize() + # torch.cuda.synchronize() diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index c1abdae5f..df49830ea 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -8,7 +8,7 @@ import torch.nn.functional as F if TYPE_CHECKING: - from lorax_server.adapters.weights import AdapterBatchData + from lorax_server.adapters.weights import AdapterBatchMetadata try: import punica_kernels as _kernels @@ -272,8 +272,7 @@ def compute_meta( # TODO see if this can be vectorized def convert_mapping( - meta: "AdapterBatchData", - lora_index_to_id: List[Optional[int]], + meta: "AdapterBatchMetadata", max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -283,7 +282,6 @@ def convert_mapping( """Converts LoRAMapping to index tensors. Args: mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. @@ -311,7 +309,7 @@ def convert_mapping( (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_indices). """ - index_mapping_indices: List[int] = meta.meta.token_indices.tolist() + index_mapping_indices: List[int] = meta.adapter_indices.tolist() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None @@ -319,15 +317,10 @@ def convert_mapping( long_lora_offsets = torch.zeros(len(index_mapping_indices), device="cuda", dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in meta.meta.adapter_indices - ] + prompt_mapping = meta.adapter_list.copy() lora_idx = None for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) + lora_idx = index_mapping_indices[i] embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: @@ -433,30 +426,41 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self.is_prefill = False self.no_lora = False + # def update_metadata( + # self, + # meta: "AdapterBatchMetadata", + # prefill: bool, + # max_loras: int, + # vocab_size: int, + # extra_vocab_size: int, + # long_lora_context = None, + # ): + + # self._update_base_metadata(meta, max_loras, + # vocab_size, extra_vocab_size, + # long_lora_context) + # if prefill: + # # Update metadata required for prefill-related operators. + # self._update_prefill_metada(self.token_lora_indices) + # self.is_prefill = True + # else: + # self.is_prefill = False + def update_metadata( self, - meta: "AdapterBatchData", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context = None, + meta: "AdapterBatchMetadata", + prefill: bool, ): - - self._update_base_metadata(meta, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - if meta.prefill: + if prefill: # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) + self._update_prefill_metada(meta.adapter_indices) self.is_prefill = True else: self.is_prefill = False def _update_base_metadata( self, - meta: "AdapterBatchData", - lora_index_to_id: List[Optional[int]], + meta: "AdapterBatchMetadata", max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -471,7 +475,6 @@ def _update_base_metadata( indices_len, ) = convert_mapping( meta, - lora_index_to_id, max_loras, vocab_size, extra_vocab_size, @@ -711,16 +714,19 @@ def add_expand_slice(self, self.expand_slice_decode) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - def add_lora(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale: float, - y_offset: Optional[int] = None, - y_slice_size: Optional[int] = None, - *, - buffer: Optional[torch.Tensor] = None) -> None: + def add_lora( + self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None, + callback: Optional[Callable] = None, + ): """ Semantics: y[i] += ( @@ -752,6 +758,11 @@ def add_lora(self, device=x.device) self.add_shrink(buffer, x, wa_t_all, scale) + + if callback is not None: + # callback used to aggregate intermediate results (i.e., allreduce, allgather) + buffer = callback(buffer) + if y_offset is None and y_slice_size is None: self.add_expand(y, buffer, wb_t_all, add_input=True) else: From 634c8e224c7b5fbec622a31b5eefb4708641127a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 12:05:48 -0700 Subject: [PATCH 08/76] Fix format --- server/lorax_server/adapters/lora.py | 3 ++- server/lorax_server/models/model.py | 32 ++++++++++++++++++++++++++++ server/lorax_server/utils/sgmv.py | 4 +++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index fac7f5b8b..bab1119d8 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -100,7 +100,8 @@ def __init__( self._is_transposed = False # [num_layers, hidden_size, r] - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + # TODO(travis): add this back if rank is 8 and we're not using triton + # weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index f5e9dbe99..87158fb14 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -3,6 +3,7 @@ from collections import defaultdict from typing import Dict, List, Optional, Tuple, Type, TypeVar +from lorax_server.adapters.lora import LoraWeights import torch from loguru import logger from transformers import PreTrainedTokenizerBase @@ -237,6 +238,9 @@ def max_speculative_tokens(self) -> int: def register_preloaded_adapters( self, preloaded_adapters: List[generate_pb2.PreloadedAdapter], adapter_memory_fractions: List[float] ): + if preloaded_adapters is None: + return + self.preloaded_adapter_indices.update({adapter.adapter_index for adapter in preloaded_adapters}) self.preloaded_adapter_memory_fractions.update( { @@ -246,6 +250,34 @@ def register_preloaded_adapters( ) self.preloaded_adapters.extend(preloaded_adapters) + # For Triton kernels: need weights into contiguous tensor + # dict of layer_name -> (lora_a_weights, lora_b_weights) + # where: + # lora_a_weights = [num_adapters, r, hidden_size] + # lora_b_weights = [num_adapters, hidden_size, r] + self.layer_to_lora_weights = {} + for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items(): + lora_a_weights = [] + lora_b_weights = [] + for i, adapter in enumerate(preloaded_adapters): + adapter_index = adapter.adapter_index + adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index) + if not isinstance(adapter_weights, LoraWeights): + # Only applicable to lora for now + continue + + # transpose to ensure col major + lora_a = adapter_weights.weights_a_t + lora_b = adapter_weights.weights_b_t + + lora_a_weights.append(lora_a) + lora_b_weights.append(lora_b) + + # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] + lora_a_weights = torch.stack(lora_a_weights, device=self.device).contiguous() + lora_b_weights = torch.stack(lora_b_weights, device=self.device).contiguous() + self.layer_to_lora_weights[layer_name] = (lora_a_weights, lora_b_weights) + def load_adapter( self, adapter_parameters: AdapterParameters, diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index df49830ea..260457d63 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -451,9 +451,11 @@ def update_metadata( meta: "AdapterBatchMetadata", prefill: bool, ): + # token_lora_indices is adapter_indices - 1 to account for base model offset + self._token_lora_indices = meta.adapter_indices - 1 if prefill: # Update metadata required for prefill-related operators. - self._update_prefill_metada(meta.adapter_indices) + self._update_prefill_metada(self._token_lora_indices) self.is_prefill = True else: self.is_prefill = False From 787072990bb5f9555d81f3831832db17efd6201f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 12:11:13 -0700 Subject: [PATCH 09/76] Plumb weights --- server/lorax_server/adapters/weights.py | 5 ++++- server/lorax_server/models/flash_causal_lm.py | 15 ++++++--------- server/lorax_server/models/model.py | 9 +++++++-- server/lorax_server/utils/layers.py | 9 +++++---- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 4035fa16e..bb7bbc117 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -1,7 +1,7 @@ from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type import torch @@ -117,6 +117,9 @@ class AdapterBatchData: # layer type -> adapter type -> batch weight data data: Dict[str, Dict[str, BatchAdapterWeights]] + # layer type -> fused lora weights + layer_to_lora_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]] + punica_wrapper: "PunicaWrapper" prefill: bool diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 617241b7e..ef828e435 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1472,17 +1472,14 @@ def generate_token( # Assign pointers to adapter weights # TODO(travis): don't update this if indices haven't changed - # self.punica_wrapper.update_metadata( - # adapter_meta, - # prefill, - # len(adapter_meta.adapter_set), - # self.model.config.vocab_size, - # self.model.config.vocab_size, - # None, - # ) self.punica_wrapper.update_metadata(adapter_meta, prefill) adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, self.punica_wrapper, prefill, batch.prefill_head_indices + adapter_meta, + self.layer_to_adapter_weights, + self.layer_to_lora_weights, + self.punica_wrapper, + prefill, + batch.prefill_head_indices ) with timer(f"{stage_str}::generate_token::forward"): diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 87158fb14..f36b15e28 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple, Type, TypeVar from lorax_server.adapters.lora import LoraWeights +from lorax_server.adapters.medusa_lora import MedusaLoraWeights import torch from loguru import logger from transformers import PreTrainedTokenizerBase @@ -263,8 +264,12 @@ def register_preloaded_adapters( adapter_index = adapter.adapter_index adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index) if not isinstance(adapter_weights, LoraWeights): - # Only applicable to lora for now - continue + if isinstance(adapter_weights, MedusaLoraWeights): + # only use lora component + adapter_weights = adapter_weights.lora_weights + else: + # only applicable to lora for now + continue # transpose to ensure col major lora_a = adapter_weights.weights_a_t diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 5ce80d102..b9a0458e0 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -86,14 +86,15 @@ def forward_layer_type( y_slice_size = None for r, rank_segments in data.rank_data.items(): - lora_a_ptr = rank_segments.lora_a_ptr - lora_b_ptr = rank_segments.lora_b_ptr + # lora_a_ptr = rank_segments.lora_a_ptr + # lora_b_ptr = rank_segments.lora_b_ptr + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[layer_type] adapter_data.punica_wrapper.add_lora( result, input, - lora_a_ptr, - lora_b_ptr, + lora_a_weights, + lora_b_weights, 1.0, y_offset, y_slice_size, From 0e057f08740345f42dd0c5e880920ee31cfa4921 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 13:38:43 -0700 Subject: [PATCH 10/76] Fixed issues --- server/lorax_server/adapters/weights.py | 11 ++++- .../custom_modeling/flash_qwen2_modeling.py | 4 +- server/lorax_server/models/model.py | 41 +++++++++++++------ server/lorax_server/utils/layers.py | 4 +- server/lorax_server/utils/ops/sgmv_expand.py | 10 ++++- .../utils/ops/sgmv_expand_slice.py | 12 +++++- server/lorax_server/utils/ops/sgmv_shrink.py | 11 ++++- server/lorax_server/utils/sgmv.py | 11 ++++- 8 files changed, 82 insertions(+), 22 deletions(-) diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index bb7bbc117..f83570bac 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -118,7 +118,7 @@ class AdapterBatchData: data: Dict[str, Dict[str, BatchAdapterWeights]] # layer type -> fused lora weights - layer_to_lora_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]] + layer_to_lora_weights: Dict[Tuple[str, int], Tuple[torch.Tensor, torch.Tensor]] punica_wrapper: "PunicaWrapper" @@ -128,6 +128,7 @@ class AdapterBatchData: def from_meta( meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights], + layer_to_lora_weights: Dict[Tuple[str, int], Tuple[torch.Tensor, torch.Tensor]], punica_wrapper: "PunicaWrapper", prefill: bool, prefill_head_indices: Optional[torch.Tensor], @@ -139,7 +140,13 @@ def from_meta( layer_weights = v.get_data(meta, k, prefill, prefill_head_indices if k == LM_HEAD else None) if layer_weights: data[k] = layer_weights - return AdapterBatchData(meta=meta, data=data, punica_wrapper=punica_wrapper, prefill=prefill) + return AdapterBatchData( + meta=meta, + data=data, + layer_to_lora_weights=layer_to_lora_weights, + punica_wrapper=punica_wrapper, + prefill=prefill, + ) def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 261c6cff5..fc3ac7097 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -303,8 +303,8 @@ def __init__(self, prefix, config, weights, layer_id): layer_id, [MLP_GATE_PROJ, MLP_UP_PROJ], sizes=[ - config.intermediate_size // 2, - config.intermediate_size // 2, + config.intermediate_size, + config.intermediate_size, ], process_group=weights.process_group, ) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index f36b15e28..e9c049b7e 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -5,6 +5,7 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.medusa_lora import MedusaLoraWeights +from lorax_server.utils.sgmv import pad_to_min_rank import torch from loguru import logger from transformers import PreTrainedTokenizerBase @@ -252,14 +253,14 @@ def register_preloaded_adapters( self.preloaded_adapters.extend(preloaded_adapters) # For Triton kernels: need weights into contiguous tensor - # dict of layer_name -> (lora_a_weights, lora_b_weights) + # dict of (layer_name, layer_id) -> (lora_a_weights, lora_b_weights) # where: # lora_a_weights = [num_adapters, r, hidden_size] # lora_b_weights = [num_adapters, hidden_size, r] self.layer_to_lora_weights = {} for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items(): - lora_a_weights = [] - lora_b_weights = [] + layer_id_to_lora_a_weights = defaultdict(list) + layer_id_to_lora_b_weights = defaultdict(list) for i, adapter in enumerate(preloaded_adapters): adapter_index = adapter.adapter_index adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index) @@ -271,17 +272,31 @@ def register_preloaded_adapters( # only applicable to lora for now continue - # transpose to ensure col major - lora_a = adapter_weights.weights_a_t - lora_b = adapter_weights.weights_b_t - - lora_a_weights.append(lora_a) - lora_b_weights.append(lora_b) + # transpose into col major + lora_a = adapter_weights.weights_a.transpose(1, 2) + lora_b = adapter_weights.weights_b.transpose(1, 2) + + nlayers = lora_a.size(0) + for layer_id in range(nlayers): + layer_id_to_lora_a_weights[layer_id].append(lora_a[layer_id]) + layer_id_to_lora_b_weights[layer_id].append(lora_b[layer_id]) - # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] - lora_a_weights = torch.stack(lora_a_weights, device=self.device).contiguous() - lora_b_weights = torch.stack(lora_b_weights, device=self.device).contiguous() - self.layer_to_lora_weights[layer_name] = (lora_a_weights, lora_b_weights) + for layer_id, lora_a_weights in layer_id_to_lora_a_weights.items(): + lora_b_weights = layer_id_to_lora_b_weights[layer_id] + + # right pad every adapter to the max rank + # TODO(travis) + # r = max([w.size(-1) for w in lora_b_weights]) + # lora_a_weights = [pad_to_min_rank(w, 1, r) for w in lora_a_weights] + # lora_b_weights = [pad_to_min_rank(w, 2, r) for w in lora_b_weights] + + # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] + lora_a_weights = torch.stack(lora_a_weights).to(self.device).contiguous() + lora_b_weights = torch.stack(lora_b_weights).to(self.device).contiguous() + print("!!! lora_a_weights", lora_a_weights.shape, layer_name, layer_id) + print("!!! lora_b_weights", lora_b_weights.shape) + # ('self_attn.q_proj', 32) + self.layer_to_lora_weights[(layer_name, layer_id)] = (lora_a_weights, lora_b_weights) def load_adapter( self, diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index b9a0458e0..0f96dc38b 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -76,6 +76,7 @@ def forward_layer_type( data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + print("!!! layer_type", layer_type, "start_idx", start_idx, "end_idx", end_idx, "result", result.shape) if end_idx - start_idx != result.shape[1]: # proj = torch.zeros_like(result[:, start_idx:end_idx]) y_offset = start_idx @@ -89,7 +90,7 @@ def forward_layer_type( # lora_a_ptr = rank_segments.lora_a_ptr # lora_b_ptr = rank_segments.lora_b_ptr - lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[layer_type] + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)] adapter_data.punica_wrapper.add_lora( result, input, @@ -230,6 +231,7 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torc end_idx = offset // self.process_group.size() else: end_idx = result.shape[1] + print("!!! sizes", self.sizes, self.process_group.size()) result = self.forward_layer_type(result, input, adapter_data, layer_name, start_idx, end_idx) diff --git a/server/lorax_server/utils/ops/sgmv_expand.py b/server/lorax_server/utils/ops/sgmv_expand.py index 867e6406f..71d1ef9a0 100644 --- a/server/lorax_server/utils/ops/sgmv_expand.py +++ b/server/lorax_server/utils/ops/sgmv_expand.py @@ -9,7 +9,7 @@ import triton import triton.language as tl -from lorax_server.utils.ops import libentry +from lorax_server.utils.ops.libentry import libentry @libentry() @@ -128,6 +128,14 @@ def sgmv_expand( add_inputs (bool, optional): Defaults to False. adds the final lora results to the output. """ + print("!!! inputs", inputs.shape) + print("!!! lora_b_weights", lora_b_weights.shape) + print("!!! output_tensor", output_tensor.shape) + print("!!! b_seq_start_loc", b_seq_start_loc) + print("!!! seq_len_tensor", seq_len_tensor) + print("!!! lora_indices_tensor", lora_indices_tensor) + print("!!! batches", batches) + print("!!! max_seq_length", max_seq_length) assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ diff --git a/server/lorax_server/utils/ops/sgmv_expand_slice.py b/server/lorax_server/utils/ops/sgmv_expand_slice.py index 021d8639d..2814e26c8 100644 --- a/server/lorax_server/utils/ops/sgmv_expand_slice.py +++ b/server/lorax_server/utils/ops/sgmv_expand_slice.py @@ -9,7 +9,7 @@ import triton import triton.language as tl -from lorax_server.utils.ops import libentry +from lorax_server.utils.ops.libentry import libentry @libentry() @@ -137,6 +137,16 @@ def sgmv_expand_slice( add_inputs (bool, optional): Defaults to False. adds the final lora results to the output.. """ + print("!!! inputs", inputs.shape) + print("!!! lora_b_weights", lora_b_weights.shape) + print("!!! output_tensor", output_tensor.shape) + print("!!! b_seq_start_loc", b_seq_start_loc) + print("!!! seq_len_tensor", seq_len_tensor) + print("!!! lora_indices_tensor", lora_indices_tensor) + print("!!! batches", batches) + print("!!! max_seq_length", max_seq_length) + print("!!! slice_offset", slice_offset) + print("!!! slice_size", slice_size) assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ diff --git a/server/lorax_server/utils/ops/sgmv_shrink.py b/server/lorax_server/utils/ops/sgmv_shrink.py index 584e51980..da7fa04d3 100644 --- a/server/lorax_server/utils/ops/sgmv_shrink.py +++ b/server/lorax_server/utils/ops/sgmv_shrink.py @@ -9,7 +9,7 @@ import triton import triton.language as tl -from lorax_server.utils.ops import libentry +from lorax_server.utils.ops.libentry import libentry @libentry() @@ -131,6 +131,15 @@ def sgmv_shrink( in the batch scaling (float): Scaling factor. """ + print("!!! inputs", inputs.shape) + print("!!! lora_a_weights", lora_a_weights.shape) + print("!!! output_tensor", output_tensor.shape) + print("!!! b_seq_start_loc", b_seq_start_loc) + print("!!! seq_len_tensor", seq_len_tensor) + print("!!! lora_indices_tensor", lora_indices_tensor) + print("!!! batch_size", batches) + print("!!! max_seq_length", max_seq_length) + print("!!! scaling", scaling) assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] assert lora_a_weights.dtype in [ diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 260457d63..62c0f0adc 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -3,10 +3,16 @@ from functools import lru_cache from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union -from lorax_server.utils.ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink import torch import torch.nn.functional as F +from lorax_server.utils.ops.bgmv_expand import bgmv_expand +from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice +from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink +from lorax_server.utils.ops.sgmv_expand import sgmv_expand +from lorax_server.utils.ops.sgmv_expand_slice import sgmv_expand_slice +from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink + if TYPE_CHECKING: from lorax_server.adapters.weights import AdapterBatchMetadata @@ -39,7 +45,10 @@ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: # tensor parallelism will result in effective rank being divided by world_size, # so we need to scale the min rank to offset that effect min_rank = MIN_SGMV_RANK * world_size + return pad_to_min_rank(t, dim, min_rank) + +def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor: # if we're at or below the min rank, pad up to the min rank # otherwise, pad to the nearest multiple of the block size current_rank = t.size(dim) From c8ad4cb37cbf00489f3d2dac0a17e899a805cbeb Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 14:21:01 -0700 Subject: [PATCH 11/76] Fixed cuda graphs --- server/lorax_server/models/flash_causal_lm.py | 2 + server/lorax_server/utils/graph.py | 43 ++++++++++++++++--- server/lorax_server/utils/sgmv.py | 13 ++++-- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ef828e435..bad97276d 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1239,6 +1239,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.num_heads, self.num_kv_heads, self.sliding_window_blocks, + self.layer_to_lora_weights, + self.punica_wrapper, ) graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory() logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024) diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 387ddf86c..8d79c6abb 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -18,7 +18,7 @@ from lorax_server.adapters.types import LORA from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged -from lorax_server.utils.sgmv import BGMV_MAX_RANK +from lorax_server.utils.sgmv import BGMV_MAX_RANK, PunicaWrapper from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER if TYPE_CHECKING: @@ -155,10 +155,13 @@ def get_max_graph_state( adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + adapter_list=[], adapter_set=set(), adapter_segments=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), segment_indices=[], ), + layer_to_lora_weights={}, + punica_wrapper=None, data=adapter_weight_data, prefill=False, ), @@ -198,6 +201,8 @@ def trace( num_kv_heads: int, sliding_window_blocks: Optional[int] = None, traced_adapter_layer_names: Optional[Set[str]] = None, + layer_to_lora_weights: Dict[str, Dict[str, Any]] = {}, + punica_wrapper: Optional[PunicaWrapper] = None, ) -> "GraphWrapper": max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks) @@ -270,6 +275,18 @@ def trace( num_heads=num_heads, num_kv_heads=num_kv_heads, ) + + meta = AdapterBatchMetadata( + adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], + adapter_list=max_input_state.adapter_data.meta.adapter_list, + adapter_set=max_input_state.adapter_data.meta.adapter_set, + adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], + segment_indices=max_input_state.adapter_data.meta.segment_indices, + ) + punica_wrapper.update_metadata( + meta=meta, + prefill=False + ) input_state = GraphState( input_ids=max_input_state.input_ids[:batch_size], @@ -287,12 +304,9 @@ def trace( cache_lengths=cache_lengths, cache_lengths_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( - meta=AdapterBatchMetadata( - adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], - adapter_set=max_input_state.adapter_data.meta.adapter_set, - adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], - segment_indices=max_input_state.adapter_data.meta.segment_indices, - ), + meta=meta, + layer_to_lora_weights=layer_to_lora_weights, + punica_wrapper=punica_wrapper, data=adapter_weight_data, prefill=False, ), @@ -403,6 +417,11 @@ def forward( pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0) pad_and_fill(dest_rank_data.indices, source_rank_data.indices, SEGMENT_PAD_VALUE) + self.input_state.adapter_data.punica_wrapper.update_metadata( + meta=adapter_data.meta, + prefill=False + ) + with self.forward_context( block_tables=self.input_state.block_tables, cu_seqlen_prefill=None, @@ -433,6 +452,8 @@ def __init__( num_heads: int, num_kv_heads: int, sliding_window_blocks: Optional[int] = None, + layer_to_lora_weights: Dict[str, Dict[str, Any]] = {}, + punica_wrapper: Optional[PunicaWrapper] = None, ): self.model = model self.device = device @@ -446,6 +467,8 @@ def __init__( self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.sliding_window_blocks = sliding_window_blocks + self.layer_to_lora_weights = layer_to_lora_weights + self.punica_wrapper = punica_wrapper def can_use_graph( self, @@ -502,6 +525,8 @@ def get_estimated_cache_memory(self) -> int: self.num_kv_heads, self.sliding_window_blocks, self.adapter_layers, # estimate memory assuming all adapters are traced + self.layer_to_lora_weights, + self.punica_wrapper, ) tmp_cache[key] = graph pool = graph.memory_pool @@ -546,6 +571,8 @@ def warmup(self): self.num_kv_heads, self.sliding_window_blocks, self.default_traced_adapter_layers, + self.layer_to_lora_weights, + self.punica_wrapper, ) self.cache[key] = graph pool = graph.memory_pool @@ -595,6 +622,8 @@ def forward( self.num_kv_heads, self.sliding_window_blocks, adapter_data.layer_names(), + self.layer_to_lora_weights, + self.punica_wrapper, ) self.cache[key] = graph diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 62c0f0adc..93f8b20a5 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -461,10 +461,15 @@ def update_metadata( prefill: bool, ): # token_lora_indices is adapter_indices - 1 to account for base model offset - self._token_lora_indices = meta.adapter_indices - 1 + base_indices = meta.adapter_indices - 1 + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + # self._token_lora_indices = base_indices + self.indices_len[0] = base_indices.shape[-1] + if prefill: # Update metadata required for prefill-related operators. - self._update_prefill_metada(self._token_lora_indices) + self._update_prefill_metada(self._token_lora_indices, base_indices.shape[-1]) self.is_prefill = True else: self.is_prefill = False @@ -506,10 +511,10 @@ def _update_base_metadata( self.indices_len[:] = indices_len - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor, indices_len: int) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + batch_size, max_length, no_lora) = compute_meta(token_lora_tensor[:indices_len]) self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( b_seq_start_tensor) From a82eb64cf97a70e348ce67c002a018d8abd68669 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 14:22:26 -0700 Subject: [PATCH 12/76] Remove debug --- server/lorax_server/models/model.py | 3 --- server/lorax_server/utils/layers.py | 2 -- server/lorax_server/utils/ops/sgmv_expand.py | 16 +++++++-------- .../utils/ops/sgmv_expand_slice.py | 20 +++++++++---------- server/lorax_server/utils/ops/sgmv_shrink.py | 19 +++++++++--------- 5 files changed, 28 insertions(+), 32 deletions(-) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index e9c049b7e..7d38de04d 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -293,9 +293,6 @@ def register_preloaded_adapters( # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] lora_a_weights = torch.stack(lora_a_weights).to(self.device).contiguous() lora_b_weights = torch.stack(lora_b_weights).to(self.device).contiguous() - print("!!! lora_a_weights", lora_a_weights.shape, layer_name, layer_id) - print("!!! lora_b_weights", lora_b_weights.shape) - # ('self_attn.q_proj', 32) self.layer_to_lora_weights[(layer_name, layer_id)] = (lora_a_weights, lora_b_weights) def load_adapter( diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 0f96dc38b..4ad18298e 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -76,7 +76,6 @@ def forward_layer_type( data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None if has_sgmv() and data is not None and data.can_vectorize(self.process_group): - print("!!! layer_type", layer_type, "start_idx", start_idx, "end_idx", end_idx, "result", result.shape) if end_idx - start_idx != result.shape[1]: # proj = torch.zeros_like(result[:, start_idx:end_idx]) y_offset = start_idx @@ -231,7 +230,6 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torc end_idx = offset // self.process_group.size() else: end_idx = result.shape[1] - print("!!! sizes", self.sizes, self.process_group.size()) result = self.forward_layer_type(result, input, adapter_data, layer_name, start_idx, end_idx) diff --git a/server/lorax_server/utils/ops/sgmv_expand.py b/server/lorax_server/utils/ops/sgmv_expand.py index 71d1ef9a0..181b92434 100644 --- a/server/lorax_server/utils/ops/sgmv_expand.py +++ b/server/lorax_server/utils/ops/sgmv_expand.py @@ -128,14 +128,14 @@ def sgmv_expand( add_inputs (bool, optional): Defaults to False. adds the final lora results to the output. """ - print("!!! inputs", inputs.shape) - print("!!! lora_b_weights", lora_b_weights.shape) - print("!!! output_tensor", output_tensor.shape) - print("!!! b_seq_start_loc", b_seq_start_loc) - print("!!! seq_len_tensor", seq_len_tensor) - print("!!! lora_indices_tensor", lora_indices_tensor) - print("!!! batches", batches) - print("!!! max_seq_length", max_seq_length) + # print("!!! inputs", inputs.shape) + # print("!!! lora_b_weights", lora_b_weights.shape) + # print("!!! output_tensor", output_tensor.shape) + # print("!!! b_seq_start_loc", b_seq_start_loc) + # print("!!! seq_len_tensor", seq_len_tensor) + # print("!!! lora_indices_tensor", lora_indices_tensor) + # print("!!! batches", batches) + # print("!!! max_seq_length", max_seq_length) assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ diff --git a/server/lorax_server/utils/ops/sgmv_expand_slice.py b/server/lorax_server/utils/ops/sgmv_expand_slice.py index 2814e26c8..1fa1d96de 100644 --- a/server/lorax_server/utils/ops/sgmv_expand_slice.py +++ b/server/lorax_server/utils/ops/sgmv_expand_slice.py @@ -137,16 +137,16 @@ def sgmv_expand_slice( add_inputs (bool, optional): Defaults to False. adds the final lora results to the output.. """ - print("!!! inputs", inputs.shape) - print("!!! lora_b_weights", lora_b_weights.shape) - print("!!! output_tensor", output_tensor.shape) - print("!!! b_seq_start_loc", b_seq_start_loc) - print("!!! seq_len_tensor", seq_len_tensor) - print("!!! lora_indices_tensor", lora_indices_tensor) - print("!!! batches", batches) - print("!!! max_seq_length", max_seq_length) - print("!!! slice_offset", slice_offset) - print("!!! slice_size", slice_size) + # print("!!! inputs", inputs.shape) + # print("!!! lora_b_weights", lora_b_weights.shape) + # print("!!! output_tensor", output_tensor.shape) + # print("!!! b_seq_start_loc", b_seq_start_loc) + # print("!!! seq_len_tensor", seq_len_tensor) + # print("!!! lora_indices_tensor", lora_indices_tensor) + # print("!!! batches", batches) + # print("!!! max_seq_length", max_seq_length) + # print("!!! slice_offset", slice_offset) + # print("!!! slice_size", slice_size) assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ diff --git a/server/lorax_server/utils/ops/sgmv_shrink.py b/server/lorax_server/utils/ops/sgmv_shrink.py index da7fa04d3..fb3a5d6ad 100644 --- a/server/lorax_server/utils/ops/sgmv_shrink.py +++ b/server/lorax_server/utils/ops/sgmv_shrink.py @@ -131,15 +131,16 @@ def sgmv_shrink( in the batch scaling (float): Scaling factor. """ - print("!!! inputs", inputs.shape) - print("!!! lora_a_weights", lora_a_weights.shape) - print("!!! output_tensor", output_tensor.shape) - print("!!! b_seq_start_loc", b_seq_start_loc) - print("!!! seq_len_tensor", seq_len_tensor) - print("!!! lora_indices_tensor", lora_indices_tensor) - print("!!! batch_size", batches) - print("!!! max_seq_length", max_seq_length) - print("!!! scaling", scaling) + # print("!!! inputs", inputs.shape) + # print("!!! lora_a_weights", lora_a_weights.shape) + # print("!!! output_tensor", output_tensor.shape) + # print("!!! b_seq_start_loc", b_seq_start_loc) + # print("!!! seq_len_tensor", seq_len_tensor) + # print("!!! lora_indices_tensor", lora_indices_tensor) + # print("!!! batch_size", batches) + # print("!!! max_seq_length", max_seq_length) + # print("!!! scaling", scaling) + assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] assert lora_a_weights.dtype in [ From f68d2c02e3956b381d34613a73bef1a6d1e56e45 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 14:29:32 -0700 Subject: [PATCH 13/76] Remove debug --- server/lorax_server/models/flash_causal_lm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index bad97276d..2faffda6a 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1630,8 +1630,6 @@ def generate_token( current_prefilling_mask_tensor, ) - print("all_input_ids", batch.all_input_ids_tensor.shape) - with timer(f"generate_token::prefill_logprobs"): if prefill and finished_prefilling: # Used to gather prefill logprobs From 2ffc1db0503aa038bdcb09b0358f797d66be0d96 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 14:37:32 -0700 Subject: [PATCH 14/76] Move init to warmup --- server/lorax_server/models/flash_causal_lm.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2faffda6a..7e69fabf0 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1109,11 +1109,7 @@ def __init__( ) self.steps = 0 - self.punica_wrapper = PunicaWrapper( - max_num_batched_tokens=10000, - max_batches=128, - device=self.device, - ) + self.punica_wrapper = None @property def block_size(self) -> int: @@ -1186,6 +1182,12 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model # The warmup batch is the biggest batch we could ever receive max_total_tokens = batch.max_input_length + max_new_tokens + get_speculative_tokens() + self.punica_wrapper = PunicaWrapper( + max_num_batched_tokens=get_max_prefill_tokens(), + max_batches=256, # TODO(travis): consider how to handle this if we exceed this limit + device=self.device, + ) + torch.cuda.empty_cache() try: self.init_kv_cache( From ea6c86dc865b5d5b7caa20c4d9ab84361467b264 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 15:05:06 -0700 Subject: [PATCH 15/76] Fix preloaded and speculators --- server/lorax_server/models/model.py | 19 +++++++++++-------- server/lorax_server/server.py | 10 +++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 7d38de04d..1c9c9ff3f 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -243,6 +243,7 @@ def register_preloaded_adapters( if preloaded_adapters is None: return + self.dynamic_adapter_loading_enabled = False self.preloaded_adapter_indices.update({adapter.adapter_index for adapter in preloaded_adapters}) self.preloaded_adapter_memory_fractions.update( { @@ -271,6 +272,10 @@ def register_preloaded_adapters( else: # only applicable to lora for now continue + + if adapter_weights is None: + # no weights for this layer + continue # transpose into col major lora_a = adapter_weights.weights_a.transpose(1, 2) @@ -318,10 +323,9 @@ def load_adapter( if dynamic and not self.dynamic_adapter_loading_enabled: raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." + f"This model does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model and remove preloaded adapters " + f"to use the dynamic adapter loading feature." ) logger.info(f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}") @@ -400,10 +404,9 @@ def offload_adapter( if not self.dynamic_adapter_loading_enabled: raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." + f"This model does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model and remove preloaded adapters " + f"to use the dynamic adapter loading feature." ) for layer_name in self.adapter_layers: diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 99f164856..5210440d7 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -336,6 +336,11 @@ async def serve_inner( create_exllama_buffers() except ImportError: pass + + # set speculative decoding tokens + speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) + if speculative_tokens > 0: + set_speculative_tokens(speculative_tokens) if preloaded_adapter_ids: logger.info(f"Preloading {len(preloaded_adapter_ids)} adapters") @@ -408,11 +413,6 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: adapter_memory_fractions = [r.memory_fraction for r in download_responses] model.register_preloaded_adapters(preloaded_adapters, adapter_memory_fractions) - # set speculative decoding tokens - speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) - if speculative_tokens > 0: - set_speculative_tokens(speculative_tokens) - server = aio.server( interceptors=[ ExceptionInterceptor(), From 0497a766538d646db36e1dacbdf797044001cae6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 15:27:17 -0700 Subject: [PATCH 16/76] Docker test --- .github/workflows/build.yaml | 6 ++---- server/lorax_server/utils/graph.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 13b9e96ca..d53289240 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,7 @@ on: push: branches: - 'main' + - 'optimizations' tags: - 'v*' @@ -69,10 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix=,suffix=,format=short - type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=optimizations,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 8d79c6abb..c08d06566 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -267,6 +267,7 @@ def trace( block_tables_ptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) + state = create_decode_state_cuda_graphs( device=max_input_state.input_ids.device, block_tables=block_tables, From 9e2a29d339ef32410c1b6106d1560ef438b92f43 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 24 Oct 2024 15:35:17 -0700 Subject: [PATCH 17/76] Profiling docs --- docs/guides/contributing/index.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/guides/contributing/index.md b/docs/guides/contributing/index.md index 5c78b97f4..5d9eeef27 100644 --- a/docs/guides/contributing/index.md +++ b/docs/guides/contributing/index.md @@ -23,3 +23,22 @@ make export-requirements ``` Never modify `requirements.txt` directly, as it may introduce dependency conflicts. + +## Profiling + +LoRAX supports the [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) to measure performance of LoRAX. + +You can enable profiling when launching LoRAX by setting the `LORAX_PROFILER_DIR` environment variable to the directory +you wish to output the Tensorboard traces to. + +Once initialized, LoRAX will begin recording traces for every request to the server. Because traces can get very large, +we record only the first 10 prefill requests (plus any decode requests between them), then stop recording and write +out the results. A summary will be printed to stdout when this occurs. + +Once you have your traces written to the profiler directory, you can visualize them in Tensorboard using the +[PyTorch Profiler Tensorboard Plugin](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). + +```bash +pip install torch_tb_profiler +tensorboard --logdir=$LORAX_PROFILER_DIR +``` From 94e37424ab4542faa5512b73331bae9f89bcab85 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 25 Oct 2024 14:28:03 -0700 Subject: [PATCH 18/76] Revert timings --- .github/workflows/build.yaml | 6 +- server/lorax_server/models/flash_causal_lm.py | 1122 ++++++++--------- server/lorax_server/server.py | 110 +- server/lorax_server/utils/tokens.py | 2 + 4 files changed, 561 insertions(+), 679 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index d53289240..13b9e96ca 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,7 +5,6 @@ on: push: branches: - 'main' - - 'optimizations' tags: - 'v*' @@ -70,7 +69,10 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix=,suffix=,format=short + type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} - name: Create a hash from tags env: diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 7e69fabf0..0d618ac08 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -30,7 +30,6 @@ from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed from lorax_server.utils.graph import GraphCache from lorax_server.utils.import_utils import get_cuda_free_memory -from lorax_server.utils.profiler import timer from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files @@ -1451,643 +1450,530 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> def generate_token( self, batch: FlashCausalLMBatch, is_warmup: bool = False ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - stage_str = "prefill" if batch.prefilling else "decode" - with timer(f"{stage_str}::generate_token::pre_forward"): - prefill = batch.prefilling - if prefill: - batch.prepare_for_prefill() - prefill_logprobs = batch.prefill_next_token_indices is not None - return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) - - # Update adapter indices for speculative tokens (if present) - adapter_meta = batch.adapter_meta - if batch.speculative_ids is not None: - B, speculative_length = batch.speculative_ids.shape - new_length = speculative_length + 1 - adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) - adapter_segments = adapter_meta.adapter_segments * new_length - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_list=adapter_meta.adapter_list, - adapter_set=adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_meta.segment_indices, - ) - - # Assign pointers to adapter weights - # TODO(travis): don't update this if indices haven't changed - self.punica_wrapper.update_metadata(adapter_meta, prefill) - adapter_data = AdapterBatchData.from_meta( - adapter_meta, - self.layer_to_adapter_weights, - self.layer_to_lora_weights, - self.punica_wrapper, - prefill, - batch.prefill_head_indices + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + prefill_logprobs = batch.prefill_next_token_indices is not None + return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) + + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_list=adapter_meta.adapter_list, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, ) - with timer(f"{stage_str}::generate_token::forward"): - out, speculative_logits = self.forward(batch, adapter_data) + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + self.punica_wrapper.update_metadata(adapter_meta, prefill) + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + self.layer_to_lora_weights, + self.punica_wrapper, + prefill, + batch.prefill_head_indices + ) - with timer(f"{stage_str}::generate_token::post_forward"): - if prefill: - next_token_logits = out[batch.prefill_next_token_indices] if prefill_logprobs else out - if speculative_logits is not None: - speculative_logits = ( - speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits - ) - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - else: - prefill_logprobs = None - next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices - - finished_prefilling = True - next_chunk_lengths = [] - current_prefilling_mask = batch.prefilling_mask - current_prefilling_mask_tensor = batch.prefilling_mask_tensor - if prefill: - if get_supports_chunking(): - next_prefilling_mask = [] - # Budget in tokens for the next batch - # We remove (len(batch) - 1) to always have enough space for at least a single decode - # for the remaining requests -1 because the first request does not need to be removed from the budget - # (ex: you have one request in the batch, you want it to take the full budget not budget -1) - batch_budget = get_max_prefill_tokens() - (len(batch) - 1) - # We reverse to prioritize older requests - # zip() is not reversible so reverse the underlying lists instead - for cache_length, input_length, prompt_length in zip( - reversed(batch.cache_lengths), - reversed(batch.input_lengths), - reversed(batch.prompt_lengths), - ): - remaining_prefill_tokens = max(prompt_length - cache_length - input_length, 0) - if remaining_prefill_tokens > 0: - next_chunk_length = max(min(remaining_prefill_tokens, batch_budget), 1) - batch_budget -= next_chunk_length - finished_prefilling = False - next_prefilling_mask.append(True) - else: - # FIXME: use true number of accepted tokens instead of 1 - # Since speculation will be turned off, this is always true - next_chunk_length = 1 - next_prefilling_mask.append(False) - next_chunk_lengths.append(next_chunk_length) - - # Reverse back the obtained values² - next_chunk_lengths.reverse() - next_prefilling_mask.reverse() - - batch.prefilling_mask_tensor = torch.tensor(next_prefilling_mask, device=batch.all_input_ids_tensor.device) - else: - # The model does not support chunking - # We know we only do a single prefill - finished_prefilling = True - next_prefilling_mask = [False] * len(batch) - batch.prefilling_mask_tensor = None + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + ) - batch.prefilling = not finished_prefilling - batch.prefilling_mask = next_prefilling_mask + out, speculative_logits = self.forward(batch, adapter_data) - with timer(f"generate_token::next_token_chooser"): - speculative_tokens = get_speculative_tokens() - ( - next_input_ids, - next_token_logprobs, - accepted_ids, - speculative_ids, - ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], - next_token_logits, - speculative_tokens, - batch.speculative_ids, - speculative_logits, + if prefill: + next_token_logits = out[batch.prefill_next_token_indices] if prefill_logprobs else out + if speculative_logits is not None: + speculative_logits = ( + speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits ) + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) + else: + prefill_logprobs = None + next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices + + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + if get_supports_chunking(): + next_prefilling_mask = [] + # Budget in tokens for the next batch + # We remove (len(batch) - 1) to always have enough space for at least a single decode + # for the remaining requests -1 because the first request does not need to be removed from the budget + # (ex: you have one request in the batch, you want it to take the full budget not budget -1) + batch_budget = get_max_prefill_tokens() - (len(batch) - 1) + # We reverse to prioritize older requests + # zip() is not reversible so reverse the underlying lists instead + for cache_length, input_length, prompt_length in zip( + reversed(batch.cache_lengths), + reversed(batch.input_lengths), + reversed(batch.prompt_lengths), + ): + remaining_prefill_tokens = max(prompt_length - cache_length - input_length, 0) + if remaining_prefill_tokens > 0: + next_chunk_length = max(min(remaining_prefill_tokens, batch_budget), 1) + batch_budget -= next_chunk_length + finished_prefilling = False + next_prefilling_mask.append(True) + else: + # FIXME: use true number of accepted tokens instead of 1 + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_prefilling_mask.append(False) + next_chunk_lengths.append(next_chunk_length) + + # Reverse back the obtained values² + next_chunk_lengths.reverse() + next_prefilling_mask.reverse() + else: + # The model does not support chunking + # We know we only do a single prefill + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) - if return_alternatives: - alternative_token_logprobs, alternative_token_ids = torch.sort( - torch.log_softmax(next_token_logits, -1), dim=-1, stable=True, descending=True - ) + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask + + speculative_tokens = get_speculative_tokens() + ( + next_input_ids, + next_token_logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_current_length], + next_token_logits, + speculative_tokens, + batch.speculative_ids, + speculative_logits, + ) - with timer(f"{stage_str}::generate_token::new_empty"): - # Since we are done prefilling, all the tensors that were concatenating values for all the requests - # instantly become of shape [BATCH_SIZE] - if prefill and finished_prefilling: - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) - elif not prefill: - next_position_ids = batch.position_ids - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.all_input_ids, - accepted_ids, - current_prefilling_mask, - batch.prefilling_mask, + if return_alternatives: + alternative_token_logprobs, alternative_token_ids = torch.sort( + torch.log_softmax(next_token_logits, -1), dim=-1, stable=True, descending=True ) - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a GPU <-> CPU sync - # It is faster if we delay this sync for the maximum amount of time - - with timer(f"generate_token::cumulative_length"): - if prefill and finished_prefilling: - # Discard first elem, which is 0 - with timer(f"generate_token::cumulative_length::end_index"): - end_index = batch.cu_seqlen_prefill[1:] - - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - with timer(f"generate_token::cumulative_length::next_position_ids"): - next_position_ids[:] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - with timer(f"generate_token::cumulative_length::adapter_indices"): - next_adapter_indices[:] = batch.adapter_meta.adapter_indices[end_index - 1] - - # if not request_is_prefilling: - # # Only save tokens if we are done prefilling for this request - # for j in range(n_accepted_ids): - # batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: + next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) + elif not prefill: + next_position_ids = batch.position_ids + + # Zipped iterator + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.all_input_ids, + accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, + ) - # Only save tokens if we are done prefilling for this request - with timer(f"generate_token::cumulative_length::offsets"): - offsets = batch.cache_lengths_tensor + batch.input_lengths_tensor - - with timer(f"generate_token::cumulative_length::update_all_input_ids_tensor"): - batch.all_input_ids_tensor = update_all_input_ids_tensor( - accepted_ids, - batch.all_input_ids_tensor, - offsets, - next_input_ids, - current_prefilling_mask_tensor, - ) + # We do two for loops as the first one can run completely asynchronously from the GPU while for the second + # one, we need to first do a GPU <-> CPU sync + # It is faster if we delay this sync for the maximum amount of time - with timer(f"generate_token::prefill_logprobs"): - if prefill and finished_prefilling: - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - for i, request in enumerate(batch.requests): - request_was_prefilling = current_prefilling_mask[i] - if request.prefill_logprobs and request_was_prefilling: - # For each member of the batch - index = 0 - - # TODO(travis): tons of d2h copies here make this super slow, should vectorize or do transfer - # up front - cache_length = batch.cache_lengths[i] - input_length = batch.input_lengths[i] - n_accepted_ids = accepted_ids[index] - - print("!!! prefill_logprobs") - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - with timer(f"generate_token::update_values"): - # Update values - # These values can be updated without a GPU -> CPU sync - if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if prefill and prefill_logprobs: - # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) - torch.log_softmax(out, -1, out=out) - prefill_logprobs_tensor = out - prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - - with timer(f"generate_token::find_segments"): - # Does a GPU <-> CPU sync internally - if prefill and finished_prefilling: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) + # For each member of the batch + index = 0 + # Cumulative length + cumulative_length = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + all_input_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, + ) in enumerate(iterator): + if prefill and finished_prefilling: + # Indexing metadata + _start_index = cumulative_length + end_index = cumulative_length + input_length + + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] + + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids - with timer(f"{stage_str}::generate_token::d2h"): - # GPU <-> CPU sync - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids.tolist() - accepted_ids = accepted_ids.tolist() - - if return_alternatives: - alternative_token_logprobs = alternative_token_logprobs.tolist() - alternative_token_ids = alternative_token_ids.tolist() - - # Update values if we need to continue prefilling - # This represents the `else` case of the `Update values` if above - # but since this require the `next_token_ids` to be on CPU, it is better to do it here - if prefill and not finished_prefilling: - with timer(f"generate_token::update_values_continue_prefill"): - # Speculation must be ignored while we prefill even with chunking - # it simplifies everything - assert batch.speculative_ids is None - - all_postfix_ids = [] - for i, ( - request_prefilling, - next_token_id, - all_input_ids, - cache_length, - input_length, - next_chunk_length, - ) in enumerate( - zip( - batch.prefilling_mask, - next_token_ids, - batch.all_input_ids, - batch.cache_lengths, - batch.input_lengths, - next_chunk_lengths, - ) - ): - if request_prefilling: - next_cache_length = cache_length + input_length - # Get new prompt IDs to prefill - postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] - else: - # This request is done prefilling, the new id is the one selected the sampling method - postfix_ids = [next_token_id] - - all_postfix_ids.append(postfix_ids) - - batch.input_ids = all_postfix_ids - - with timer(f"generate_token::get_results"): - # Results - generations: List[Generation] = [] - stopped = not is_warmup - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, + if not request_is_prefilling: + # Only save tokens if we are done prefilling for this request + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] + + batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + + index += n_accepted_ids + cumulative_length += input_length + + # Update values + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill and prefill_logprobs: + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out + prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + + # GPU <-> CPU sync + next_token_logprobs = next_token_logprobs.tolist() + next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() + + if return_alternatives: + alternative_token_logprobs = alternative_token_logprobs.tolist() + alternative_token_ids = alternative_token_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, batch.cache_lengths, batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - batch.stopping_criterias, - batch.all_input_ids, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - current_prefilling_mask, - batch.prefilling_mask, - accepted_ids, + next_chunk_lengths, ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + + # Results + generations: List[Generation] = [] + stopped = not is_warmup + + # Zipped iterator + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + batch.stopping_criterias, + batch.all_input_ids, + batch.next_token_chooser.do_sample, + batch.next_token_chooser.seeds, + current_prefilling_mask, + batch.prefilling_mask, + accepted_ids, + ) - # Reset max_input_length - batch.max_input_length = 0 - # For each member of the batch - index = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - request_was_prefilling, - request_is_prefilling, - n_accepted_ids, - ) in enumerate(iterator): - all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None - - # TODO(travis): return_k_alternatives - # if request.parameters.return_k_alternatives > 0: - # # Limit the number of alternatives to the vocabulary size - # num_alternatives = min( - # request.parameters.return_k_alternatives, - # len(alternative_token_ids[token_idx]), - # ) - - # # Select top-k logprobs - # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] - # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] - - # # Decode tokens - # request_alternative_token_texts = [] - # for alternative_token_id in request_alternative_token_ids: - # all_input_ids.append(alternative_token_id) - # alternative_token_text, _, _ = self.decode_token( - # all_input_ids, - # prefix_offset, - # read_offset, - # ) - # request_alternative_token_texts.append(alternative_token_text) - # all_input_ids.pop() - # alternative_tokens = AlternativeTokens( - # request_alternative_token_ids, - # request_alternative_token_logprobs, - # request_alternative_token_texts, - # ) - # all_alternative_tokens.append(alternative_tokens) - - # Compute logprobs first as, even though we might skip the token, - # it can still be required to compute the logprobs - # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need - # this state to be stable - if request.id % self.world_size == self.rank: - # Prefill - if request_was_prefilling and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - if not request_is_prefilling: - # The request is dones prefilling, meaning that we started generating new tokens - # The last logprob is a logprob for a generated token that was not part of the prompt - # We need to remove it - out_end_index -= 1 - - request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] - - past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] - - if past_prefill_logprob_tokens is None: - # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs - prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids - - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - - prefill_logprob_tokens = NextTokens( - prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - [], - all_alternative_tokens, - ) - if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens - - batch.prefill_logprob_tokens[i] = prefill_logprob_tokens - else: - batch.prefill_logprob_tokens[i] = None - - # If it is, the tokens we decoded should be ignored - if request_is_prefilling: - # Make sure that we do not stop as even though this request did not create a token, it is still - # processing - stopped = False - new_input_length = next_chunk_lengths[i] - else: - new_input_length = n_accepted_ids - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - logger.debug(f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] - - # Shard generations - # All generations will be appended in the rust sharded client - if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # TODO(travis): top tokens - # if top_n_tokens > 0: - # all_top_tokens = [] - # for top_token_ids, top_token_logprobs in zip( - # top_token_ids, top_token_logprobs - # ): - # toptoken_texts = self.tokenizer.batch_decode( - # top_token_ids, - # clean_up_tokenization_spaces=False, - # skip_special_tokens=False, - # ) - # special_toptokens = [ - # token_id in self.all_special_ids - # for token_id in top_token_ids - # ] - # top_tokens = Tokens( - # top_token_ids, - # top_token_logprobs, - # toptoken_texts, - # special_toptokens, - # ) - # all_top_tokens.append(top_tokens) - # top_tokens = all_top_tokens - # else: - # top_tokens = None - - generation = Generation( - request.id, - batch.prefill_logprob_tokens[i], - len(all_input_ids[:-1]) if prefill else 0, - NextTokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - all_alternative_tokens, - ), - generated_text, - ) - - generations.append(generation) - - # advance the FSM for each accepted token (as we may have more than one from speculative decoding) - for next_token_id in _next_token_ids: - batch.next_token_chooser.next_state(i, next_token_id) - - with timer(f"generate_token::update_remaining_values"): - # Update values - index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length - batch.max_current_length = max(batch.max_current_length, current_length) - - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.all_input_ids[i] = all_input_ids - - if stopped: - # No need to return a batch if we know that all requests stopped - return generations, None - - if prefill and finished_prefilling: - # We do not need prefill tensors anymore - batch.cu_seqlen_prefill = None - batch.prefill_cache_indices = None - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None - - return generations, batch - - -def update_all_input_ids_tensor( - accepted_ids, - all_input_ids_tensor, - offsets, - next_input_ids, - current_prefilling_mask -): - # Get batch size - batch_size = all_input_ids_tensor.size(0) - # Calculate S (number of candidate tokens per batch) - S = next_input_ids.size(0) // batch_size - - # Reshape next_input_ids to [batch_size, S] - next_input_ids = next_input_ids.view(batch_size, S) - - # Since accepted_ids is always 1, we only need the first candidate token for each batch - values = next_input_ids[:, 0] - - # Update all_input_ids_tensor at the specified positions with the accepted IDs - all_input_ids_tensor[torch.arange(batch_size), offsets] = values - - return all_input_ids_tensor - - # NO SPECULATION - # # Get batch size - # batch_size = all_input_ids_tensor.size(0) - # # Calculate S (number of candidate tokens per batch) - # S = next_input_ids.size(0) // batch_size - - # # Reshape next_input_ids to [batch_size, S] - # next_input_ids = next_input_ids.view(batch_size, S) - - # # Select indices of batches to process based on current_prefilling_mask - # batch_indices = torch.nonzero(current_prefilling_mask, as_tuple=False).squeeze(1) - - # # Gather offsets and next_input_ids for the selected batches - # offsets_selected = offsets[batch_indices] - # # Since accepted_ids is always 1, we only need the first candidate token - # values = next_input_ids[batch_indices, 0] - - # # Update all_input_ids_tensor at the specified positions with the accepted IDs - # all_input_ids_tensor[batch_indices, offsets_selected] = values - - # return all_input_ids_tensor - - # FULL - # # Get batch size and compute S (number of candidate tokens per batch) - # batch_size = accepted_ids.size(0) - # S = next_input_ids.size(0) // batch_size - - # # Reshape next_input_ids to [batch_size, S] - # next_input_ids = next_input_ids.view(batch_size, S) - - # # Select indices of batches to process based on the current_prefilling_mask - # batch_indices = torch.nonzero(current_prefilling_mask, as_tuple=False).squeeze(1) - # num_batches = batch_indices.size(0) - - # # Gather the accepted_ids, offsets, and next_input_ids for the selected batches - # accepted_ids_selected = accepted_ids[batch_indices] - # offsets_selected = offsets[batch_indices] - # next_input_ids_selected = next_input_ids[batch_indices] - - # # Determine the maximum number of accepted IDs to pad sequences - # max_accepted_ids = accepted_ids_selected.max() - - # # Create sequence indices offsets for each batch - # seq_indices_offsets = torch.arange(max_accepted_ids, device=accepted_ids.device).unsqueeze(0) - # seq_indices_offsets = seq_indices_offsets.expand(num_batches, -1) - - # # Create a mask to identify valid positions within accepted_ids for each batch - # seq_mask = seq_indices_offsets < accepted_ids_selected.unsqueeze(1) + # Reset max_input_length + batch.max_input_length = 0 + # For each member of the batch + index = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + prefix_offset, + read_offset, + stopping_criteria, + all_input_ids, + do_sample, + seed, + request_was_prefilling, + request_is_prefilling, + n_accepted_ids, + ) in enumerate(iterator): + all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None + + # TODO(travis): return_k_alternatives + # if request.parameters.return_k_alternatives > 0: + # # Limit the number of alternatives to the vocabulary size + # num_alternatives = min( + # request.parameters.return_k_alternatives, + # len(alternative_token_ids[token_idx]), + # ) + + # # Select top-k logprobs + # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] + # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] + + # # Decode tokens + # request_alternative_token_texts = [] + # for alternative_token_id in request_alternative_token_ids: + # all_input_ids.append(alternative_token_id) + # alternative_token_text, _, _ = self.decode_token( + # all_input_ids, + # prefix_offset, + # read_offset, + # ) + # request_alternative_token_texts.append(alternative_token_text) + # all_input_ids.pop() + # alternative_tokens = AlternativeTokens( + # request_alternative_token_ids, + # request_alternative_token_logprobs, + # request_alternative_token_texts, + # ) + # all_alternative_tokens.append(alternative_tokens) + + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: + # Prefill + if request_was_prefilling and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs + prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids + + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) - # # Calculate the sequence indices where updates will occur - # seq_indices = seq_indices_offsets + offsets_selected.unsqueeze(1) + prefill_logprob_tokens = NextTokens( + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + [], + all_alternative_tokens, + ) + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens - # # Expand batch indices to align with seq_indices - # batch_indices_expanded = batch_indices.unsqueeze(1).expand(-1, max_accepted_ids) + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens + else: + batch.prefill_logprob_tokens[i] = None + + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + else: + new_input_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 + + if n_accepted_ids > 1: + logger.debug(f"speculated ids {n_accepted_ids - 1}") + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - # # Extract the values to be written into all_input_ids_tensor - # values = next_input_ids_selected[:, :max_accepted_ids] + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - # # Flatten tensors and apply the mask to select valid positions - # batch_indices_flat = batch_indices_expanded.reshape(-1)[seq_mask.reshape(-1)] - # seq_indices_flat = seq_indices.reshape(-1)[seq_mask.reshape(-1)] - # values_flat = values.reshape(-1)[seq_mask.reshape(-1)] + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # TODO(travis): top tokens + # if top_n_tokens > 0: + # all_top_tokens = [] + # for top_token_ids, top_token_logprobs in zip( + # top_token_ids, top_token_logprobs + # ): + # toptoken_texts = self.tokenizer.batch_decode( + # top_token_ids, + # clean_up_tokenization_spaces=False, + # skip_special_tokens=False, + # ) + # special_toptokens = [ + # token_id in self.all_special_ids + # for token_id in top_token_ids + # ] + # top_tokens = Tokens( + # top_token_ids, + # top_token_logprobs, + # toptoken_texts, + # special_toptokens, + # ) + # all_top_tokens.append(top_tokens) + # top_tokens = all_top_tokens + # else: + # top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + len(all_input_ids[:-1]) if prefill else 0, + NextTokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + all_alternative_tokens, + ), + generated_text, + ) - # # Update all_input_ids_tensor at the specified positions with the accepted IDs - # all_input_ids_tensor[batch_indices_flat, seq_indices_flat] = values_flat + generations.append(generation) + + # advance the FSM for each accepted token (as we may have more than one from speculative decoding) + for next_token_id in _next_token_ids: + batch.next_token_chooser.next_state(i, next_token_id) + + # Update values + index += n_accepted_ids + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.all_input_ids[i] = all_input_ids + + if stopped: + # No need to return a batch if we know that all requests stopped + return generations, None + + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None - # return all_input_ids_tensor + return generations, batch diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5210440d7..bb3b99f1a 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -23,7 +23,6 @@ enum_string_to_adapter_source, is_base_model, ) -from lorax_server.utils.profiler import timer from lorax_server.utils.sgmv import has_sgmv from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens @@ -95,40 +94,36 @@ async def Warmup(self, request: generate_pb2.WarmupRequest, context): return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) async def Prefill(self, request: generate_pb2.PrefillRequest, context): - with timer("prefill::total"): - with timer("prefill::batch::from_pb"): - batch = self.model.batch_type.from_pb( - request.batch, - self.model.tokenizer, - self.model.tokenizers, - self.model.processor, - self.model.model.config, - self.model.dtype, - self.model.device, - ) + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.tokenizers, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) - if self.model.supports_chunking: - if request.HasField("cached_batch"): - with timer("prefill::batch::concatenate"): - cached_batch = self.cache.pop(request.cached_batch.id) - if cached_batch is None: - raise ValueError(f"Batch ID {request.cached_batch.id} not found in cache.") - batch = self.model.batch_type.concatenate([cached_batch, batch]) - - with timer("prefill::generate_token"): - generations, next_batch = self.model.generate_token(batch) - self.cache.set(next_batch) + if self.model.supports_chunking: + if request.HasField("cached_batch"): + cached_batch = self.cache.pop(request.cached_batch.id) + if cached_batch is None: + raise ValueError(f"Batch ID {request.cached_batch.id} not found in cache.") + batch = self.model.batch_type.concatenate([cached_batch, batch]) + + generations, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) - if self.model.profiler: - self.model.steps += 1 - if self.model.steps == 10: - self.model.profiler.stop() - print(self.model.profiler.key_averages()) - - return generate_pb2.PrefillResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) + if self.model.profiler: + self.model.steps += 1 + if self.model.steps == 10: + self.model.profiler.stop() + print(self.model.profiler.key_averages()) + + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) async def Classify(self, request: generate_pb2.ClassifyRequest, context): if not self.model.supports_classification: @@ -165,34 +160,31 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): return embeddings_pb async def Decode(self, request: generate_pb2.DecodeRequest, context): - with timer("decode::total"): - if len(request.batches) == 0: - raise ValueError("Must provide at least one batch") - - batches = [] - for batch_pb in request.batches: - batch = self.cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) - - if len(batches) == 0: - raise ValueError("All batches are empty") - - if len(batches) > 1: - with timer("decode::batch::concatenate"): - batch = self.model.batch_type.concatenate(batches) - else: - batch = batches[0] + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") - with timer("decode::generate_token"): - generations, next_batch = self.model.generate_token(batch) - self.cache.set(next_batch) + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) - return generate_pb2.DecodeResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) + if len(batches) == 0: + raise ValueError("All batches are empty") + + if len(batches) > 1: + batch = self.model.batch_type.concatenate(batches) + else: + batch = batches[0] + + generations, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context): if ( diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index f9cf935f9..477aee667 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -354,6 +354,8 @@ def __call__( S = 1 scores = scores.view(B, S, -1) + # print("!!! scores", scores.shape, B, S) + # print("!!! scores", scores.norm()) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) with self.schema_processor.restore_state() if self.schema_processor is not None else nullcontext(): for j in range(S): From 0abeccc4174b36b3483abac4f0a42d43b4b9aed6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 25 Oct 2024 14:33:26 -0700 Subject: [PATCH 19/76] Fixed merge --- server/lorax_server/models/flash_causal_lm.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 0d618ac08..0a1b97b25 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1483,12 +1483,6 @@ def generate_token( batch.prefill_head_indices ) - # Assign pointers to adapter weights - # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices - ) - out, speculative_logits = self.forward(batch, adapter_data) if prefill: From 6f5a976f3291ac7aa58797fab5c8776a8a51e247 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 25 Oct 2024 21:46:51 -0700 Subject: [PATCH 20/76] Added LORAX_SPECULATION_MAX_BATCH_SIZE --- server/lorax_server/adapters/medusa.py | 7 ++++--- server/lorax_server/models/flash_causal_lm.py | 5 ++++- server/lorax_server/server.py | 4 +++- server/lorax_server/utils/state.py | 10 +++++++++- server/lorax_server/utils/tokens.py | 3 ++- 5 files changed, 22 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 68838e760..9c3dcaff5 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type +from loguru import logger import torch import torch.distributed @@ -10,7 +11,7 @@ from lorax_server.layers import FastLinear, TensorParallelColumnLinear from lorax_server.utils.segments import find_segments from lorax_server.utils.sgmv import segmented_matmul -from lorax_server.utils.state import get_speculative_tokens +from lorax_server.utils.state import LORAX_SPECULATION_MAX_BATCH_SIZE, get_speculative_tokens from lorax_server.utils.weights import AbstractWeights, InMemoryWeights if TYPE_CHECKING: @@ -21,7 +22,6 @@ _MEDUSA_ENABLED = False - @dataclass class MedusaConfig(AdapterConfig): medusa_num_heads: int @@ -159,7 +159,8 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights): def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None): # If we have too many tokens, we skip speculative logits - if x.shape[0] > 128: + if x.shape[0] > LORAX_SPECULATION_MAX_BATCH_SIZE: + logger.info(f"Skipping speculation at batch size = {x.shape[0]}") logits = lm_head(x) return logits, None diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 0a1b97b25..fac23f539 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -695,8 +695,11 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch sequence_processors=sequence_processors, ) + # Discard speculative IDs if they are not present in all batches speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None + torch.cat( + [b.speculative_ids for b in batches], dim=0) + if all(b.speculative_ids is not None for b in batches) else None ) if adapter_segment_builder is not None: diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index bb3b99f1a..cd41509cc 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -332,7 +332,9 @@ async def serve_inner( # set speculative decoding tokens speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) if speculative_tokens > 0: - set_speculative_tokens(speculative_tokens) + # Only use ngram speculation if the model does not support speculative tokens itself + use_ngram = model.max_speculative_tokens == 0 + set_speculative_tokens(speculative_tokens, use_ngram=use_ngram) if preloaded_adapter_ids: logger.info(f"Preloading {len(preloaded_adapter_ids)} adapters") diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index d6b3cea9f..5566208b8 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -6,11 +6,13 @@ WARMUP = False SPECULATIVE_TOKENS = 0 +NGRAM = False LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None) PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) CHUNKED_PREFILL = bool(os.environ.get("CHUNKED_PREFILL", "")) +LORAX_SPECULATION_MAX_BATCH_SIZE = int(os.environ.get("LORAX_SPECULATION_MAX_BATCH_SIZE", 32)) # Always use flashinfer when prefix caching is enabled FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING @@ -54,15 +56,21 @@ def warmup_mode(): set_warmup(False) -def set_speculative_tokens(value: int): +def set_speculative_tokens(value: int, use_ngram: bool): global SPECULATIVE_TOKENS + global NGRAM SPECULATIVE_TOKENS = value + NGRAM = use_ngram def get_speculative_tokens() -> int: return SPECULATIVE_TOKENS +def use_ngram() -> bool: + return NGRAM + + def set_supports_chunking(supports_chunking: bool): global SUPPORTS_CHUNKING SUPPORTS_CHUNKING = supports_chunking diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 477aee667..1946bbd25 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -3,6 +3,7 @@ from contextlib import nullcontext from typing import List, Optional, Set, Tuple, Union +from lorax_server.utils.state import use_ngram import torch from transformers import ( PreTrainedTokenizerBase, @@ -421,7 +422,7 @@ def __call__( if speculative_scores is not None: # Only use greedy sampling for speculative tokens speculative_ids = Greedy()(speculative_scores) - else: + elif use_ngram(): speculative_ids = ngram_speculate(input_ids, next_ids, accepted_ids, speculate) return next_ids, next_logprobs, accepted_ids, speculative_ids From f89ee8765d637d22d2e47e713c4a33ff91268536 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 26 Oct 2024 21:48:55 -0700 Subject: [PATCH 21/76] Try separate trees per adapter --- router/src/radix.rs | 80 +++++++++++-------- server/lorax_server/models/flash_causal_lm.py | 2 + 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/router/src/radix.rs b/router/src/radix.rs index 243df370a..068311c51 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -8,14 +8,10 @@ use std::{ fn hash(adapter_index: u32, slice: &[u32]) -> u64 { assert!(!slice.is_empty()); - if slice.len() == 1 && adapter_index == 0 { - slice[0] as u64 - } else { - let mut s = std::hash::DefaultHasher::new(); - adapter_index.hash(&mut s); - slice.hash(&mut s); - s.finish() - } + let mut s = std::hash::DefaultHasher::new(); + adapter_index.hash(&mut s); + slice.hash(&mut s); + s.finish() } pub struct RadixAllocator { @@ -93,13 +89,13 @@ impl Allocator for RadixAllocator { .find(adapter_index, prefill_tokens.as_slice(), &mut blocks); node_id } else { - self.cache_blocks.root_id() + self.cache_blocks.get_or_create_root(adapter_index) }; // Even if this allocation fails below, we need to increase he // refcount to ensure that the prefix that was found is not evicted. self.cache_blocks - .incref(prefix_node) + .incref(adapter_index, prefix_node) .expect("Failed to increment refcount"); let prefix_len = blocks.len() * self.block_size as usize; @@ -116,7 +112,7 @@ impl Allocator for RadixAllocator { tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens"); tracing::debug!("Block size {}", self.block_size); self.cache_blocks - .decref(prefix_node) + .decref(adapter_index, prefix_node) .expect("Failed to decrement refcount"); return None; } @@ -164,7 +160,7 @@ impl Allocator for RadixAllocator { }; self.cache_blocks - .decref(allocation.prefix_node) + .decref(allocation.adapter_index, allocation.prefix_node) .expect("Failed to decrement refcount"); if let Some(prefill_tokens) = allocation.prefill_tokens { @@ -241,8 +237,8 @@ pub type NodeId = DefaultKey; #[derive(Debug)] pub struct RadixTrie { - /// Identifier of the root nod. - root: DefaultKey, + /// Adapter index --> Identifier of the root node. + roots: HashMap, /// Leave node identifiers ordered by increasing recency. leaves: BTreeSet<(u64, NodeId)>, @@ -261,13 +257,13 @@ pub struct RadixTrie { impl RadixTrie { /// Construct a new radix trie. pub fn new(block_size: usize) -> Self { - let root = TrieNode::new(vec![], vec![], 0, None); - let mut nodes = SlotMap::new(); - let root = nodes.insert(root); + let nodes = SlotMap::new(); + let roots = HashMap::new(); + RadixTrie { leaves: BTreeSet::new(), nodes, - root, + roots, time: 0, block_size, } @@ -284,7 +280,7 @@ impl RadixTrie { /// Using this method will update the access time of the traversed nodes. pub fn find(&mut self, adapter_index: u32, key: &[u32], blocks: &mut Vec) -> NodeId { self.time += 1; - self.find_(adapter_index, self.root, key, blocks) + self.find_(adapter_index, self.root_id(adapter_index), key, blocks) } /// Find worker. @@ -317,10 +313,10 @@ impl RadixTrie { } /// Decrease the reference count of a node. - pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + pub fn decref(&mut self, adapter_index: u32, node_id: NodeId) -> Result<(), TrieError> { // We don't care about refcounting for root, since it will never // be evicted. - if node_id == self.root { + if node_id == self.root_id(adapter_index) { return Ok(()); } @@ -346,8 +342,8 @@ impl RadixTrie { } /// Increase the reference count of a node. - pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { - if node_id == self.root { + pub fn incref(&mut self, adapter_index: u32, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root_id(adapter_index) { return Ok(()); } @@ -382,7 +378,7 @@ impl RadixTrie { let blocks_needed = n_blocks.saturating_sub(evicted.len()); tracing::debug!("Evicting node {node_id:?} "); - let node = self.nodes.get(node_id).expect("Leave does not exist"); + let node = self.nodes.get(node_id).expect("Leaf does not exist"); assert_eq!( node.ref_count, 0, "Leaf must have refcount of 0, got {}", @@ -401,7 +397,7 @@ impl RadixTrie { // The node has more blocks than needed, so we'll just remove // the required number of blocks and leave the remaining blocks // untouched. - let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + let node = self.nodes.get_mut(node_id).expect("Leaf does not exist"); let truncate_blocks = node.blocks.len() - blocks_needed; let truncate_tokens = truncate_blocks * self.block_size; @@ -427,7 +423,8 @@ impl RadixTrie { blocks: &[u32], ) -> Result { self.time += 1; - let common = self.insert_(adapter_index, self.root, tokens, blocks)?; + let node_id = self.get_or_create_root(adapter_index); + let common = self.insert_(adapter_index, node_id, tokens, blocks)?; Ok(common) } @@ -507,7 +504,7 @@ impl RadixTrie { let grandparent_id = node.parent.expect("Node does not have a parent"); let parent_id = self.add_node(adapter_index, grandparent_id, parent_key, parent_blocks); - self.add_node_to_parent(parent_id, node_key, node_id); + self.add_node_to_parent(adapter_index, parent_id, node_key, node_id); // Reborrow to make the borrow checker happy. let node = self @@ -534,19 +531,25 @@ impl RadixTrie { let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); let child_id = self.nodes.insert(child); - self.add_node_to_parent(parent_id, first, child_id); + self.add_node_to_parent(adapter_index, parent_id, first, child_id); self.leaves.insert((self.time, child_id)); child_id } /// Add a node to the parent. - fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) { + fn add_node_to_parent( + &mut self, + adapter_index: u32, + parent_id: NodeId, + hash: u64, + child_id: NodeId, + ) { // Unwrap here, passing in an unknown id is a programming error. let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); if parent.children.insert(hash, child_id).is_none() { // Only increase reference count if child does not replace another child. - self.incref(parent_id) + self.incref(adapter_index, parent_id) .expect("Failed to increase parent refcount"); } } @@ -565,7 +568,7 @@ impl RadixTrie { let node_key = hash(adapter_index, &node.key[..self.block_size]); parent.children.remove(&node_key); - self.decref(parent_id) + self.decref(adapter_index, parent_id) .expect("Failed to decrease parent refcount"); node } @@ -587,8 +590,8 @@ impl RadixTrie { /// Print debugging output for the trie. /// /// In contrast to `Debug` nicely formatted. - pub fn print_debug(&self) { - self.print_debug_(self.root, 0); + pub fn print_debug(&self, adapter_index: u32) { + self.print_debug_(self.root_id(adapter_index), 0); } fn print_debug_(&self, node_id: NodeId, indent: usize) { @@ -609,8 +612,15 @@ impl RadixTrie { } } - pub(crate) fn root_id(&self) -> DefaultKey { - self.root + fn get_or_create_root(&mut self, adapter_index: u32) -> DefaultKey { + *self.roots.entry(adapter_index).or_insert_with(|| { + let root = TrieNode::new(vec![], vec![], 0, None); + self.nodes.insert(root) + }) + } + + pub(crate) fn root_id(&self, adapter_index: u32) -> DefaultKey { + self.roots[&adapter_index] } } diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index fac23f539..e3f41ec11 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -696,6 +696,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) # Discard speculative IDs if they are not present in all batches + if not all(b.speculative_ids is not None for b in batches): + print("!!! CONCATENATE -- discard speculative_ids") speculative_ids = ( torch.cat( [b.speculative_ids for b in batches], dim=0) From 23a77d2850fe9dfc50c2f673d1cc07dd028d45db Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 26 Oct 2024 22:29:34 -0700 Subject: [PATCH 22/76] Allow refcount==0 --- router/src/radix.rs | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/router/src/radix.rs b/router/src/radix.rs index 068311c51..1c0627e80 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -6,12 +6,23 @@ use std::{ sync::Arc, }; +// fn hash(adapter_index: u32, slice: &[u32]) -> u64 { +// assert!(!slice.is_empty()); +// let mut s = std::hash::DefaultHasher::new(); +// adapter_index.hash(&mut s); +// slice.hash(&mut s); +// s.finish() +// } + fn hash(adapter_index: u32, slice: &[u32]) -> u64 { assert!(!slice.is_empty()); - let mut s = std::hash::DefaultHasher::new(); - adapter_index.hash(&mut s); - slice.hash(&mut s); - s.finish() + if slice.len() == 1 { + slice[0] as u64 + } else { + let mut s = std::hash::DefaultHasher::new(); + slice.hash(&mut s); + s.finish() + } } pub struct RadixAllocator { @@ -82,6 +93,9 @@ impl Allocator for RadixAllocator { tokens: u32, prefill_tokens: Option>>, ) -> Option { + // ensure root node exists + self.cache_blocks.get_or_create_root(adapter_index); + let mut blocks = vec![]; let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { let node_id = @@ -89,7 +103,7 @@ impl Allocator for RadixAllocator { .find(adapter_index, prefill_tokens.as_slice(), &mut blocks); node_id } else { - self.cache_blocks.get_or_create_root(adapter_index) + self.cache_blocks.root_id(adapter_index) }; // Even if this allocation fails below, we need to increase he @@ -325,7 +339,8 @@ impl RadixTrie { .get_mut(node_id) .ok_or(TrieError::InvalidNodeId)?; if node.ref_count == 0 { - return Err(TrieError::RefCountUnderflow); + // return Err(TrieError::RefCountUnderflow); + return Ok(()); } node.ref_count -= 1; From 22ed54da3e2356c2d4349ce16ac55c3e4c49daba Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 28 Oct 2024 08:32:22 -0700 Subject: [PATCH 23/76] Message --- server/lorax_server/models/flash_causal_lm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e3f41ec11..ed6921f35 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -696,12 +696,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) # Discard speculative IDs if they are not present in all batches - if not all(b.speculative_ids is not None for b in batches): - print("!!! CONCATENATE -- discard speculative_ids") + keep_speculative_ids = all(b.speculative_ids is not None for b in batches) + if not keep_speculative_ids: + logger.info("Discarding speculative IDs, not every batch has them") + speculative_ids = ( torch.cat( [b.speculative_ids for b in batches], dim=0) - if all(b.speculative_ids is not None for b in batches) else None + if keep_speculative_ids else None ) if adapter_segment_builder is not None: From 327bb91cf03bb8038711200a83f4c72265999c88 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 28 Oct 2024 08:32:57 -0700 Subject: [PATCH 24/76] Docker test --- .github/workflows/build.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 13b9e96ca..2255cc086 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,7 @@ on: push: branches: - 'main' + - 'optimizations' tags: - 'v*' @@ -69,10 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix=,suffix=,format=short - type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=optimizations-2,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: From fbb2b3f87561762674089237a3283140e75d08b1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 28 Oct 2024 09:09:16 -0700 Subject: [PATCH 25/76] Cleanup --- server/lorax_server/adapters/weights.py | 5 --- server/lorax_server/models/flash_causal_lm.py | 12 ------ server/lorax_server/models/model.py | 20 ++++++++-- server/lorax_server/server.py | 17 ++++---- server/lorax_server/utils/profiler.py | 39 ------------------- server/lorax_server/utils/tokens.py | 2 - 6 files changed, 23 insertions(+), 72 deletions(-) delete mode 100644 server/lorax_server/utils/profiler.py diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index f83570bac..d655ed865 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -31,11 +31,6 @@ class AdapterBatchMetadata: # segment_indices[s] == adapter_indices[i] segment_indices: List[int] - @property - def token_indices(self) -> torch.Tensor: - # Create the `token_indices` by repeating each segment index by the number of tokens in it - return torch.cat([torch.full((count,), self.adapter_indices[idx], dtype=torch.long) for idx, count in enumerate(self.segment_indices)]) - class AdapterWeights(ABC): @abstractclassmethod diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ed6921f35..fdb2a731d 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1103,18 +1103,6 @@ def __init__( num_kv_heads=self.num_kv_heads, ) - self.profiler = None - if LORAX_PROFILER_DIR is not None: - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler(LORAX_PROFILER_DIR, use_gzip=True) - ) - self.steps = 0 - self.punica_wrapper = None @property diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 1c9c9ff3f..c774235aa 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -24,6 +24,7 @@ BLOCK_SIZE, CHUNKED_PREFILL, FLASH_INFER, + LORAX_PROFILER_DIR, get_speculative_tokens, set_supports_chunking, ) @@ -116,6 +117,18 @@ def __init__( self.check_initialized() + self.profiler = None + if LORAX_PROFILER_DIR is not None: + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler(LORAX_PROFILER_DIR, use_gzip=True) + ) + self.profiler_steps = 0 + @property def info(self) -> InfoResponse: if self.requires_padding and self.sliding_window is not None: @@ -290,10 +303,9 @@ def register_preloaded_adapters( lora_b_weights = layer_id_to_lora_b_weights[layer_id] # right pad every adapter to the max rank - # TODO(travis) - # r = max([w.size(-1) for w in lora_b_weights]) - # lora_a_weights = [pad_to_min_rank(w, 1, r) for w in lora_a_weights] - # lora_b_weights = [pad_to_min_rank(w, 2, r) for w in lora_b_weights] + r = max([w.size(-1) for w in lora_b_weights]) + lora_a_weights = [pad_to_min_rank(w, 1, r) for w in lora_a_weights] + lora_b_weights = [pad_to_min_rank(w, 2, r) for w in lora_b_weights] # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] lora_a_weights = torch.stack(lora_a_weights).to(self.device).contiguous() diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index cd41509cc..1801f4d21 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -58,15 +58,12 @@ async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - try: - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() - return generate_pb2.ClearCacheResponse() - except: - exit(1) + return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) @@ -115,8 +112,8 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): self.cache.set(next_batch) if self.model.profiler: - self.model.steps += 1 - if self.model.steps == 10: + self.model.profiler_steps += 1 + if self.model.profiler_steps == 10: self.model.profiler.stop() print(self.model.profiler.key_averages()) diff --git a/server/lorax_server/utils/profiler.py b/server/lorax_server/utils/profiler.py deleted file mode 100644 index bb74be6d2..000000000 --- a/server/lorax_server/utils/profiler.py +++ /dev/null @@ -1,39 +0,0 @@ -import time -from contextlib import contextmanager - -import torch - - -class TimingContextManager: - def __init__(self, name: str): - self.name = name - self.total_time = 0 - self.count = 0 - - @contextmanager - def timing(self): - start = time.time() - try: - yield - finally: - end = time.time() - self.total_time += end - start - self.count += 1 - # print(f"=== {self.name}: avg={self.get_average_time():.3f} s total={self.total_time:.3f} s count={self.count}") - - def get_average_time(self): - if self.count == 0: - return 0 - return self.total_time / self.count - - -_timers = {} - - -@contextmanager -def timer(name: str): - if name not in _timers: - _timers[name] = TimingContextManager(name) - with _timers[name].timing(): - yield - # torch.cuda.synchronize() diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 1946bbd25..2bf613a81 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -355,8 +355,6 @@ def __call__( S = 1 scores = scores.view(B, S, -1) - # print("!!! scores", scores.shape, B, S) - # print("!!! scores", scores.norm()) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) with self.schema_processor.restore_state() if self.schema_processor is not None else nullcontext(): for j in range(S): From f0693e96d01dd1d25a09bad54098a616ac7f3873 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 28 Oct 2024 09:18:18 -0700 Subject: [PATCH 26/76] Padding --- router/src/radix.rs | 13 +++---------- server/lorax_server/models/model.py | 4 ++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/router/src/radix.rs b/router/src/radix.rs index 1c0627e80..3b2a2263e 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -6,15 +6,7 @@ use std::{ sync::Arc, }; -// fn hash(adapter_index: u32, slice: &[u32]) -> u64 { -// assert!(!slice.is_empty()); -// let mut s = std::hash::DefaultHasher::new(); -// adapter_index.hash(&mut s); -// slice.hash(&mut s); -// s.finish() -// } - -fn hash(adapter_index: u32, slice: &[u32]) -> u64 { +fn hash(_adapter_index: u32, slice: &[u32]) -> u64 { assert!(!slice.is_empty()); if slice.len() == 1 { slice[0] as u64 @@ -244,7 +236,7 @@ struct RadixAllocation { #[derive(Debug)] pub enum TrieError { InvalidNodeId, - RefCountUnderflow, + // RefCountUnderflow, } pub type NodeId = DefaultKey; @@ -339,6 +331,7 @@ impl RadixTrie { .get_mut(node_id) .ok_or(TrieError::InvalidNodeId)?; if node.ref_count == 0 { + // TODO(travis): figureo ut why this is happening, but should be safe to skip // return Err(TrieError::RefCountUnderflow); return Ok(()); } diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index c774235aa..eff471ee7 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -304,8 +304,8 @@ def register_preloaded_adapters( # right pad every adapter to the max rank r = max([w.size(-1) for w in lora_b_weights]) - lora_a_weights = [pad_to_min_rank(w, 1, r) for w in lora_a_weights] - lora_b_weights = [pad_to_min_rank(w, 2, r) for w in lora_b_weights] + lora_a_weights = [pad_to_min_rank(w, 0, r) for w in lora_a_weights] + lora_b_weights = [pad_to_min_rank(w, 1, r) for w in lora_b_weights] # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] lora_a_weights = torch.stack(lora_a_weights).to(self.device).contiguous() From e62e0f8174a829d696873c204a6ea0a49fc94add Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 28 Oct 2024 12:27:51 -0700 Subject: [PATCH 27/76] Fixed turbo lora + compile --- .../models/custom_modeling/flash_qwen2_modeling.py | 5 +++++ server/lorax_server/models/flash_causal_lm.py | 6 ++++++ server/lorax_server/utils/graph.py | 4 +++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index fc3ac7097..94d7766e3 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -488,6 +488,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -513,6 +514,10 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index fdb2a731d..a31a94c70 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1419,6 +1419,8 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> lm_head_indices=batch.prefill_head_indices, ) else: + skip_lm_head = get_speculative_tokens() > 0 + # CUDA graph mode out = model.forward( input_ids=input_ids, @@ -1436,6 +1438,10 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> lm_head_indices=batch.prefill_head_indices, ) + if skip_lm_head: + # re-run through the LM head as the graph did not capture it + out = self.model.lm_head(out[0], adapter_data) + if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index c08d06566..a83e22451 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -19,7 +19,7 @@ from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.sgmv import BGMV_MAX_RANK, PunicaWrapper -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, get_speculative_tokens if TYPE_CHECKING: from lorax_server.models.flash_causal_lm import FlashCausalLMBatch @@ -339,6 +339,7 @@ def trace( adapter_data=input_state.adapter_data, prefill_cache_indices=None, lm_head_indices=None, + skip_lm_head=get_speculative_tokens() > 0, ) torch.cuda.synchronize() @@ -356,6 +357,7 @@ def trace( adapter_data=input_state.adapter_data, prefill_cache_indices=None, lm_head_indices=None, + skip_lm_head=get_speculative_tokens() > 0, ) torch.cuda.synchronize(device) From 66d86765ea42539b626f3ae4b83b9024ded71bb9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 28 Oct 2024 13:00:47 -0700 Subject: [PATCH 28/76] Fix --- server/lorax_server/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index eff471ee7..fcf80f906 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -81,6 +81,7 @@ def __init__( self.preloaded_adapter_indices = set() self.preloaded_adapter_memory_fractions = {} self.preloaded_adapters = [] + self.layer_to_lora_weights = {} self.trust_remote_code = trust_remote_code @@ -271,7 +272,6 @@ def register_preloaded_adapters( # where: # lora_a_weights = [num_adapters, r, hidden_size] # lora_b_weights = [num_adapters, hidden_size, r] - self.layer_to_lora_weights = {} for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items(): layer_id_to_lora_a_weights = defaultdict(list) layer_id_to_lora_b_weights = defaultdict(list) From 55e5c414407d0908571fcc628ad4caeaea6576fd Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 29 Oct 2024 22:17:48 -0700 Subject: [PATCH 29/76] Fix adapter root node id --- Cargo.lock | 87 +++++++++++++++++++++++++++++++----- router/Cargo.toml | 3 ++ router/src/radix.rs | 106 ++++++++++++++++++++++++++------------------ 3 files changed, 141 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 82f5029e9..ec8bd04c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,15 @@ dependencies = [ "libc", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.12" @@ -1558,7 +1567,7 @@ dependencies = [ "serde", "serde_json", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.17", "vergen", ] @@ -1605,7 +1614,8 @@ dependencies = [ "tower-http 0.4.1", "tracing", "tracing-opentelemetry 0.19.0", - "tracing-subscriber", + "tracing-subscriber 0.3.17", + "tracing-test", "utoipa", "utoipa-swagger-ui", "vergen", @@ -1643,6 +1653,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" +[[package]] +name = "matchers" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchers" version = "0.1.0" @@ -3754,11 +3773,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "log", "pin-project-lite", "tracing-attributes", @@ -3767,9 +3785,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -3778,9 +3796,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -3818,7 +3836,7 @@ dependencies = [ "tracing", "tracing-core", "tracing-log", - "tracing-subscriber", + "tracing-subscriber 0.3.17", ] [[package]] @@ -3832,7 +3850,7 @@ dependencies = [ "tracing", "tracing-core", "tracing-log", - "tracing-subscriber", + "tracing-subscriber 0.3.17", ] [[package]] @@ -3845,13 +3863,35 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "ansi_term", + "chrono", + "lazy_static", + "matchers 0.0.1", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", + "tracing-serde", +] + [[package]] name = "tracing-subscriber" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" dependencies = [ - "matchers", + "matchers 0.1.0", "nu-ansi-term", "once_cell", "regex", @@ -3866,6 +3906,29 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-test" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3b48778c2d401c6a7fcf38a0e3c55dc8e8e753cbd381044a8cdb6fd69a29f53" +dependencies = [ + "lazy_static", + "tracing-core", + "tracing-subscriber 0.2.25", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c49adbab879d2e0dd7f75edace5f0ac2156939ecb7e6a1e8fa14e53728328c48" +dependencies = [ + "lazy_static", + "quote", + "syn 1.0.109", +] + [[package]] name = "try-lock" version = "0.2.4" diff --git a/router/Cargo.toml b/router/Cargo.toml index 27373864f..c38330683 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -68,6 +68,9 @@ base64 = "0.22.0" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } +[dev-dependencies] +tracing-test = "0.1" + [features] default = ["ngrok"] ngrok = ["dep:ngrok"] diff --git a/router/src/radix.rs b/router/src/radix.rs index 3b2a2263e..637ec2f65 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, }; -fn hash(_adapter_index: u32, slice: &[u32]) -> u64 { +fn hash(slice: &[u32]) -> u64 { assert!(!slice.is_empty()); if slice.len() == 1 { slice[0] as u64 @@ -49,7 +49,7 @@ impl RadixAllocator { } } - fn alloc_or_reclaim(&mut self, adapter_index: u32, n_blocks_needed: usize) -> Option> { + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { if self.free_blocks.len() < n_blocks_needed { // This is a bit annoying, we first extend the free list and then // split it off again below. This is because we need to put it on @@ -62,7 +62,7 @@ impl RadixAllocator { ); self.free_blocks.extend( self.cache_blocks - .evict(adapter_index, n_blocks_needed - self.free_blocks.len()), + .evict(n_blocks_needed - self.free_blocks.len()), ); } @@ -101,7 +101,7 @@ impl Allocator for RadixAllocator { // Even if this allocation fails below, we need to increase he // refcount to ensure that the prefix that was found is not evicted. self.cache_blocks - .incref(adapter_index, prefix_node) + .incref(prefix_node) .expect("Failed to increment refcount"); let prefix_len = blocks.len() * self.block_size as usize; @@ -111,14 +111,14 @@ impl Allocator for RadixAllocator { tracing::debug!("Prefix {prefix_len} - Suffix {suffix_len}"); - match self.alloc_or_reclaim(adapter_index, suffix_blocks as usize) { + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { tracing::debug!("Cannot allocate {:?}", self.cache_blocks); tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens"); tracing::debug!("Block size {}", self.block_size); self.cache_blocks - .decref(adapter_index, prefix_node) + .decref(prefix_node) .expect("Failed to decrement refcount"); return None; } @@ -166,7 +166,7 @@ impl Allocator for RadixAllocator { }; self.cache_blocks - .decref(allocation.adapter_index, allocation.prefix_node) + .decref(allocation.prefix_node) .expect("Failed to decrement refcount"); if let Some(prefill_tokens) = allocation.prefill_tokens { @@ -236,7 +236,7 @@ struct RadixAllocation { #[derive(Debug)] pub enum TrieError { InvalidNodeId, - // RefCountUnderflow, + RefCountUnderflow, } pub type NodeId = DefaultKey; @@ -300,7 +300,7 @@ impl RadixTrie { let node = &self.nodes[node_id]; if key.len() >= self.block_size { - let node_key = hash(adapter_index, &key[..self.block_size]); + let node_key = hash(&key[..self.block_size]); if let Some(&child_id) = node.children.get(&node_key) { self.update_access_time(child_id); let child = self.nodes.get(child_id).expect("Invalid child identifier"); @@ -319,10 +319,10 @@ impl RadixTrie { } /// Decrease the reference count of a node. - pub fn decref(&mut self, adapter_index: u32, node_id: NodeId) -> Result<(), TrieError> { + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { // We don't care about refcounting for root, since it will never // be evicted. - if node_id == self.root_id(adapter_index) { + if self.is_root(node_id) { return Ok(()); } @@ -331,9 +331,7 @@ impl RadixTrie { .get_mut(node_id) .ok_or(TrieError::InvalidNodeId)?; if node.ref_count == 0 { - // TODO(travis): figureo ut why this is happening, but should be safe to skip - // return Err(TrieError::RefCountUnderflow); - return Ok(()); + return Err(TrieError::RefCountUnderflow); } node.ref_count -= 1; @@ -350,8 +348,8 @@ impl RadixTrie { } /// Increase the reference count of a node. - pub fn incref(&mut self, adapter_index: u32, node_id: NodeId) -> Result<(), TrieError> { - if node_id == self.root_id(adapter_index) { + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if self.is_root(node_id) { return Ok(()); } @@ -371,7 +369,7 @@ impl RadixTrie { /// /// Returns the evicted blocks. When the length is less than `n_blocks`, /// not enough blocks could be evicted. - pub fn evict(&mut self, adapter_index: u32, n_blocks: usize) -> Vec { + pub fn evict(&mut self, n_blocks: usize) -> Vec { // NOTE: we don't return Result here. If any of the unwrapping fails, // it's a programming error in the trie implementation, not a user // error caused by e.g. an invalid argument. @@ -395,7 +393,7 @@ impl RadixTrie { if blocks_needed >= node.blocks.len() { // We need to evict the whole node if we need more blocks than it has. - let node = self.remove_node(adapter_index, node_id); + let node = self.remove_node(node_id); evicted.extend(node.blocks); if evicted.len() >= n_blocks { @@ -450,7 +448,7 @@ impl RadixTrie { assert_eq!(tokens.len(), blocks.len() * self.block_size); - let node_key = hash(adapter_index, &tokens[..self.block_size]); + let node_key = hash(&tokens[..self.block_size]); if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) { self.update_access_time(child_id); let child = self @@ -479,17 +477,17 @@ impl RadixTrie { // The node's prefix and the insertion prefix only match partially, // split the node to just contain the matching part. Then insert the // remainder of the prefix into the node again - let child_id = self.split_node(adapter_index, child_id, shared_prefix_len); + let child_id = self.split_node(child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; let blocks = &blocks[shared_prefix_len / self.block_size..]; Ok(shared_prefix_len + self.insert_(adapter_index, child_id, key, blocks)?) } else { - self.add_node(adapter_index, node_id, tokens, blocks); + self.add_node(node_id, tokens, blocks); Ok(0) } } - fn split_node(&mut self, adapter_index: u32, node_id: NodeId, prefix_len: usize) -> NodeId { + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { // We have to make the current node a child to ensure that its // properties and node id stay the same. @@ -508,11 +506,11 @@ impl RadixTrie { std::mem::swap(&mut node.key, &mut parent_key); std::mem::swap(&mut node.blocks, &mut parent_blocks); - let node_key = hash(adapter_index, &node.key[..self.block_size]); + let node_key = hash(&node.key[..self.block_size]); let grandparent_id = node.parent.expect("Node does not have a parent"); - let parent_id = self.add_node(adapter_index, grandparent_id, parent_key, parent_blocks); - self.add_node_to_parent(adapter_index, parent_id, node_key, node_id); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); // Reborrow to make the borrow checker happy. let node = self @@ -527,43 +525,36 @@ impl RadixTrie { /// Create a node and add it to the parent. fn add_node( &mut self, - adapter_index: u32, parent_id: NodeId, key: impl Into>, blocks: impl Into>, ) -> NodeId { let key = key.into(); let blocks = blocks.into(); - let first = hash(adapter_index, &key[..self.block_size]); + let first = hash(&key[..self.block_size]); let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); let child_id = self.nodes.insert(child); - self.add_node_to_parent(adapter_index, parent_id, first, child_id); + self.add_node_to_parent(parent_id, first, child_id); self.leaves.insert((self.time, child_id)); child_id } /// Add a node to the parent. - fn add_node_to_parent( - &mut self, - adapter_index: u32, - parent_id: NodeId, - hash: u64, - child_id: NodeId, - ) { + fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) { // Unwrap here, passing in an unknown id is a programming error. let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); if parent.children.insert(hash, child_id).is_none() { // Only increase reference count if child does not replace another child. - self.incref(adapter_index, parent_id) + self.incref(parent_id) .expect("Failed to increase parent refcount"); } } /// Remove a node from the trie. - fn remove_node(&mut self, adapter_index: u32, node_id: NodeId) -> TrieNode { + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { // Unwrap here, passing in an unknown id is a programming error. let node = self.nodes.remove(node_id).expect("Unknown node"); assert!( @@ -574,9 +565,9 @@ impl RadixTrie { let parent_id = node.parent.expect("Attempted to remove root node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); - let node_key = hash(adapter_index, &node.key[..self.block_size]); + let node_key = hash(&node.key[..self.block_size]); parent.children.remove(&node_key); - self.decref(adapter_index, parent_id) + self.decref(parent_id) .expect("Failed to decrease parent refcount"); node } @@ -630,6 +621,11 @@ impl RadixTrie { pub(crate) fn root_id(&self, adapter_index: u32) -> DefaultKey { self.roots[&adapter_index] } + + pub(crate) fn is_root(&self, node_id: NodeId) -> bool { + let node = self.nodes.get(node_id).expect("Unknown node"); + node.parent.is_none() + } } /// Trie node. @@ -667,6 +663,7 @@ fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { #[cfg(test)] mod tests { use std::sync::Arc; + use tracing_test::traced_test; use super::*; @@ -768,6 +765,29 @@ mod tests { assert_eq!(cache.free_blocks.len(), 5); } + #[traced_test] + #[test] + fn allocator_frees_fully_overlapping_prefills_multi_adapter() { + let mut cache = RadixAllocator::new(1, 5, None); + let allocation1 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(1, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + let allocation3 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation3.prefix_len, 0); + + // 5 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 0); + } + #[test] fn allocator_frees_partially_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 20, None); @@ -906,7 +926,7 @@ mod tests { let mut blocks = Vec::new(); // Remove less than the leave blocks. - assert_eq!(trie.evict(0, 1), vec![7]); + assert_eq!(trie.evict(1), vec![7]); trie.find(0, &[0, 1, 2, 3, 5, 6, 7], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); @@ -915,7 +935,7 @@ mod tests { trie.find(0, &[1, 2, 3], &mut blocks); // Remove the leave blocks exactly. - assert_eq!(trie.evict(0, 2), vec![5, 6]); + assert_eq!(trie.evict(2), vec![5, 6]); blocks.clear(); trie.find(0, &[0, 1, 2, 3, 5, 6, 7], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3]); @@ -923,12 +943,12 @@ mod tests { trie.find(0, &[1, 2, 3], &mut blocks); // Remove more than the leave blocks. - assert_eq!(trie.evict(0, 3), vec![4, 3, 2]); + assert_eq!(trie.evict(3), vec![4, 3, 2]); blocks.clear(); trie.find(0, &[0, 1, 2, 3, 4], &mut blocks); assert_eq!(blocks, vec![0, 1]); // Clear out the whole trie. - assert_eq!(trie.evict(0, 10), vec![1, 2, 3, 0, 1]); + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); } } From a6f3a17e7e94dd99fa569b69921ab93baec7a967 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 29 Oct 2024 22:25:55 -0700 Subject: [PATCH 30/76] More tests --- router/src/radix.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/router/src/radix.rs b/router/src/radix.rs index 637ec2f65..21b603c4c 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -719,6 +719,35 @@ mod tests { assert_eq!(allocation.prefix_len, 4); } + #[test] + fn allocator_reuses_prefixes_multi_adapter() { + let mut cache = RadixAllocator::new(1, 20, None); + + // Allocate 8 tokens: 4 tokens in prefill + 4 slots for generation + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![12, 13, 14, 15, 16, 17, 18, 19]); + assert_eq!(allocation.blocks, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + // 4 new blocks, 4 reused blocks from unused slots that were freed above. + let allocation = cache + .allocate(1, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11, 16, 17, 18, 19]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + // Same blocks as the first allocation, as cache was never evicted. + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![12, 13, 14, 15, 16, 17, 18, 19]); + assert_eq!(allocation.prefix_len, 4); + } + #[test] fn allocator_collects_older_prefixes_first() { let mut cache = RadixAllocator::new(1, 7, None); From 352c92a3b27a06bdf0f3bae1ca5b8ab839009a2d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 30 Oct 2024 08:55:25 -0700 Subject: [PATCH 31/76] Docker test --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2255cc086..4d19ad305 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,7 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-2,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=raw,value=optimizations-3,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: From 1ea8d6efcc93abc7f4646607e5381c4085d31ed6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 30 Oct 2024 14:44:20 -0700 Subject: [PATCH 32/76] Bump flashinfer --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index e7aa5a3dd..eccefae58 100644 --- a/Dockerfile +++ b/Dockerfile @@ -216,7 +216,7 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 RUN pip install einops --no-cache-dir # Install flashinfer -RUN pip install --no-cache-dir flashinfer==0.1.5+cu124torch2.4 -i https://flashinfer.ai/whl/cu124/torch2.4 +RUN pip install --no-cache-dir flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 # Install server COPY proto proto From c0640f28d597f02eaeec875ba5cbf6f92b5a825d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 09:36:42 -0700 Subject: [PATCH 33/76] Added logprobs fix --- router/src/batch.rs | 13 +++++++++++++ router/src/scheduler.rs | 10 +++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 2be4bcdbe..18d0b40ad 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -22,6 +22,7 @@ use crate::{ }; pub(crate) trait ValidRequest: Sync + Send + Debug + Any { + fn decoder_input_details(&self) -> bool; fn input_length(&self) -> u32; fn input_ids(&self) -> Option>>; fn max_new_tokens(&self) -> u32; @@ -31,6 +32,10 @@ pub(crate) trait ValidRequest: Sync + Send + Debug + Any { } impl ValidRequest for ValidGenerateRequest { + fn decoder_input_details(&self) -> bool { + self.decoder_input_details + } + fn input_length(&self) -> u32 { self.input_length } @@ -69,6 +74,10 @@ pub(crate) struct ValidEmbedRequest { } impl ValidRequest for ValidEmbedRequest { + fn decoder_input_details(&self) -> bool { + false + } + fn input_length(&self) -> u32 { self.input_length } @@ -107,6 +116,10 @@ pub(crate) struct ValidClassifyRequest { } impl ValidRequest for ValidClassifyRequest { + fn decoder_input_details(&self) -> bool { + false + } + fn input_length(&self) -> u32 { self.input_length } diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 4231b3894..22bcb2418 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -365,6 +365,14 @@ impl AdapterSchedulerState { None } Some(block_allocator) => { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details() { + None + } else { + entry.request.input_ids().clone() + }; + let tokens = entry.request.input_length() + entry.request.max_new_tokens() + self.speculate @@ -379,7 +387,7 @@ impl AdapterSchedulerState { ); let block_allocation = match block_allocator - .allocate(adapter.index(), tokens, entry.request.input_ids()) + .allocate(adapter.index(), tokens, input_ids) .await { None => { From 54c36c9e1225d248bb4ce784064124c6237ae368 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 13:16:25 -0700 Subject: [PATCH 34/76] Fix slots --- router/src/radix.rs | 44 +++++++++++----- server/lorax_server/adapters/medusa.py | 4 +- server/lorax_server/models/flash_causal_lm.py | 50 +++++++++++++++---- server/lorax_server/server.py | 11 ++-- .../utils/flashinfer_attention.py | 14 ++++-- server/lorax_server/utils/graph.py | 5 +- 6 files changed, 92 insertions(+), 36 deletions(-) diff --git a/router/src/radix.rs b/router/src/radix.rs index 21b603c4c..75430f37f 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -85,6 +85,14 @@ impl Allocator for RadixAllocator { tokens: u32, prefill_tokens: Option>>, ) -> Option { + // print out blocks for allocation + tracing::debug!( + "!!! Allocate blocks {:?} {:?} {:?}", + adapter_index, + tokens, + prefill_tokens.as_ref().as_slice() + ); + // ensure root node exists self.cache_blocks.get_or_create_root(adapter_index); @@ -110,6 +118,7 @@ impl Allocator for RadixAllocator { let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; tracing::debug!("Prefix {prefix_len} - Suffix {suffix_len}"); + tracing::debug!("Cached blocks: {blocks:?}"); match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), @@ -150,6 +159,15 @@ impl Allocator for RadixAllocator { self.allocation_id += 1; self.allocations.insert(self.allocation_id, allocation); + // log final blocks and slots + tracing::debug!( + "!!! BlockAllocation {:?} {:?} {:?} {:?}", + adapter_index, + blocks, + slots, + prefix_len + ); + Some(BlockAllocation { allocation_id: self.allocation_id, block_allocator: None, @@ -165,6 +183,13 @@ impl Allocator for RadixAllocator { None => unreachable!("Tried to free an unknown allocation."), }; + tracing::debug!( + "!!! Free blocks {:?} {:?} {:?}", + allocation.adapter_index, + allocation.cached_prefix_len, + allocation.prefill_tokens.as_ref().as_slice() + ); + self.cache_blocks .decref(allocation.prefix_node) .expect("Failed to decrement refcount"); @@ -286,17 +311,11 @@ impl RadixTrie { /// Using this method will update the access time of the traversed nodes. pub fn find(&mut self, adapter_index: u32, key: &[u32], blocks: &mut Vec) -> NodeId { self.time += 1; - self.find_(adapter_index, self.root_id(adapter_index), key, blocks) + self.find_(self.root_id(adapter_index), key, blocks) } /// Find worker. - fn find_( - &mut self, - adapter_index: u32, - mut node_id: NodeId, - key: &[u32], - blocks: &mut Vec, - ) -> NodeId { + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { let node = &self.nodes[node_id]; if key.len() >= self.block_size { @@ -310,7 +329,7 @@ impl RadixTrie { let key = &key[shared_prefix_len..]; if !key.is_empty() { - node_id = self.find_(adapter_index, child_id, key, blocks); + node_id = self.find_(child_id, key, blocks); } } } @@ -430,14 +449,13 @@ impl RadixTrie { ) -> Result { self.time += 1; let node_id = self.get_or_create_root(adapter_index); - let common = self.insert_(adapter_index, node_id, tokens, blocks)?; + let common = self.insert_(node_id, tokens, blocks)?; Ok(common) } /// Insertion worker. fn insert_( &mut self, - adapter_index: u32, node_id: NodeId, tokens: &[u32], blocks: &[u32], @@ -467,7 +485,6 @@ impl RadixTrie { if shared_prefix_len == child.key.len() { return Ok(shared_prefix_len + self.insert_( - adapter_index, child_id, &tokens[shared_prefix_len..], &blocks[shared_prefix_len / self.block_size..], @@ -480,7 +497,7 @@ impl RadixTrie { let child_id = self.split_node(child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; let blocks = &blocks[shared_prefix_len / self.block_size..]; - Ok(shared_prefix_len + self.insert_(adapter_index, child_id, key, blocks)?) + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) } else { self.add_node(node_id, tokens, blocks); Ok(0) @@ -719,6 +736,7 @@ mod tests { assert_eq!(allocation.prefix_len, 4); } + #[traced_test] #[test] fn allocator_reuses_prefixes_multi_adapter() { let mut cache = RadixAllocator::new(1, 20, None); diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 9c3dcaff5..ec55ca608 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -312,11 +312,11 @@ def load( default_medusa=default_medusa, segments=MedusaSegments( w=[ - (adapter_weights[idx].model.medusa.linear.linear.weight if idx in adapter_weights else EMPTY_TENSOR) + (adapter_weights[idx].model.medusa.linear.linear.weight.data if idx in adapter_weights else EMPTY_TENSOR) for idx in segment_indices ], b=[ - (adapter_weights[idx].model.medusa.linear.linear.bias if idx in adapter_weights else EMPTY_TENSOR) + (adapter_weights[idx].model.medusa.linear.linear.bias.data if idx in adapter_weights else EMPTY_TENSOR) for idx in segment_indices ], s_start=segments[indices], diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index a31a94c70..1011697ae 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -440,7 +440,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Set slice slot_filtering_indices[ - self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + remaining_tokens - 1 + self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + request_cache_length + remaining_tokens - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 @@ -486,6 +486,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": segment_indices=adapter_segment_indices, ) + logger.info("!!! FILTER slots {} -> {}", self.slots, slots) + logger.info("!!! FILTER slots_indices {} -> {}", self.slot_indices, slot_indices) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -696,15 +699,18 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) # Discard speculative IDs if they are not present in all batches - keep_speculative_ids = all(b.speculative_ids is not None for b in batches) - if not keep_speculative_ids: - logger.info("Discarding speculative IDs, not every batch has them") - - speculative_ids = ( - torch.cat( - [b.speculative_ids for b in batches], dim=0) - if keep_speculative_ids else None - ) + if get_speculative_tokens() > 0: + keep_speculative_ids = all(b.speculative_ids is not None for b in batches) + if not keep_speculative_ids: + logger.info("Discarding speculative IDs, not every batch has them") + + speculative_ids = ( + torch.cat( + [b.speculative_ids for b in batches], dim=0) + if keep_speculative_ids else None + ) + else: + speculative_ids = None if adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() @@ -715,6 +721,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) + + logger.info("!!! CONCATENATE slots {} -> {}", [b.slots for b in batches], slots) + logger.info("!!! CONCATENATE slots_indices {} -> {}", [b.slot_indices for b in batches], slot_indices) return cls( batch_id=batches[0].batch_id, @@ -922,6 +931,9 @@ def prepare_for_prefill(self): segment_indices=adapter_segment_indices, ) + logger.info("!!! PREPARE_FOR_PREFILL slots {}", self.slots) + logger.info("!!! PREPARE_FOR_PREFILL slots_indices {}", self.slot_indices) + def __len__(self): return len(self.requests) @@ -1354,6 +1366,11 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length + logger.info("!!! BLOCKS={} {}\n SLOTS={} {}\n SLOT_INDICES={} {}", + block_tables.tolist(), block_tables.shape, + batch.slots.tolist(), batch.slots.shape, + batch.slot_indices.tolist(), batch.slot_indices.shape) + if batch.speculative_ids is not None: speculative_ids = batch.speculative_ids @@ -1363,9 +1380,19 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + logger.info("!!! SLOT INDICES {} -> {}", batch.slot_indices.tolist(), slot_indices.tolist()) + + slots = batch.slots[slot_indices] + + # slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + logger.info("!!! NEW SLOTS {}", slots.tolist(), slots.shape) + + logger.info("!!! BEFORE {} {}", input_lengths, batch.cache_lengths_tensor) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) + logger.info("!!! AFTER {} {}", input_lengths, cache_lengths_tensor) block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length @@ -1946,6 +1973,7 @@ def generate_token( batch.next_token_chooser.next_state(i, next_token_id) # Update values + logger.info(f"!!! UPDATE VALUES {i} n_accepted_ids={n_accepted_ids} new_input_length={new_input_length} input_length={input_length} cache_length={cache_length}") index += n_accepted_ids current_cache_length = cache_length + input_length batch.cache_lengths[i] = current_cache_length diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 1801f4d21..796c377df 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -58,10 +58,13 @@ async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() + try: + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + except: + exit(1) return generate_pb2.ClearCacheResponse() diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index 2accc60ca..aa3c8c118 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -53,7 +53,9 @@ def use_prefill_with_paged_kv_state( `attention` function while the context manager is active. """ - indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) @@ -62,9 +64,13 @@ def use_prefill_with_paged_kv_state( # Get the lengths of the last page in a block. if page_size == 1: - last_page_len = torch.ones(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) else: - last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 @@ -81,7 +87,7 @@ def use_prefill_with_paged_kv_state( head_dim=head_size, q_data_type=dtype, page_size=page_size, - # window_left=window_left, # TODO + window_left=window_left, ) yield finally: diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index a83e22451..734bcda4e 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -15,7 +15,7 @@ from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.adapters.lora import BatchLoraWeights, RankSegments -from lorax_server.adapters.types import LORA +from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.sgmv import BGMV_MAX_RANK, PunicaWrapper @@ -607,7 +607,8 @@ def forward( graph.input_state.traced_adapter_layer_names if graph is not None else set() ) logger.info( - "Retrace graph with new adapter layers: {} -> {}", + "batch_size={} -- retrace graph with new adapter layers: {} -> {}", + batch_size, current_traced_adapter_layer_names, adapter_data.layer_names(), ) From 88cd932f894e701bb0bb2e51ef0411b69c2fb504 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 13:23:12 -0700 Subject: [PATCH 35/76] No debugging --- server/lorax_server/models/flash_causal_lm.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 1011697ae..70cbaf90b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -486,8 +486,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": segment_indices=adapter_segment_indices, ) - logger.info("!!! FILTER slots {} -> {}", self.slots, slots) - logger.info("!!! FILTER slots_indices {} -> {}", self.slot_indices, slot_indices) + # logger.info("!!! FILTER slots {} -> {}", self.slots, slots) + # logger.info("!!! FILTER slots_indices {} -> {}", self.slot_indices, slot_indices) return type(self)( batch_id=self.batch_id, @@ -722,8 +722,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch segment_indices=adapter_segment_indices, ) - logger.info("!!! CONCATENATE slots {} -> {}", [b.slots for b in batches], slots) - logger.info("!!! CONCATENATE slots_indices {} -> {}", [b.slot_indices for b in batches], slot_indices) + # logger.info("!!! CONCATENATE slots {} -> {}", [b.slots for b in batches], slots) + # logger.info("!!! CONCATENATE slots_indices {} -> {}", [b.slot_indices for b in batches], slot_indices) return cls( batch_id=batches[0].batch_id, @@ -931,8 +931,8 @@ def prepare_for_prefill(self): segment_indices=adapter_segment_indices, ) - logger.info("!!! PREPARE_FOR_PREFILL slots {}", self.slots) - logger.info("!!! PREPARE_FOR_PREFILL slots_indices {}", self.slot_indices) + # logger.info("!!! PREPARE_FOR_PREFILL slots {}", self.slots) + # logger.info("!!! PREPARE_FOR_PREFILL slots_indices {}", self.slot_indices) def __len__(self): return len(self.requests) @@ -1366,10 +1366,10 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length - logger.info("!!! BLOCKS={} {}\n SLOTS={} {}\n SLOT_INDICES={} {}", - block_tables.tolist(), block_tables.shape, - batch.slots.tolist(), batch.slots.shape, - batch.slot_indices.tolist(), batch.slot_indices.shape) + # logger.info("!!! BLOCKS={} {}\n SLOTS={} {}\n SLOT_INDICES={} {}", + # block_tables.tolist(), block_tables.shape, + # batch.slots.tolist(), batch.slots.shape, + # batch.slot_indices.tolist(), batch.slot_indices.shape) if batch.speculative_ids is not None: speculative_ids = batch.speculative_ids @@ -1382,17 +1382,17 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - logger.info("!!! SLOT INDICES {} -> {}", batch.slot_indices.tolist(), slot_indices.tolist()) + # logger.info("!!! SLOT INDICES {} -> {}", batch.slot_indices.tolist(), slot_indices.tolist()) slots = batch.slots[slot_indices] # slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - logger.info("!!! NEW SLOTS {}", slots.tolist(), slots.shape) + # logger.info("!!! NEW SLOTS {}", slots.tolist(), slots.shape) - logger.info("!!! BEFORE {} {}", input_lengths, batch.cache_lengths_tensor) + # logger.info("!!! BEFORE {} {}", input_lengths, batch.cache_lengths_tensor) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) - logger.info("!!! AFTER {} {}", input_lengths, cache_lengths_tensor) + # logger.info("!!! AFTER {} {}", input_lengths, cache_lengths_tensor) block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length @@ -1973,7 +1973,7 @@ def generate_token( batch.next_token_chooser.next_state(i, next_token_id) # Update values - logger.info(f"!!! UPDATE VALUES {i} n_accepted_ids={n_accepted_ids} new_input_length={new_input_length} input_length={input_length} cache_length={cache_length}") + # logger.info(f"!!! UPDATE VALUES {i} n_accepted_ids={n_accepted_ids} new_input_length={new_input_length} input_length={input_length} cache_length={cache_length}") index += n_accepted_ids current_cache_length = cache_length + input_length batch.cache_lengths[i] = current_cache_length From 3505b52bece3c2d96d0ef6e2f9de8ff5709ad20c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 13:23:32 -0700 Subject: [PATCH 36/76] Docker test --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 4d19ad305..bf052f7d1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,7 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-3,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=raw,value=optimizations-4,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: From cf3d2d90bf248d856aa40504dc8ad72e17c645c5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 14:44:31 -0700 Subject: [PATCH 37/76] Fixed slot filtering --- server/lorax_server/models/flash_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 70cbaf90b..a0855f923 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -440,7 +440,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Set slice slot_filtering_indices[ - self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + request_cache_length + remaining_tokens - 1 + self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + remaining_tokens + get_speculative_tokens() - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 From d1ff7b4b1548ca5f8bfb52bff49b83a6c65de7f7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 15:08:25 -0700 Subject: [PATCH 38/76] Triton kernels --- server/lorax_server/models/flash_causal_lm.py | 421 ++++++++++++------ .../lorax_server/models/metadata_kernels.py | 347 +++++++++++++++ 2 files changed, 627 insertions(+), 141 deletions(-) create mode 100644 server/lorax_server/models/metadata_kernels.py diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index a0855f923..ee8a42e79 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -44,6 +44,14 @@ ) from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import Weights +from lorax_server.models.metadata_kernels import ( + has_triton, + copy_next_input_ids_inplace, + block_tables_to_ragged, + block_tables_to_padded, + prepare_position_slot_ids, + slots_filtering, +) ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) @@ -84,6 +92,10 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward before staying set in decode slots: Optional[torch.Tensor] + # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch + # used for filtering + cu_slots: torch.Tensor + max_input_length: int max_current_length: int @@ -94,7 +106,7 @@ class FlashCausalLMBatch(Batch): prefilling_mask_tensor: Optional[torch.Tensor] # Prefill metadata tensors to efficiently compute logprobs - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + # tensor of length b+1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers # as we only keep SLIDING_WINDOW values instead of the whole tensor @@ -196,6 +208,8 @@ def from_pb( all_input_ids = [] all_postfix_ids = [] requests_idx_mapping = {} + slots = [] + cu_slots = [0] next_token_chooser_parameters = [] stopping_criterias = [] @@ -206,7 +220,9 @@ def from_pb( max_length = 0 max_blocks = 0 + cu_blocks = [0] block_tables = [] + block_tables_ragged = [] # Parse batch for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): @@ -266,10 +282,22 @@ def from_pb( if not r.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)] + request_slots = [ + s + for b in request_blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] else: request_blocks = r.blocks + request_slots = r.slots block_tables.append(request_blocks) + block_tables_ragged.extend(request_blocks) + cu_blocks.append(len(block_tables_ragged)) + + slots.extend(request_slots) + cu_slots.append(len(slots)) + cache_lengths.append(cache_length) num_blocks += len(request_blocks) @@ -297,11 +325,33 @@ def from_pb( # Create tensors on device all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device) - block_tables_tensor = torch.zeros((len(block_tables), max_blocks), dtype=torch.int32, device="cpu") - for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - block_tables_tensor = block_tables_tensor.to(device) - prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32, device=device) + block_tables_ragged = torch.tensor( + block_tables_ragged, device=device, dtype=torch.int32 + ) + cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_tensor = torch.empty( + (len(block_tables), max_blocks), + device=device, + dtype=torch.int32, + ) + + # If the device supports Triton, we can use a fused kernel + if has_triton(): + block_tables_to_padded( + max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged + ) + else: + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor( + request_blocks + ) + + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) + + slots = torch.tensor(slots, dtype=torch.int64, device=device) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) prefilling_mask = [True] * len(pb.requests) prefilling_mask_tensor = torch.tensor(prefilling_mask, dtype=torch.bool, device=device) @@ -337,7 +387,8 @@ def from_pb( cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=None, - slots=None, + slots=slots, + cu_slots=cu_slots, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -363,7 +414,11 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": indices = [] # slots to keep after filtering - slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool, device=device) + if not has_triton(): + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device + ) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -380,6 +435,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": cache_lengths = [] prefix_offsets = [] read_offsets = [] + cu_slots = [0] prefilling_mask = [] prefill_logprob_tokens = [] @@ -389,8 +445,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 + max_slots = 0 + cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -428,24 +484,27 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks += len(request_block_table) block_tables.append(request_block_table) + start_slot = self.cu_slots[idx] + end_slot = self.cu_slots[idx + 1] + slot_length = end_slot - start_slot + + if not has_triton(): + # Set slice + slot_filtering_indices[start_slot:end_slot] = True + + cu_slots.append(cumulative_slot_tokens + slot_length) + # Input ids if the request was part of a prefilling batch # If the batch was decoding we can index into the tensor directly later if self.prefilling: input_ids.append(self.input_ids[idx]) else: # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length - - remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - - # Set slice - slot_filtering_indices[ - self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + remaining_tokens + get_speculative_tokens() - 1 - ] = True - - cumulative_max_length += request_input_length + remaining_tokens - 1 + slot_indices[i] = cumulative_slot_tokens + request_cache_length + cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] @@ -453,12 +512,22 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": next_token_chooser = self.next_token_chooser.filter(indices) speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None prompt_lengths_tensor = self.prompt_lengths_tensor[indices] + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) + + if not has_triton(): + slots = self.slots[slot_filtering_indices] + else: + slots = self.slots.new_empty(cumulative_slot_tokens) + gpu_cu_slots = cu_slots.to(device) + slots_indexing_start = self.cu_slots.to(device)[indices] + slots_filtering( + max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start + ) if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None slot_indices = None - slots = None cache_lengths_tensor = None input_lengths_tensor = None adapter_meta = None @@ -469,7 +538,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] prefilling_mask_tensor = None @@ -502,6 +570,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + cu_slots=cu_slots, max_input_length=max_input_length, max_current_length=max_current_length, prefilling=self.prefilling, @@ -562,11 +631,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) prefilling = prefilling or b.prefilling + slots = batches[0].slots.new_empty(total_slots) + cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64) if prefilling: input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - slots = None slot_indices = None cache_lengths_tensor = None input_lengths_tensor = None @@ -576,7 +646,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(total_batch_size) @@ -636,13 +705,16 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - if not prefilling: - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + slots[slots_start_index:slots_end_index] = batch.slots + cu_slots[start_index + 1 : end_index + 1] = ( + batch.cu_slots[1:] + cumulative_slots + ) + if not prefilling: input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - slots[slots_start_index:slots_end_index] = batch.slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor @@ -658,9 +730,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - - # Update - cumulative_slots += len(batch.slots) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -688,6 +757,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch stopping_criterias.extend(batch.stopping_criterias) # Update + cumulative_slots += len(batch.slots) cumulative_batch_size += len(batch) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -740,6 +810,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cache_lengths=cache_lengths, cache_lengths_tensor=cache_lengths_tensor, slots=slots, + cu_slots=cu_slots, max_input_length=max_input_length, max_current_length=max_current_length, prefilling=prefilling, @@ -773,14 +844,49 @@ def prepare_for_prefill(self): # it simplifies everything assert self.speculative_ids is None + device = self.block_tables_tensor.device + + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + self.cu_seqlen_prefill = torch.nn.functional.pad( + torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) + ).to(torch.int32) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + + # If the device supports Triton, we can use a fused kernel + if has_triton(): + self.position_ids = torch.empty( + len(self.input_ids), dtype=torch.int32, device=device + ) + self.slot_indices = torch.empty( + len(self.input_ids), dtype=torch.int64, device=device + ) + cu_slots_gpu = self.cu_slots.to(device) + + prepare_position_slot_ids( + self.max_input_length, + self.cache_lengths_tensor, + self.cu_seqlen_prefill, + cu_slots_gpu, + self.position_ids, + self.slot_indices, + ) + position_ids = [] - cu_seqlen_prefill = [0] slot_indices = [] prefill_cache_indices = [] all_prefill_logprobs = True no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] prefill_cu_outlens = [0] # Cumulative length @@ -788,7 +894,6 @@ def prepare_for_prefill(self): cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 - slots = [] adapter_indices_list = [] adapter_list = [] @@ -810,24 +915,33 @@ def prepare_for_prefill(self): ) ): next_chunk_length = input_length - # Position ids - request_position_ids = torch.arange(cache_length, cache_length + input_length, dtype=torch.int32) - position_ids.append(request_position_ids) + + if not has_triton(): + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) + request_slot_indices = torch.arange( + cache_length + cumulative_slot_tokens, + cache_length + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) - if not r.slots: - request_slots = [s for b in blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] - else: - request_slots = r.slots + slot_indices.append(request_slot_indices) - request_slots = request_slots[cache_length:] - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) + # Update + cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill if SLIDING_WINDOW is not None: @@ -844,30 +958,12 @@ def prepare_for_prefill(self): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append( - torch.arange( - cumulative_length, - cumulative_length + input_length, - dtype=torch.int64, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], - dtype=torch.int64, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - slots.extend(request_slots) - slot_indices.append(request_slot_indices) - if SLIDING_WINDOW is not None: prefill_cache_indices.append(request_prefill_cache_indices) @@ -876,41 +972,81 @@ def prepare_for_prefill(self): # Update cumulative_length += next_chunk_length - cumulative_slot_tokens += len(request_slots) - device = self.block_tables_tensor.device + if not all_prefill_logprobs and not no_prefill_logprobs: + prefill_head_indices = [] + prefill_next_token_indices = [] - if isinstance(self.input_ids, list): - if len(self) > 1: - input_ids = np.concatenate(self.input_ids, dtype=np.int64) - else: - input_ids = self.input_ids[0] - self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + # Cumulative length + cumulative_length = 0 + prefill_out_cumulative_length = 0 + + for i, ( + r, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.input_lengths, + self.prefilling_mask, + ) + ): + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length if len(self) > 1: - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) + if position_ids: + position_ids = torch.cat(position_ids) + if slot_indices: + slot_indices = torch.cat(slot_indices) if SLIDING_WINDOW is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: - position_ids = position_ids[0] - slot_indices = slot_indices[0] + if position_ids: + position_ids = position_ids[0] + if slot_indices: + slot_indices = slot_indices[0] if SLIDING_WINDOW is not None: prefill_cache_indices = prefill_cache_indices[0] + if not has_triton(): + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cu_outlens = prefill_cu_outlens - cu_seqlen_prefill = torch.tensor(cu_seqlen_prefill, device=device, dtype=torch.int32) - self.cu_seqlen_prefill = cu_seqlen_prefill - self.position_ids = position_ids.to(device) - self.slot_indices = slot_indices.to(device) self.prefill_cache_indices = prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None - self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32, device=device) if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.cat(prefill_head_indices).to(device) @@ -918,8 +1054,6 @@ def prepare_for_prefill(self): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - self.slots = torch.tensor(slots, dtype=torch.int64, device=device) - self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32, device=device) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) @@ -1593,11 +1727,12 @@ def generate_token( # Since we are done prefilling, all the tensors that were concatenating values for all the requests # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) - elif not prefill: - next_position_ids = batch.position_ids + indices = batch.cu_seqlen_prefill[1:] - 1 + batch.position_ids = batch.position_ids[indices] + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ + indices + ] # Zipped iterator iterator = zip( @@ -1616,8 +1751,10 @@ def generate_token( # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - index = 0 # Cumulative length + cu_accepted_ids = torch.nn.functional.pad( + torch.cumsum(accepted_ids, dim=0), (1, 0) + ) cumulative_length = 0 for i, ( request, @@ -1629,53 +1766,55 @@ def generate_token( request_was_prefilling, request_is_prefilling, ) in enumerate(iterator): - if prefill and finished_prefilling: + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length - - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] - - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] - - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids + + # If the device does not support triton, we copy one by one + if not request_is_prefilling and not has_triton(): + # Only save tokens if we are done prefilling for this request + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - index += n_accepted_ids cumulative_length += input_length + + # If the device support triton, we can use a fused kernel + if has_triton(): + copy_next_input_ids_inplace( + speculate + 1, + batch.all_input_ids_tensor, + batch.cache_lengths_tensor, + batch.input_lengths_tensor, + batch.prompt_lengths_tensor, + next_input_ids, + cu_accepted_ids, + ) # Update values # These values can be updated without a GPU -> CPU sync if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1869,8 +2008,10 @@ def generate_token( # processing stopped = False new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length else: - new_input_length = n_accepted_ids + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 # Append next token to all tokens next_token_texts = [] left = 0 @@ -1975,12 +2116,10 @@ def generate_token( # Update values # logger.info(f"!!! UPDATE VALUES {i} n_accepted_ids={n_accepted_ids} new_input_length={new_input_length} input_length={input_length} cache_length={cache_length}") index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset diff --git a/server/lorax_server/models/metadata_kernels.py b/server/lorax_server/models/metadata_kernels.py new file mode 100644 index 000000000..7e2c2b1ac --- /dev/null +++ b/server/lorax_server/models/metadata_kernels.py @@ -0,0 +1,347 @@ +# From: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/metadata_kernels.py + +import torch +import triton + +import triton.language as tl + +from loguru import logger +from typing import List, Optional +from torch.utils._triton import has_triton as has_triton_torch + +from lorax_server.utils.import_utils import ( + SYSTEM, +) +_HAS_TRITON: Optional[bool] = None + + +def has_triton(): + global _HAS_TRITON + if _HAS_TRITON is None: + # FIXME: it seems that has_triton_torch is bugged on RocM + # For now, only accept cuda + _HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False + if _HAS_TRITON: + logger.info("Using optimized Triton indexing kernels.") + + return _HAS_TRITON + + +def block_tables_to_padded( + max_blocks: int, + cu_seqlen: torch.Tensor, + block_tables: torch.Tensor, + block_tables_ragged: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_blocks, meta["BLOCK_SIZE"]), + len(block_tables), + ) + + triton_block_tables_to_padded[grid]( + cu_seqlen, + block_tables, + block_tables_ragged, + block_tables.shape[1], + BLOCK_SIZE=256, + ) + + +def block_tables_to_ragged( + *, + block_tables: torch.Tensor, + input_lengths: List[int], + cache_lengths: List[int], + input_lengths_tensor: torch.Tensor, + cache_lengths_tensor: torch.Tensor, + max_current_length: int +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(cache_lengths) + + total_len = sum(input_lengths) + sum(cache_lengths) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + if has_triton(): + cu_seqlen = torch.nn.functional.pad( + torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0) + ) + + def grid(meta): + return ( + triton.cdiv(max_current_length, meta["BLOCK_SIZE"]), + len(cache_lengths), + ) + + triton_block_tables_to_ragged[grid]( + cu_seqlen, + block_tables, + block_tables_ragged, + block_tables.shape[1], + BLOCK_SIZE=256, + ) + else: + offset = 0 + for i, (input_length, cache_length) in enumerate( + zip(input_lengths, cache_lengths) + ): + seq_len = cache_length + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged + + +def copy_next_input_ids_inplace( + max_next_input_ids: int, + all_input_ids: torch.Tensor, + cache_lengths: torch.Tensor, + input_lengths: torch.Tensor, + prompt_lengths: torch.Tensor, + next_input_ids: torch.Tensor, + cu_accepted_ids: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]), + len(all_input_ids), + ) + + triton_copy_next_input_ids_inplace[grid]( + all_input_ids, + cache_lengths, + input_lengths, + prompt_lengths, + next_input_ids, + cu_accepted_ids, + all_input_ids.shape[1], + BLOCK_SIZE=16, + ) + + +def prepare_position_slot_ids( + max_input_length: int, + cache_lengths: torch.Tensor, + cu_seqlen: torch.Tensor, + cu_slots: torch.Tensor, + position_ids: torch.Tensor, + slot_indices: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_input_length, meta["BLOCK_SIZE"]), + len(cache_lengths), + ) + + triton_prepare_position_slot_ids[grid]( + cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256 + ) + + +def slots_filtering( + max_slots: int, + slots: torch.Tensor, + filtered_slots: torch.Tensor, + cu_slots: torch.Tensor, + slots_start: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_slots, meta["BLOCK_SIZE"]), + len(slots_start), + ) + + triton_slots_filtering[grid]( + slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256 + ) + + +@triton.jit +def triton_slots_filtering( + # Inputs + slots_ptr, + filtered_slots_ptr, + slots_start_ptr, + cu_slots_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + filter_start = tl.load(slots_start_ptr + bid) + + slot_start = tl.load(cu_slots_ptr + bid) + slot_end = tl.load(cu_slots_ptr + bid + 1) + + mask = (slot_start + block_arange) < slot_end + + slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask) + tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask) + + +@triton.jit +def triton_block_tables_to_padded( + # Inputs + cu_seqlen_ptr, + # Outputs + block_tables_ptr, + block_tables_ragged_ptr, + # Stride + stride_block_tables, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + mask = (seq_start + block_arange) < seq_end + + blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask) + tl.store( + block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask + ) + + +@triton.jit +def triton_block_tables_to_ragged( + # Inputs + cu_seqlen_ptr, + # Outputs + block_tables_ptr, + block_tables_ragged_ptr, + # Stride + stride_block_tables, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + mask = (seq_start + block_arange) < seq_end + + blocks = tl.load( + block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask + ) + tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask) + + +@triton.jit +def triton_copy_next_input_ids_inplace( + # Inputs + all_input_ids_ptr, + cache_lengths_ptr, + input_lengths_ptr, + prompt_lengths_ptr, + next_input_ids_ptr, + cu_accepted_ids_ptr, + # Stride + stride_all_input_ids, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in max_accepted_ids / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + # Used for correctly indexing in all_input_ids + cache_length = tl.load(cache_lengths_ptr + bid) + input_length = tl.load(input_lengths_ptr + bid) + prompt_length = tl.load(prompt_lengths_ptr + bid) + + # Start/End of next_input_ids for this request + next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid) + next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1) + + # Mask values out of range + mask = (next_input_ids_start + block_arange) < next_input_ids_end + + # Mask values for request still prefilling + decode_mask = (cache_length + input_length + block_arange) >= prompt_length + + mask = mask & decode_mask + + # Load this request next input ids + next_input_ids = tl.load( + next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask + ) + + # Store in all_input_ids, since it is a 2D tensor, apply stride * bid + tl.store( + all_input_ids_ptr + + stride_all_input_ids * bid + + cache_length + + input_length + + block_arange, + next_input_ids, + mask=mask, + ) + + +@triton.jit +def triton_prepare_position_slot_ids( + # Inputs + cache_lengths_ptr, + cu_seqlen_ptr, + cu_slots_ptr, + # Outputs + position_ids_ptr, + slot_indices_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in max_input_length / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + cache_length = tl.load(cache_lengths_ptr + bid) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + slot_start = tl.load(cu_slots_ptr + bid) + + mask = (seq_start + block_arange) < seq_end + + tl.store( + position_ids_ptr + seq_start + block_arange, + cache_length + block_arange, + mask=mask, + ) + tl.store( + slot_indices_ptr + seq_start + block_arange, + slot_start + cache_length + block_arange, + mask=mask, + ) \ No newline at end of file From 57c33d78061f06a1e6e0136f4111035919f9af67 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 15:17:35 -0700 Subject: [PATCH 39/76] Fix ragged --- server/lorax_server/models/flash_causal_lm.py | 4 +++- server/lorax_server/utils/attention/utils.py | 21 ------------------- server/lorax_server/utils/graph.py | 8 ++++++- 3 files changed, 10 insertions(+), 23 deletions(-) delete mode 100644 server/lorax_server/utils/attention/utils.py diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ee8a42e79..646f586ae 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -36,7 +36,6 @@ from lorax_server.utils.state import ( BLOCK_SIZE, FLASH_INFER, - LORAX_PROFILER_DIR, get_max_prefill_tokens, get_speculative_tokens, get_supports_chunking, @@ -1556,6 +1555,9 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=max_s, ) with self._forward_context( diff --git a/server/lorax_server/utils/attention/utils.py b/server/lorax_server/utils/attention/utils.py deleted file mode 100644 index 8292be916..000000000 --- a/server/lorax_server/utils/attention/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import List - -import torch - - -def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] -) -> torch.Tensor: - """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(cache_lengths) - - total_len = sum(input_lengths) + sum(cache_lengths) - block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) - - offset = 0 - for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): - seq_len = cache_length + input_length - block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] - offset += seq_len - - return block_tables_ragged diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 734bcda4e..0896a016e 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -7,6 +7,7 @@ from statistics import median from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple +from lorax_server.models.metadata_kernels import block_tables_to_ragged import numpy as np import torch from loguru import logger @@ -17,7 +18,6 @@ from lorax_server.adapters.lora import BatchLoraWeights, RankSegments from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.attention.common import Seqlen -from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.sgmv import BGMV_MAX_RANK, PunicaWrapper from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, get_speculative_tokens @@ -263,6 +263,9 @@ def trace( block_tables=block_tables, input_lengths=input_lengths, cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_total_tokens, ) block_tables_ptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) @@ -392,6 +395,9 @@ def forward( block_tables=block_tables, input_lengths=seqlen.input_lengths, cache_lengths=seqlen.cache_lengths, + input_lengths_tensor=seqlen.input_lengths, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, ) self.input_state.block_tables[: block_tables.shape[0]] = block_tables else: From ece47f7f28aef6ee89935f92b7079a47b191faff Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 15:23:18 -0700 Subject: [PATCH 40/76] More fixes --- server/lorax_server/models/flash_causal_lm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 646f586ae..0a95391dd 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -26,7 +26,6 @@ from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files from lorax_server.utils.attention.common import Seqlen -from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed from lorax_server.utils.graph import GraphCache from lorax_server.utils.import_utils import get_cuda_free_memory @@ -1662,7 +1661,6 @@ def generate_token( else: prefill_logprobs = None next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices finished_prefilling = True next_chunk_lengths = [] @@ -1800,7 +1798,7 @@ def generate_token( # If the device support triton, we can use a fused kernel if has_triton(): copy_next_input_ids_inplace( - speculate + 1, + speculative_tokens + 1, batch.all_input_ids_tensor, batch.cache_lengths_tensor, batch.input_lengths_tensor, @@ -1818,7 +1816,6 @@ def generate_token( batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) From cb99320f29a7572ab14118b8c2367171bb64aad7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 16:03:49 -0700 Subject: [PATCH 41/76] Revert docker --- .github/workflows/build.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index bf052f7d1..13b9e96ca 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,7 +5,6 @@ on: push: branches: - 'main' - - 'optimizations' tags: - 'v*' @@ -70,7 +69,10 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-4,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix=,suffix=,format=short + type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} - name: Create a hash from tags env: From 466ea37a6801cf8dbc3d5e5d94855b78ab3a5601 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 16:06:06 -0700 Subject: [PATCH 42/76] Renamed sgmv -> punica --- server/lorax_server/adapters/lora.py | 2 +- server/lorax_server/adapters/medusa.py | 2 +- server/lorax_server/adapters/weights.py | 2 +- server/lorax_server/models/flash_causal_lm.py | 2 +- server/lorax_server/models/model.py | 2 +- server/lorax_server/server.py | 2 +- server/lorax_server/utils/graph.py | 2 +- server/lorax_server/utils/layers.py | 2 +- server/lorax_server/utils/{sgmv.py => punica.py} | 0 server/tests/utils/test_lora.py | 2 +- server/tests/utils/test_sgmv.py | 2 +- 11 files changed, 10 insertions(+), 10 deletions(-) rename server/lorax_server/utils/{sgmv.py => punica.py} (100%) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index bab1119d8..6ef1c85f9 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -10,7 +10,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.sgmv import ( +from lorax_server.utils.punica import ( BGMV_MAX_RANK, MAX_RANK_CUSTOM, get_tmp_tensors, diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index ec55ca608..5ca8b5315 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -10,7 +10,7 @@ from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.layers import FastLinear, TensorParallelColumnLinear from lorax_server.utils.segments import find_segments -from lorax_server.utils.sgmv import segmented_matmul +from lorax_server.utils.punica import segmented_matmul from lorax_server.utils.state import LORAX_SPECULATION_MAX_BATCH_SIZE, get_speculative_tokens from lorax_server.utils.weights import AbstractWeights, InMemoryWeights diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index d655ed865..a0122e4b4 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -9,7 +9,7 @@ from lorax_server.utils.lora import LM_HEAD if TYPE_CHECKING: - from lorax_server.utils.sgmv import PunicaWrapper + from lorax_server.utils.punica import PunicaWrapper @dataclass diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 02f72ef0c..b4d01ed75 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union -from lorax_server.utils.sgmv import PunicaWrapper +from lorax_server.utils.punica import PunicaWrapper import numpy as np import torch import torch.distributed diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index fcf80f906..132cd5ab5 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -5,7 +5,7 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.medusa_lora import MedusaLoraWeights -from lorax_server.utils.sgmv import pad_to_min_rank +from lorax_server.utils.punica import pad_to_min_rank import torch from loguru import logger from transformers import PreTrainedTokenizerBase diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 796c377df..39c96552c 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -23,7 +23,7 @@ enum_string_to_adapter_source, is_base_model, ) -from lorax_server.utils.sgmv import has_sgmv +from lorax_server.utils.punica import has_sgmv from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 0896a016e..2c159d87e 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -18,7 +18,7 @@ from lorax_server.adapters.lora import BatchLoraWeights, RankSegments from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.attention.common import Seqlen -from lorax_server.utils.sgmv import BGMV_MAX_RANK, PunicaWrapper +from lorax_server.utils.punica import BGMV_MAX_RANK, PunicaWrapper from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, get_speculative_tokens if TYPE_CHECKING: diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 4ad18298e..b79252248 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -12,7 +12,7 @@ from lorax_server.layers.linear import FastLinear, get_linear # noqa: F401 from lorax_server.layers.tensor_parallel import SuperLayer, TensorParallelColumnLinear, TensorParallelHead # noqa: F401 from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.sgmv import ( +from lorax_server.utils.punica import ( add_lora_a_bgmv, add_lora_b_bgmv, has_sgmv, diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/punica.py similarity index 100% rename from server/lorax_server/utils/sgmv.py rename to server/lorax_server/utils/punica.py diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 05f810365..24ecd81ea 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -9,7 +9,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +from lorax_server.utils.punica import MIN_RANK_CUSTOM class FakeAdapterWeights(AdapterWeights): diff --git a/server/tests/utils/test_sgmv.py b/server/tests/utils/test_sgmv.py index 0c535f1b1..5b94270a0 100644 --- a/server/tests/utils/test_sgmv.py +++ b/server/tests/utils/test_sgmv.py @@ -3,7 +3,7 @@ import pytest import torch -from lorax_server.utils.sgmv import ( +from lorax_server.utils.punica import ( get_tmp_tensors, has_sgmv, lora_a_sgmv_cutlass, From 2f80c6a371795c67a47bef686f4843447e0a6f3d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 16:28:11 -0700 Subject: [PATCH 43/76] Refactor PunicaWrapper --- server/lorax_server/models/flash_causal_lm.py | 1 + server/lorax_server/utils/layers.py | 142 +++++++++--------- server/lorax_server/utils/punica.py | 24 +-- 3 files changed, 79 insertions(+), 88 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index b4d01ed75..0bf6289ed 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1336,6 +1336,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model max_num_batched_tokens=get_max_prefill_tokens(), max_batches=256, # TODO(travis): consider how to handle this if we exceed this limit device=self.device, + enabled=not self.dynamic_adapter_loading_enabled # only supported for now with statically loaded adapters ) torch.cuda.empty_cache() diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index b79252248..b64694a17 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -75,82 +75,90 @@ def forward_layer_type( data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None - if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + # Triton Punica kernels + if adapter_data.punica_wrapper.enabled: if end_idx - start_idx != result.shape[1]: - # proj = torch.zeros_like(result[:, start_idx:end_idx]) y_offset = start_idx y_slice_size = end_idx - start_idx else: - # proj = result y_offset = None y_slice_size = None + + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)] + adapter_data.punica_wrapper.add_lora( + result, + input, + lora_a_weights, + lora_b_weights, + 1.0, + y_offset, + y_slice_size, + callback=self.collect_lora_a if self.process_group.size() > 1 else None, + ) + + # Legacy Punica kernels + elif has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result for r, rank_segments in data.rank_data.items(): - # lora_a_ptr = rank_segments.lora_a_ptr - # lora_b_ptr = rank_segments.lora_b_ptr - - lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)] - adapter_data.punica_wrapper.add_lora( - result, - input, - lora_a_weights, - lora_b_weights, - 1.0, - y_offset, - y_slice_size, - callback=self.collect_lora_a if self.process_group.size() > 1 else None, - ) + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if data.use_sgmv: + # Use SGMV for prefill + if lora_a_ptr is not None and lora_b_ptr is not None: + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + if lora_a_ptr is not None and lora_b_ptr is not None: + v = torch.zeros((input.size(0), r), dtype=input.dtype, device=input.device) + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) - # if data.use_sgmv: - # # Use SGMV for prefill - # if lora_a_ptr is not None and lora_b_ptr is not None: - # v = lora_a_sgmv_cutlass( - # input, - # rank_segments.tmp_shrink, - # lora_a_ptr, - # rank_segments.segment_starts, - # rank_segments.segment_ends, - # self.layer_id, - # r, - # ) - - # if self.process_group.size() > 1: - # v = self.collect_lora_a(v) - - # lora_b_sgmv_cutlass( - # proj, - # v, - # rank_segments.tmp_expand, - # lora_b_ptr, - # rank_segments.segment_starts, - # rank_segments.segment_ends, - # self.layer_id, - # ) - # else: - # # Use BGMV for decode - # if lora_a_ptr is not None and lora_b_ptr is not None: - # v = torch.zeros((input.size(0), r), dtype=input.dtype, device=input.device) - # add_lora_a_bgmv( - # v, - # input, - # lora_a_ptr, - # rank_segments.indices, - # self.layer_id, - # ) - - # if self.process_group.size() > 1: - # v = self.collect_lora_a(v) - - # add_lora_b_bgmv( - # proj, - # v, - # lora_b_ptr, - # rank_segments.indices, - # self.layer_id, - # ) - - # if end_idx - start_idx != result.shape[1]: - # result[:, start_idx:end_idx] += proj + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + + # Vanilla PyTorch else: adapter_indices = adapter_data.meta.adapter_indices if data is not None and data.prefill_head_indices is not None and data.layer_name == LM_HEAD: diff --git a/server/lorax_server/utils/punica.py b/server/lorax_server/utils/punica.py index 93f8b20a5..f924a6e59 100644 --- a/server/lorax_server/utils/punica.py +++ b/server/lorax_server/utils/punica.py @@ -390,6 +390,7 @@ def convert_mapping( ) +# Source: https://github.com/vllm-project/vllm/blob/main/vllm/lora/punica.py class PunicaWrapper: """ PunicaWrapper is designed to manage and provide metadata for the punica @@ -398,7 +399,7 @@ class PunicaWrapper: """ def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: str): + device: str, enabled: bool): self._token_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) @@ -434,26 +435,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self.batch_size: int = -1 self.is_prefill = False self.no_lora = False - - # def update_metadata( - # self, - # meta: "AdapterBatchMetadata", - # prefill: bool, - # max_loras: int, - # vocab_size: int, - # extra_vocab_size: int, - # long_lora_context = None, - # ): - - # self._update_base_metadata(meta, max_loras, - # vocab_size, extra_vocab_size, - # long_lora_context) - # if prefill: - # # Update metadata required for prefill-related operators. - # self._update_prefill_metada(self.token_lora_indices) - # self.is_prefill = True - # else: - # self.is_prefill = False + self.enabled = enabled def update_metadata( self, From 47bfd0c885ff9a7eeea5534c5dcbd5cd541fd3af Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 16:31:43 -0700 Subject: [PATCH 44/76] More configuration --- server/lorax_server/models/flash_causal_lm.py | 9 ++++++--- server/lorax_server/utils/layers.py | 2 +- server/lorax_server/utils/punica.py | 7 +++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 0bf6289ed..3bc8874b6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union -from lorax_server.utils.punica import PunicaWrapper +from lorax_server.utils.punica import LORAX_PUNICA_TRION_DISABLED, PunicaWrapper import numpy as np import torch import torch.distributed @@ -1334,9 +1334,12 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.punica_wrapper = PunicaWrapper( max_num_batched_tokens=get_max_prefill_tokens(), - max_batches=256, # TODO(travis): consider how to handle this if we exceed this limit + max_batches=256, # TODO(travis): find a better way to set this programmatically device=self.device, - enabled=not self.dynamic_adapter_loading_enabled # only supported for now with statically loaded adapters + enabled=( + not self.dynamic_adapter_loading_enabled and # only supported for now with statically loaded adapters + not LORAX_PUNICA_TRION_DISABLED + ) ) torch.cuda.empty_cache() diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index b64694a17..5d3ca5fe2 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -76,7 +76,7 @@ def forward_layer_type( data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None # Triton Punica kernels - if adapter_data.punica_wrapper.enabled: + if adapter_data.punica_wrapper.enabled and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size: if end_idx - start_idx != result.shape[1]: y_offset = start_idx y_slice_size = end_idx - start_idx diff --git a/server/lorax_server/utils/punica.py b/server/lorax_server/utils/punica.py index f924a6e59..a9c276dad 100644 --- a/server/lorax_server/utils/punica.py +++ b/server/lorax_server/utils/punica.py @@ -3,6 +3,7 @@ from functools import lru_cache from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union +from loguru import logger import torch import torch.nn.functional as F @@ -26,6 +27,11 @@ HAS_SGMV = False +LORAX_PUNICA_TRION_DISABLED = bool(os.environ.get("LORAX_PUNICA_TRION_DISABLED", "")) +if LORAX_PUNICA_TRION_DISABLED: + logger.info("LORAX_PUNICA_TRION_DISABLED is set, disabling Punica Trion kernels.") + + MIN_SGMV_RANK = 8 MIN_RANK_CUSTOM = 16 MAX_RANK_CUSTOM = 128 @@ -431,6 +437,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._lora_indices_per_batch = torch.empty(max_batches, dtype=torch.long, device=device) + self.max_batch_size = max_batches self.max_length: int = 0 self.batch_size: int = -1 self.is_prefill = False From 2343d7851678f11160241870b57ee96ae0ae7736 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 16:32:57 -0700 Subject: [PATCH 45/76] More logs --- server/lorax_server/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 39c96552c..a15a5fde9 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -23,7 +23,7 @@ enum_string_to_adapter_source, is_base_model, ) -from lorax_server.utils.punica import has_sgmv +from lorax_server.utils.punica import LORAX_PUNICA_TRION_DISABLED, has_sgmv from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens @@ -424,10 +424,12 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: await server.start() # Log SGMV kernel status + if not LORAX_PUNICA_TRION_DISABLED: + logger.info("Trion kernel is enabled, multi-LoRA inference will be fast!") if has_sgmv(): logger.info("SGMV kernel is enabled, multi-LoRA inference will be fast!") else: - logger.info("SGMV kernel is disabled, multi-LoRA inference may be slow") + logger.info("Punica kernels are disabled, multi-LoRA inference may be slow") logger.info("Server started at {}".format(local_url)) From f915abe780fe5e4b9434e5ce8c4e36cc720a609c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 16:48:46 -0700 Subject: [PATCH 46/76] Fixes --- server/lorax_server/server.py | 2 +- server/lorax_server/utils/layers.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index a15a5fde9..3d7a63f22 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -424,7 +424,7 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: await server.start() # Log SGMV kernel status - if not LORAX_PUNICA_TRION_DISABLED: + if not LORAX_PUNICA_TRION_DISABLED and not model.dynamic_adapter_loading_enabled: logger.info("Trion kernel is enabled, multi-LoRA inference will be fast!") if has_sgmv(): logger.info("SGMV kernel is enabled, multi-LoRA inference will be fast!") diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 5d3ca5fe2..61be1a82b 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -74,9 +74,14 @@ def forward_layer_type( ) -> torch.Tensor: data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None + can_vectorize = data is not None and data.can_vectorize(self.process_group) # Triton Punica kernels - if adapter_data.punica_wrapper.enabled and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size: + if ( + adapter_data.punica_wrapper.enabled and + input.shape[0] <= adapter_data.punica_wrapper.max_batch_size and + can_vectorize + ): if end_idx - start_idx != result.shape[1]: y_offset = start_idx y_slice_size = end_idx - start_idx @@ -97,7 +102,7 @@ def forward_layer_type( ) # Legacy Punica kernels - elif has_sgmv() and data is not None and data.can_vectorize(self.process_group): + elif has_sgmv() and can_vectorize: if end_idx - start_idx != result.shape[1]: proj = torch.zeros_like(result[:, start_idx:end_idx]) else: From ad460c0f75d25a165313c7205c2663fb077bd295 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:07:01 -0700 Subject: [PATCH 47/76] Guard init --- server/lorax_server/models/flash_causal_lm.py | 4 ++-- server/lorax_server/models/model.py | 6 +++++- server/lorax_server/server.py | 4 ++-- server/lorax_server/utils/ops/__init__.py | 1 + server/lorax_server/utils/punica.py | 6 +++--- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 3bc8874b6..ef33f2f32 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union -from lorax_server.utils.punica import LORAX_PUNICA_TRION_DISABLED, PunicaWrapper +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, PunicaWrapper import numpy as np import torch import torch.distributed @@ -1338,7 +1338,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model device=self.device, enabled=( not self.dynamic_adapter_loading_enabled and # only supported for now with statically loaded adapters - not LORAX_PUNICA_TRION_DISABLED + not LORAX_PUNICA_TRITON_DISABLED ) ) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 132cd5ab5..c4d350e46 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -5,7 +5,7 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.medusa_lora import MedusaLoraWeights -from lorax_server.utils.punica import pad_to_min_rank +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, pad_to_min_rank import torch from loguru import logger from transformers import PreTrainedTokenizerBase @@ -267,6 +267,10 @@ def register_preloaded_adapters( ) self.preloaded_adapters.extend(preloaded_adapters) + if LORAX_PUNICA_TRITON_DISABLED: + # Following code is only applicable to Triton kernels + return + # For Triton kernels: need weights into contiguous tensor # dict of (layer_name, layer_id) -> (lora_a_weights, lora_b_weights) # where: diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 3d7a63f22..9cf7340bb 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -23,7 +23,7 @@ enum_string_to_adapter_source, is_base_model, ) -from lorax_server.utils.punica import LORAX_PUNICA_TRION_DISABLED, has_sgmv +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, has_sgmv from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens @@ -424,7 +424,7 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: await server.start() # Log SGMV kernel status - if not LORAX_PUNICA_TRION_DISABLED and not model.dynamic_adapter_loading_enabled: + if not LORAX_PUNICA_TRITON_DISABLED and not model.dynamic_adapter_loading_enabled: logger.info("Trion kernel is enabled, multi-LoRA inference will be fast!") if has_sgmv(): logger.info("SGMV kernel is enabled, multi-LoRA inference will be fast!") diff --git a/server/lorax_server/utils/ops/__init__.py b/server/lorax_server/utils/ops/__init__.py index e69de29bb..373f4d940 100644 --- a/server/lorax_server/utils/ops/__init__.py +++ b/server/lorax_server/utils/ops/__init__.py @@ -0,0 +1 @@ +# Source: https://github.com/vllm-project/vllm/tree/main/vllm/lora/ops \ No newline at end of file diff --git a/server/lorax_server/utils/punica.py b/server/lorax_server/utils/punica.py index a9c276dad..f41e65b6d 100644 --- a/server/lorax_server/utils/punica.py +++ b/server/lorax_server/utils/punica.py @@ -27,9 +27,9 @@ HAS_SGMV = False -LORAX_PUNICA_TRION_DISABLED = bool(os.environ.get("LORAX_PUNICA_TRION_DISABLED", "")) -if LORAX_PUNICA_TRION_DISABLED: - logger.info("LORAX_PUNICA_TRION_DISABLED is set, disabling Punica Trion kernels.") +LORAX_PUNICA_TRITON_DISABLED = bool(os.environ.get("LORAX_PUNICA_TRITON_DISABLED", "")) +if LORAX_PUNICA_TRITON_DISABLED: + logger.info("LORAX_PUNICA_TRITON_DISABLED is set, disabling Punica Trion kernels.") MIN_SGMV_RANK = 8 From 43c129bacb09c8cb9aaf7518c3645dcda360a3b3 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:08:26 -0700 Subject: [PATCH 48/76] Guard model has lm_head --- server/lorax_server/models/flash_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ef33f2f32..4bacfd9cb 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1627,7 +1627,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> lm_head_indices=batch.prefill_head_indices, ) - if skip_lm_head: + if skip_lm_head and hasattr(self.model, "lm_head"): # re-run through the LM head as the graph did not capture it out = self.model.lm_head(out[0], adapter_data) From 1c70ec6d6e2d83c783f7ca3bbb17d2151df4d13d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:21:50 -0700 Subject: [PATCH 49/76] Determine trace set from preloaded adapter set --- server/lorax_server/models/flash_causal_lm.py | 2 +- server/lorax_server/models/model.py | 8 +++++++- server/lorax_server/utils/graph.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 4bacfd9cb..99518f2dd 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1389,7 +1389,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.device, self.kv_cache, self.adapter_layers, - self.default_traced_adapter_layers, + self.traced_adapter_layers, self._forward_context, max_total_tokens, self.num_heads, diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index c4d350e46..5088ee8bb 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -233,6 +233,12 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: @property def adapter_layers(self) -> List[str]: return [] + + @property + def traced_adapter_layers(self) -> List[str]: + if self.layer_to_adapter_weights: + return list(self.layer_to_adapter_weights.keys()) + return self.default_traced_adapter_layers @property def default_traced_adapter_layers(self) -> List[str]: @@ -279,7 +285,7 @@ def register_preloaded_adapters( for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items(): layer_id_to_lora_a_weights = defaultdict(list) layer_id_to_lora_b_weights = defaultdict(list) - for i, adapter in enumerate(preloaded_adapters): + for adapter in preloaded_adapters: adapter_index = adapter.adapter_index adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index) if not isinstance(adapter_weights, LoraWeights): diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 2c159d87e..bdb7e2882 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -561,6 +561,7 @@ def get_estimated_cache_memory(self) -> int: def warmup(self): ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS) pool = None + logger.info("Tracing CUDA graphs with initial adapter layers: {}", self.default_traced_adapter_layers) with tqdm(total=ngraphs, desc="Trace CUDA graphs") as pbar: for batch_size in reversed(CACHED_BATCH_SIZES): pbar.set_postfix({"batch_size": batch_size}) From 3ebcbea790dd26c7b2104ca9b5895aa0d0bc8539 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:34:35 -0700 Subject: [PATCH 50/76] Plumb skip_lm_head --- .../models/custom_modeling/flash_cohere_modeling.py | 6 ++++++ .../models/custom_modeling/flash_dbrx_modeling.py | 5 +++++ .../models/custom_modeling/flash_gemma2_modeling.py | 5 +++++ .../models/custom_modeling/flash_gemma_modeling.py | 1 + .../models/custom_modeling/flash_gpt2_modeling.py | 1 + .../models/custom_modeling/flash_llama_modeling.py | 5 +++++ .../models/custom_modeling/flash_mistral_modeling.py | 5 +++++ .../models/custom_modeling/flash_mixtral_modeling.py | 5 +++++ .../models/custom_modeling/flash_neox_modeling.py | 5 +++++ .../models/custom_modeling/flash_phi3_modeling.py | 5 +++++ .../models/custom_modeling/flash_phi_modeling.py | 5 +++++ .../models/custom_modeling/flash_qwen_modeling.py | 5 +++++ .../models/custom_modeling/flash_rw_modeling.py | 5 +++++ .../models/custom_modeling/flash_santacoder_modeling.py | 5 +++++ server/lorax_server/models/custom_modeling/llava_next.py | 5 +++++ server/lorax_server/models/custom_modeling/mllama.py | 2 ++ 16 files changed, 70 insertions(+) diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index 64dddfc36..0e8771a11 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -524,6 +524,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -538,6 +539,11 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + # FIXME: simply running the LM head is not sufficient since we also need to scale the logits + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) logits *= self.logit_scale if speculative_logits is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index 5c16c81df..634043379 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -1009,6 +1009,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -1023,5 +1024,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index eec7fddf8..7b9464110 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -539,6 +539,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -554,5 +555,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index 2e8b6cba9..09be5d766 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -538,6 +538,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index eeb6e8d38..3b65e969d 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -367,6 +367,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index b308edc5d..961c87f16 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -598,6 +598,7 @@ def forward( prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -615,5 +616,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index b74ad7e73..c2eb566ed 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -610,6 +610,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -635,5 +636,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 58e939932..3e17ccdb9 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -963,6 +963,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -987,5 +988,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index cc6df3382..70405a874 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -357,6 +357,7 @@ def forward( max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.gpt_neox( input_ids, @@ -370,5 +371,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits = self.embed_out(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index b0b48688d..4edddea78 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -506,6 +506,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -520,5 +521,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index d600928b9..e1b005052 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -388,6 +388,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -402,5 +403,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 606248af1..7efbe22f6 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -507,6 +507,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, @@ -521,5 +522,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 4f3b36765..c43ba1d46 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -592,6 +592,7 @@ def forward( max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, @@ -605,5 +606,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits = self.lm_head(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 4e98b97a2..f8325a864 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -423,6 +423,7 @@ def forward( max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, @@ -436,5 +437,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits = self.lm_head(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/llava_next.py b/server/lorax_server/models/custom_modeling/llava_next.py index bede2691e..81c170cd5 100644 --- a/server/lorax_server/models/custom_modeling/llava_next.py +++ b/server/lorax_server/models/custom_modeling/llava_next.py @@ -178,6 +178,7 @@ def forward( pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional["AdapterBatchData"] = None, + skip_lm_head: bool = False, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: @@ -264,5 +265,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.text_model.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index c48448b36..7aa2e01e0 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -884,6 +884,7 @@ def forward( # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + skip_lm_head: bool = False, ): if cross_attention_states is not None: seqlen_q = len(image_indices) @@ -954,6 +955,7 @@ def forward( prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, + skip_lm_head=skip_lm_head, ) return outputs From 922c5d6f86ab8792111605fa1ccb18f65de93d66 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:39:08 -0700 Subject: [PATCH 51/76] Cleanup comments --- server/lorax_server/models/flash_causal_lm.py | 41 ++++--------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 99518f2dd..2ad5119f5 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -553,9 +553,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": segment_indices=adapter_segment_indices, ) - # logger.info("!!! FILTER slots {} -> {}", self.slots, slots) - # logger.info("!!! FILTER slots_indices {} -> {}", self.slot_indices, slot_indices) - return type(self)( batch_id=self.batch_id, requests=requests, @@ -767,19 +764,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch sequence_processors=sequence_processors, ) - # Discard speculative IDs if they are not present in all batches + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + speculative_ids = None if get_speculative_tokens() > 0: - keep_speculative_ids = all(b.speculative_ids is not None for b in batches) - if not keep_speculative_ids: + if all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: logger.info("Discarding speculative IDs, not every batch has them") - - speculative_ids = ( - torch.cat( - [b.speculative_ids for b in batches], dim=0) - if keep_speculative_ids else None - ) - else: - speculative_ids = None if adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() @@ -791,9 +783,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch segment_indices=adapter_segment_indices, ) - # logger.info("!!! CONCATENATE slots {} -> {}", [b.slots for b in batches], slots) - # logger.info("!!! CONCATENATE slots_indices {} -> {}", [b.slot_indices for b in batches], slot_indices) - return cls( batch_id=batches[0].batch_id, requests=requests, @@ -1064,9 +1053,6 @@ def prepare_for_prefill(self): segment_indices=adapter_segment_indices, ) - # logger.info("!!! PREPARE_FOR_PREFILL slots {}", self.slots) - # logger.info("!!! PREPARE_FOR_PREFILL slots_indices {}", self.slot_indices) - def __len__(self): return len(self.requests) @@ -1525,11 +1511,6 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length - # logger.info("!!! BLOCKS={} {}\n SLOTS={} {}\n SLOT_INDICES={} {}", - # block_tables.tolist(), block_tables.shape, - # batch.slots.tolist(), batch.slots.shape, - # batch.slot_indices.tolist(), batch.slot_indices.shape) - if batch.speculative_ids is not None: speculative_ids = batch.speculative_ids @@ -1540,18 +1521,14 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> arange_int = arange.to(dtype=torch.int32) new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - # logger.info("!!! SLOT INDICES {} -> {}", batch.slot_indices.tolist(), slot_indices.tolist()) - slots = batch.slots[slot_indices] - # slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - # logger.info("!!! NEW SLOTS {}", slots.tolist(), slots.shape) - - # logger.info("!!! BEFORE {} {}", input_lengths, batch.cache_lengths_tensor) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) - # logger.info("!!! AFTER {} {}", input_lengths, cache_lengths_tensor) block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length From b2de54fe1ee3c577f9c3bebe6a21f3b952ae5ed8 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:57:25 -0700 Subject: [PATCH 52/76] Fixed orient for rank --- server/lorax_server/adapters/lora.py | 3 +-- server/lorax_server/models/model.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 6ef1c85f9..b2aec8a0d 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -100,8 +100,7 @@ def __init__( self._is_transposed = False # [num_layers, hidden_size, r] - # TODO(travis): add this back if rank is 8 and we're not using triton - # weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 5088ee8bb..4e75e27ad 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -5,7 +5,7 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.medusa_lora import MedusaLoraWeights -from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, pad_to_min_rank +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, pad_to_min_rank, use_cutlass_shrink import torch from loguru import logger from transformers import PreTrainedTokenizerBase @@ -301,8 +301,10 @@ def register_preloaded_adapters( continue # transpose into col major - lora_a = adapter_weights.weights_a.transpose(1, 2) lora_b = adapter_weights.weights_b.transpose(1, 2) + lora_a = adapter_weights.weights_a + if use_cutlass_shrink(lora_b.size(2)): + lora_a = lora_a.transpose(1, 2) nlayers = lora_a.size(0) for layer_id in range(nlayers): From 35c7de2b562c393b7452e0afc7e8482f83c10d19 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 22:59:52 -0700 Subject: [PATCH 53/76] Format --- server/lorax_server/adapters/medusa.py | 17 +- server/lorax_server/adapters/weights.py | 8 +- .../custom_modeling/flash_cohere_modeling.py | 4 +- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../custom_modeling/flash_gemma2_modeling.py | 4 +- .../custom_modeling/flash_llama_modeling.py | 4 +- .../custom_modeling/flash_mistral_modeling.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi3_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 4 +- .../custom_modeling/flash_qwen2_modeling.py | 4 +- .../custom_modeling/flash_qwen_modeling.py | 4 +- .../custom_modeling/flash_rw_modeling.py | 4 +- .../flash_santacoder_modeling.py | 4 +- .../models/custom_modeling/llava_next.py | 4 +- server/lorax_server/models/flash_causal_lm.py | 129 +++------ .../lorax_server/models/metadata_kernels.py | 44 +-- server/lorax_server/models/model.py | 32 +-- server/lorax_server/server.py | 6 +- .../utils/flashinfer_attention.py | 12 +- server/lorax_server/utils/graph.py | 18 +- server/lorax_server/utils/layers.py | 10 +- server/lorax_server/utils/ops/__init__.py | 2 +- server/lorax_server/utils/ops/bgmv_expand.py | 30 +- .../utils/ops/bgmv_expand_slice.py | 29 +- server/lorax_server/utils/ops/bgmv_shrink.py | 15 +- server/lorax_server/utils/ops/libentry.py | 60 ++-- server/lorax_server/utils/ops/sgmv_expand.py | 32 +-- .../utils/ops/sgmv_expand_slice.py | 40 ++- server/lorax_server/utils/ops/sgmv_shrink.py | 26 +- server/lorax_server/utils/ops/utils.py | 19 +- server/lorax_server/utils/paged_attention.py | 6 +- server/lorax_server/utils/punica.py | 270 ++++++++---------- server/lorax_server/utils/tokens.py | 2 +- server/lorax_server/utils/torch_utils.py | 12 +- 36 files changed, 372 insertions(+), 503 deletions(-) diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 5ca8b5315..276764509 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -1,16 +1,16 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type -from loguru import logger import torch import torch.distributed +from loguru import logger from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import MEDUSA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.layers import FastLinear, TensorParallelColumnLinear -from lorax_server.utils.segments import find_segments from lorax_server.utils.punica import segmented_matmul +from lorax_server.utils.segments import find_segments from lorax_server.utils.state import LORAX_SPECULATION_MAX_BATCH_SIZE, get_speculative_tokens from lorax_server.utils.weights import AbstractWeights, InMemoryWeights @@ -22,6 +22,7 @@ _MEDUSA_ENABLED = False + @dataclass class MedusaConfig(AdapterConfig): medusa_num_heads: int @@ -312,11 +313,19 @@ def load( default_medusa=default_medusa, segments=MedusaSegments( w=[ - (adapter_weights[idx].model.medusa.linear.linear.weight.data if idx in adapter_weights else EMPTY_TENSOR) + ( + adapter_weights[idx].model.medusa.linear.linear.weight.data + if idx in adapter_weights + else EMPTY_TENSOR + ) for idx in segment_indices ], b=[ - (adapter_weights[idx].model.medusa.linear.linear.bias.data if idx in adapter_weights else EMPTY_TENSOR) + ( + adapter_weights[idx].model.medusa.linear.linear.bias.data + if idx in adapter_weights + else EMPTY_TENSOR + ) for idx in segment_indices ], s_start=segments[indices], diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index a0122e4b4..1ca77f001 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -136,10 +136,10 @@ def from_meta( if layer_weights: data[k] = layer_weights return AdapterBatchData( - meta=meta, - data=data, - layer_to_lora_weights=layer_to_lora_weights, - punica_wrapper=punica_wrapper, + meta=meta, + data=data, + layer_to_lora_weights=layer_to_lora_weights, + punica_wrapper=punica_wrapper, prefill=prefill, ) diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index 0e8771a11..1c6c23320 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -539,11 +539,11 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: # FIXME: simply running the LM head is not sufficient since we also need to scale the logits return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) logits *= self.logit_scale if speculative_logits is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index 634043379..252bdd514 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -1024,9 +1024,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index 7b9464110..2aca92c7d 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -555,9 +555,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 961c87f16..87abc494a 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -616,9 +616,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index c2eb566ed..5e244fa93 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -636,9 +636,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 3e17ccdb9..f9eac6fb3 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -988,9 +988,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index 70405a874..e798b6f6b 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -371,9 +371,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits = self.embed_out(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index 4edddea78..151d11641 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -521,9 +521,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index e1b005052..e6b22fe71 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -403,9 +403,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 4a8f23863..b04d61698 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -537,10 +537,10 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 7efbe22f6..c4a46db68 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -522,9 +522,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index c43ba1d46..4fe821039 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -606,9 +606,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits = self.lm_head(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index f8325a864..a3dc31da6 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -437,9 +437,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits = self.lm_head(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/llava_next.py b/server/lorax_server/models/custom_modeling/llava_next.py index 81c170cd5..cb0797834 100644 --- a/server/lorax_server/models/custom_modeling/llava_next.py +++ b/server/lorax_server/models/custom_modeling/llava_next.py @@ -265,9 +265,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - + if skip_lm_head: return hidden_states, None - + logits, speculative_logits = self.text_model.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2ad5119f5..152e1cb50 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union -from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, PunicaWrapper import numpy as np import torch import torch.distributed @@ -15,6 +14,14 @@ from transformers import AutoConfig, AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedTokenizerBase from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata +from lorax_server.models.metadata_kernels import ( + block_tables_to_padded, + block_tables_to_ragged, + copy_next_input_ids_inplace, + has_triton, + prepare_position_slot_ids, + slots_filtering, +) from lorax_server.models.model import Model from lorax_server.models.types import ( Batch, @@ -29,6 +36,7 @@ from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed from lorax_server.utils.graph import GraphCache from lorax_server.utils.import_utils import get_cuda_free_memory +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, PunicaWrapper from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files @@ -43,14 +51,6 @@ from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.torch_utils import is_fp8, is_fp8_kv, is_fp8_supported from lorax_server.utils.weights import Weights -from lorax_server.models.metadata_kernels import ( - has_triton, - copy_next_input_ids_inplace, - block_tables_to_ragged, - block_tables_to_padded, - prepare_position_slot_ids, - slots_filtering, -) ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) @@ -281,11 +281,7 @@ def from_pb( if not r.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] + request_slots = [s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: request_blocks = r.blocks request_slots = r.slots @@ -324,9 +320,7 @@ def from_pb( # Create tensors on device all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device) - block_tables_ragged = torch.tensor( - block_tables_ragged, device=device, dtype=torch.int32 - ) + block_tables_ragged = torch.tensor(block_tables_ragged, device=device, dtype=torch.int32) cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) block_tables_tensor = torch.empty( (len(block_tables), max_blocks), @@ -336,18 +330,12 @@ def from_pb( # If the device supports Triton, we can use a fused kernel if has_triton(): - block_tables_to_padded( - max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged - ) + block_tables_to_padded(max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged) else: for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor( - request_blocks - ) + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - prompt_lengths_tensor = torch.tensor( - prompt_lengths, dtype=torch.int32, device=device - ) + prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32, device=device) slots = torch.tensor(slots, dtype=torch.int64, device=device) cu_slots = torch.tensor(cu_slots, dtype=torch.int64) @@ -415,9 +403,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # slots to keep after filtering if not has_triton(): # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) + slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool, device=device) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -519,9 +505,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slots = self.slots.new_empty(cumulative_slot_tokens) gpu_cu_slots = cu_slots.to(device) slots_indexing_start = self.cu_slots.to(device)[indices] - slots_filtering( - max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start - ) + slots_filtering(max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start) if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` @@ -704,9 +688,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) slots[slots_start_index:slots_end_index] = batch.slots - cu_slots[start_index + 1 : end_index + 1] = ( - batch.cu_slots[1:] + cumulative_slots - ) + cu_slots[start_index + 1 : end_index + 1] = batch.cu_slots[1:] + cumulative_slots if not prefilling: input_ids[start_index:end_index] = batch.input_ids @@ -782,7 +764,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) - + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -841,24 +823,16 @@ def prepare_for_prefill(self): input_ids = self.input_ids[0] self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - self.input_lengths_tensor = torch.tensor( - self.input_lengths, dtype=torch.int32, device=device - ) - self.cu_seqlen_prefill = torch.nn.functional.pad( - torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) - ).to(torch.int32) - self.cache_lengths_tensor = torch.tensor( - self.cache_lengths, dtype=torch.int32, device=device + self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32, device=device) + self.cu_seqlen_prefill = torch.nn.functional.pad(torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)).to( + torch.int32 ) + self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32, device=device) # If the device supports Triton, we can use a fused kernel if has_triton(): - self.position_ids = torch.empty( - len(self.input_ids), dtype=torch.int32, device=device - ) - self.slot_indices = torch.empty( - len(self.input_ids), dtype=torch.int64, device=device - ) + self.position_ids = torch.empty(len(self.input_ids), dtype=torch.int32, device=device) + self.slot_indices = torch.empty(len(self.input_ids), dtype=torch.int64, device=device) cu_slots_gpu = self.cu_slots.to(device) prepare_position_slot_ids( @@ -903,20 +877,14 @@ def prepare_for_prefill(self): ) ): next_chunk_length = input_length - + if not has_triton(): # Position ids - request_position_ids = torch.arange( - cache_length, cache_length + input_length, dtype=torch.int32 - ) + request_position_ids = torch.arange(cache_length, cache_length + input_length, dtype=torch.int32) position_ids.append(request_position_ids) if not r.slots: - request_slots = [ - s - for b in blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] + request_slots = [s for b in blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: request_slots = r.slots @@ -991,9 +959,7 @@ def prepare_for_prefill(self): dtype=torch.int64, ) ) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) + prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) prefill_out_cumulative_length += input_length else: prefill_head_indices.append( @@ -1115,13 +1081,13 @@ def __init__( config.quantize = quantize if is_fp8(config.quantize) and not is_fp8_supported(): - raise ValueError('FP8 quantization is only supported on hardware that supports FP8') + raise ValueError("FP8 quantization is only supported on hardware that supports FP8") if is_fp8_kv(config.quantize): if not FLASH_INFER: - raise ValueError('FP8 KV cache requires FLASH_INFER backend') + raise ValueError("FP8 KV cache requires FLASH_INFER backend") self.kv_dtype = torch.float8_e4m3fn - logger.info('Enabling FP8 KV cache. Prefix caching will not work.') + logger.info("Enabling FP8 KV cache. Prefix caching will not work.") else: self.kv_dtype = dtype @@ -1244,7 +1210,7 @@ def __init__( num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, ) - + self.punica_wrapper = None @property @@ -1323,9 +1289,9 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model max_batches=256, # TODO(travis): find a better way to set this programmatically device=self.device, enabled=( - not self.dynamic_adapter_loading_enabled and # only supported for now with statically loaded adapters - not LORAX_PUNICA_TRITON_DISABLED - ) + not self.dynamic_adapter_loading_enabled # only supported for now with statically loaded adapters + and not LORAX_PUNICA_TRITON_DISABLED + ), ) torch.cuda.empty_cache() @@ -1642,12 +1608,12 @@ def generate_token( # TODO(travis): don't update this if indices haven't changed self.punica_wrapper.update_metadata(adapter_meta, prefill) adapter_data = AdapterBatchData.from_meta( - adapter_meta, - self.layer_to_adapter_weights, - self.layer_to_lora_weights, - self.punica_wrapper, - prefill, - batch.prefill_head_indices + adapter_meta, + self.layer_to_adapter_weights, + self.layer_to_lora_weights, + self.punica_wrapper, + prefill, + batch.prefill_head_indices, ) out, speculative_logits = self.forward(batch, adapter_data) @@ -1734,9 +1700,7 @@ def generate_token( indices = batch.cu_seqlen_prefill[1:] - 1 batch.position_ids = batch.position_ids[indices] batch.slot_indices = batch.slot_indices[indices] - batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ - indices - ] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[indices] # Zipped iterator iterator = zip( @@ -1756,9 +1720,7 @@ def generate_token( # For each member of the batch # Cumulative length - cu_accepted_ids = torch.nn.functional.pad( - torch.cumsum(accepted_ids, dim=0), (1, 0) - ) + cu_accepted_ids = torch.nn.functional.pad(torch.cumsum(accepted_ids, dim=0), (1, 0)) cumulative_length = 0 for i, ( request, @@ -1791,14 +1753,13 @@ def generate_token( # Only save tokens if we are done prefilling for this request batch.all_input_ids_tensor[ i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] cumulative_length += input_length - + # If the device support triton, we can use a fused kernel if has_triton(): copy_next_input_ids_inplace( diff --git a/server/lorax_server/models/metadata_kernels.py b/server/lorax_server/models/metadata_kernels.py index 7e2c2b1ac..830cbdca2 100644 --- a/server/lorax_server/models/metadata_kernels.py +++ b/server/lorax_server/models/metadata_kernels.py @@ -1,17 +1,17 @@ # From: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/metadata_kernels.py +from typing import List, Optional + import torch import triton - import triton.language as tl - from loguru import logger -from typing import List, Optional from torch.utils._triton import has_triton as has_triton_torch from lorax_server.utils.import_utils import ( SYSTEM, ) + _HAS_TRITON: Optional[bool] = None @@ -55,20 +55,16 @@ def block_tables_to_ragged( cache_lengths: List[int], input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, - max_current_length: int + max_current_length: int, ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" assert len(input_lengths) == len(cache_lengths) total_len = sum(input_lengths) + sum(cache_lengths) - block_tables_ragged = torch.empty( - total_len, dtype=torch.int32, device=block_tables.device - ) + block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) if has_triton(): - cu_seqlen = torch.nn.functional.pad( - torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0) - ) + cu_seqlen = torch.nn.functional.pad(torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)) def grid(meta): return ( @@ -85,9 +81,7 @@ def grid(meta): ) else: offset = 0 - for i, (input_length, cache_length) in enumerate( - zip(input_lengths, cache_lengths) - ): + for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): seq_len = cache_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len @@ -154,9 +148,7 @@ def grid(meta): len(slots_start), ) - triton_slots_filtering[grid]( - slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256 - ) + triton_slots_filtering[grid](slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256) @triton.jit @@ -214,9 +206,7 @@ def triton_block_tables_to_padded( mask = (seq_start + block_arange) < seq_end blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask) - tl.store( - block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask - ) + tl.store(block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask) @triton.jit @@ -244,9 +234,7 @@ def triton_block_tables_to_ragged( mask = (seq_start + block_arange) < seq_end - blocks = tl.load( - block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask - ) + blocks = tl.load(block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask) tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask) @@ -290,17 +278,11 @@ def triton_copy_next_input_ids_inplace( mask = mask & decode_mask # Load this request next input ids - next_input_ids = tl.load( - next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask - ) + next_input_ids = tl.load(next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask) # Store in all_input_ids, since it is a 2D tensor, apply stride * bid tl.store( - all_input_ids_ptr - + stride_all_input_ids * bid - + cache_length - + input_length - + block_arange, + all_input_ids_ptr + stride_all_input_ids * bid + cache_length + input_length + block_arange, next_input_ids, mask=mask, ) @@ -344,4 +326,4 @@ def triton_prepare_position_slot_ids( slot_indices_ptr + seq_start + block_arange, slot_start + cache_length + block_arange, mask=mask, - ) \ No newline at end of file + ) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 4e75e27ad..be80b3bdd 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -3,13 +3,12 @@ from collections import defaultdict from typing import Dict, List, Optional, Tuple, Type, TypeVar -from lorax_server.adapters.lora import LoraWeights -from lorax_server.adapters.medusa_lora import MedusaLoraWeights -from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, pad_to_min_rank, use_cutlass_shrink import torch from loguru import logger from transformers import PreTrainedTokenizerBase +from lorax_server.adapters.lora import LoraWeights +from lorax_server.adapters.medusa_lora import MedusaLoraWeights from lorax_server.adapters.utils import download_adapter_weights from lorax_server.adapters.weights import LayerAdapterWeights from lorax_server.models.types import Batch, GeneratedText @@ -19,6 +18,7 @@ BASE_MODEL_ADAPTER_ID, load_and_merge_adapters, ) +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, pad_to_min_rank, use_cutlass_shrink from lorax_server.utils.sources import HUB from lorax_server.utils.state import ( BLOCK_SIZE, @@ -126,7 +126,7 @@ def __init__( torch.profiler.ProfilerActivity.CUDA, ], with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler(LORAX_PROFILER_DIR, use_gzip=True) + on_trace_ready=torch.profiler.tensorboard_trace_handler(LORAX_PROFILER_DIR, use_gzip=True), ) self.profiler_steps = 0 @@ -233,7 +233,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: @property def adapter_layers(self) -> List[str]: return [] - + @property def traced_adapter_layers(self) -> List[str]: if self.layer_to_adapter_weights: @@ -262,7 +262,7 @@ def register_preloaded_adapters( ): if preloaded_adapters is None: return - + self.dynamic_adapter_loading_enabled = False self.preloaded_adapter_indices.update({adapter.adapter_index for adapter in preloaded_adapters}) self.preloaded_adapter_memory_fractions.update( @@ -280,7 +280,7 @@ def register_preloaded_adapters( # For Triton kernels: need weights into contiguous tensor # dict of (layer_name, layer_id) -> (lora_a_weights, lora_b_weights) # where: - # lora_a_weights = [num_adapters, r, hidden_size] + # lora_a_weights = [num_adapters, r, hidden_size] # lora_b_weights = [num_adapters, hidden_size, r] for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items(): layer_id_to_lora_a_weights = defaultdict(list) @@ -295,11 +295,11 @@ def register_preloaded_adapters( else: # only applicable to lora for now continue - + if adapter_weights is None: # no weights for this layer continue - + # transpose into col major lora_b = adapter_weights.weights_b.transpose(1, 2) lora_a = adapter_weights.weights_a @@ -310,7 +310,7 @@ def register_preloaded_adapters( for layer_id in range(nlayers): layer_id_to_lora_a_weights[layer_id].append(lora_a[layer_id]) layer_id_to_lora_b_weights[layer_id].append(lora_b[layer_id]) - + for layer_id, lora_a_weights in layer_id_to_lora_a_weights.items(): lora_b_weights = layer_id_to_lora_b_weights[layer_id] @@ -347,9 +347,9 @@ def load_adapter( if dynamic and not self.dynamic_adapter_loading_enabled: raise ValueError( - f"This model does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model and remove preloaded adapters " - f"to use the dynamic adapter loading feature." + "This model does not support dynamic adapter loading. " + "Please initialize a new model instance from the base model and remove preloaded adapters " + "to use the dynamic adapter loading feature." ) logger.info(f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}") @@ -428,9 +428,9 @@ def offload_adapter( if not self.dynamic_adapter_loading_enabled: raise ValueError( - f"This model does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model and remove preloaded adapters " - f"to use the dynamic adapter loading feature." + "This model does not support dynamic adapter loading. " + "Please initialize a new model instance from the base model and remove preloaded adapters " + "to use the dynamic adapter loading feature." ) for layer_name in self.adapter_layers: diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 9cf7340bb..bb378b561 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -63,7 +63,7 @@ async def ClearCache(self, request, context): self.cache.delete(request.id) else: self.cache.clear() - except: + except Exception: exit(1) return generate_pb2.ClearCacheResponse() @@ -113,7 +113,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - + if self.model.profiler: self.model.profiler_steps += 1 if self.model.profiler_steps == 10: @@ -328,7 +328,7 @@ async def serve_inner( create_exllama_buffers() except ImportError: pass - + # set speculative decoding tokens speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) if speculative_tokens > 0: diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index 2380875de..cc99b7a51 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -53,9 +53,7 @@ def use_prefill_with_paged_kv_state( `attention` function while the context manager is active. """ - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) + indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) @@ -64,13 +62,9 @@ def use_prefill_with_paged_kv_state( # Get the lengths of the last page in a block. if page_size == 1: - last_page_len = torch.ones( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) + last_page_len = torch.ones(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) else: - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) + last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index bdb7e2882..8baaed757 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -7,7 +7,6 @@ from statistics import median from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple -from lorax_server.models.metadata_kernels import block_tables_to_ragged import numpy as np import torch from loguru import logger @@ -16,7 +15,8 @@ from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.adapters.lora import BatchLoraWeights, RankSegments -from lorax_server.adapters.types import LORA, MEDUSA +from lorax_server.adapters.types import LORA +from lorax_server.models.metadata_kernels import block_tables_to_ragged from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.punica import BGMV_MAX_RANK, PunicaWrapper from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, get_speculative_tokens @@ -279,7 +279,7 @@ def trace( num_heads=num_heads, num_kv_heads=num_kv_heads, ) - + meta = AdapterBatchMetadata( adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], adapter_list=max_input_state.adapter_data.meta.adapter_list, @@ -287,10 +287,7 @@ def trace( adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], segment_indices=max_input_state.adapter_data.meta.segment_indices, ) - punica_wrapper.update_metadata( - meta=meta, - prefill=False - ) + punica_wrapper.update_metadata(meta=meta, prefill=False) input_state = GraphState( input_ids=max_input_state.input_ids[:batch_size], @@ -426,11 +423,8 @@ def forward( pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0) pad_and_fill(dest_rank_data.indices, source_rank_data.indices, SEGMENT_PAD_VALUE) - self.input_state.adapter_data.punica_wrapper.update_metadata( - meta=adapter_data.meta, - prefill=False - ) - + self.input_state.adapter_data.punica_wrapper.update_metadata(meta=adapter_data.meta, prefill=False) + with self.forward_context( block_tables=self.input_state.block_tables, cu_seqlen_prefill=None, diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 61be1a82b..8b43d89b7 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -78,9 +78,9 @@ def forward_layer_type( # Triton Punica kernels if ( - adapter_data.punica_wrapper.enabled and - input.shape[0] <= adapter_data.punica_wrapper.max_batch_size and - can_vectorize + adapter_data.punica_wrapper.enabled + and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size + and can_vectorize ): if end_idx - start_idx != result.shape[1]: y_offset = start_idx @@ -88,7 +88,7 @@ def forward_layer_type( else: y_offset = None y_slice_size = None - + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)] adapter_data.punica_wrapper.add_lora( result, @@ -162,7 +162,7 @@ def forward_layer_type( if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj - + # Vanilla PyTorch else: adapter_indices = adapter_data.meta.adapter_indices diff --git a/server/lorax_server/utils/ops/__init__.py b/server/lorax_server/utils/ops/__init__.py index 373f4d940..22f53e4d0 100644 --- a/server/lorax_server/utils/ops/__init__.py +++ b/server/lorax_server/utils/ops/__init__.py @@ -1 +1 @@ -# Source: https://github.com/vllm-project/vllm/tree/main/vllm/lora/ops \ No newline at end of file +# Source: https://github.com/vllm-project/vllm/tree/main/vllm/lora/ops diff --git a/server/lorax_server/utils/ops/bgmv_expand.py b/server/lorax_server/utils/ops/bgmv_expand.py index d214da0b6..59562cee8 100644 --- a/server/lorax_server/utils/ops/bgmv_expand.py +++ b/server/lorax_server/utils/ops/bgmv_expand.py @@ -1,7 +1,7 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -48,8 +48,9 @@ def _bgmv_expand_kernel( offset_k = tl.arange(0, BLOCK_K) offset_n = tl.arange(0, BLOCK_N) if EVEN_K: - tiled_a = tl.load(input_ptr + cur_batch * xm_stride + - offset_k * xk_stride, ) # [BLOCK_K] + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + ) # [BLOCK_K] else: tiled_a = tl.load( input_ptr + cur_batch * xm_stride + offset_k * xk_stride, @@ -61,18 +62,15 @@ def _bgmv_expand_kernel( if CAST_TYPE: tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) # sliding to next row-block - b_ptr = (lora_ptr + l0_stride * lora_index + - pid_sn * split_n_length * lora_k_stride) + b_ptr = lora_ptr + l0_stride * lora_index + pid_sn * split_n_length * lora_k_stride c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length for n in range(0, split_n_length, BLOCK_N): current_n = n + offset_n current_n_c = tl.max_contiguous(current_n, BLOCK_N) - b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] - < K) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] < K) c_mask = current_n < split_n_length tiled_b = tl.load( - b_ptr + current_n_c[:, None] * lora_k_stride + - offset_k[None, :] * lora_n_stride, + b_ptr + current_n_c[:, None] * lora_k_stride + offset_k[None, :] * lora_n_stride, mask=b_ptr_mask, other=0.0, ) # [BLOCK_N,BLOCK_K] @@ -103,9 +101,9 @@ def bgmv_expand( corresponding to each batch, An index of -1 means no lora should be applied. batches (int): batch size - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False. adds the final lora results to the output. - override_config (Optional[Dict[str, int]], optional): Defaults to None. + override_config (Optional[Dict[str, int]], optional): Defaults to None. Triton grid config """ @@ -133,8 +131,8 @@ def bgmv_expand( ADD_INPUTS = add_inputs CAST_TYPE = False if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, + torch.float16, + torch.bfloat16, ]: CAST_TYPE = True batches = lora_indices_tensor.size(0) @@ -142,7 +140,7 @@ def bgmv_expand( config = override_config else: config = get_lora_op_configs("expand", batches, N) - grid = lambda META: ( + grid = lambda META: ( # noqa: E731 META["SPLIT_N"], batches, ) @@ -166,4 +164,4 @@ def bgmv_expand( CAST_TYPE=CAST_TYPE, **config, ) - return \ No newline at end of file + return diff --git a/server/lorax_server/utils/ops/bgmv_expand_slice.py b/server/lorax_server/utils/ops/bgmv_expand_slice.py index 1444fa8e5..a4eb1b425 100644 --- a/server/lorax_server/utils/ops/bgmv_expand_slice.py +++ b/server/lorax_server/utils/ops/bgmv_expand_slice.py @@ -1,7 +1,7 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -49,8 +49,9 @@ def _bgmv_expand_slice_kernel( offset_k = tl.arange(0, BLOCK_K) offset_n = tl.arange(0, BLOCK_N) if EVEN_K: - tiled_a = tl.load(input_ptr + cur_batch * xm_stride + - offset_k * xk_stride, ) # [BLOCK_K] + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + ) # [BLOCK_K] else: tiled_a = tl.load( input_ptr + cur_batch * xm_stride + offset_k * xk_stride, @@ -62,19 +63,15 @@ def _bgmv_expand_slice_kernel( if CAST_TYPE: tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) # sliding to next row-block - b_ptr = (lora_ptr + l0_stride * lora_index + - pid_sn * split_n_length * lora_k_stride) - c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + - slice_offset * cn_stride) + b_ptr = lora_ptr + l0_stride * lora_index + pid_sn * split_n_length * lora_k_stride + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + slice_offset * cn_stride for n in range(0, split_n_length, BLOCK_N): current_n = n + offset_n - b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] - < K) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] < K) c_mask = current_n < split_n_length tiled_b = tl.load( - b_ptr + current_n[:, None] * lora_k_stride + - offset_k[None, :] * lora_n_stride, + b_ptr + current_n[:, None] * lora_k_stride + offset_k[None, :] * lora_n_stride, mask=b_ptr_mask, other=0.0, ) # [BLOCK_N,BLOCK_K] @@ -142,8 +139,8 @@ def bgmv_expand_slice( ADD_INPUTS = add_inputs CAST_TYPE = False if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, + torch.float16, + torch.bfloat16, ]: CAST_TYPE = True @@ -154,7 +151,7 @@ def bgmv_expand_slice( else: config = get_lora_op_configs("expand", batches, N) - grid = lambda META: ( + grid = lambda META: ( # noqa: E731 META["SPLIT_N"], batches, ) @@ -179,4 +176,4 @@ def bgmv_expand_slice( CAST_TYPE=CAST_TYPE, **config, ) - return \ No newline at end of file + return diff --git a/server/lorax_server/utils/ops/bgmv_shrink.py b/server/lorax_server/utils/ops/bgmv_shrink.py index c532ba526..0937f4fa7 100644 --- a/server/lorax_server/utils/ops/bgmv_shrink.py +++ b/server/lorax_server/utils/ops/bgmv_shrink.py @@ -1,7 +1,7 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -48,7 +48,7 @@ def _bgmv_shrink_kernel( offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K a_ptr = input_ptr + cur_batch * xm_stride b_ptr = lora_ptr + l0_stride * lora_index - accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32) for k in range(0, K, BLOCK_K * SPLIT_K): current_k = k + offset_k current_k_c = tl.max_contiguous(current_k, BLOCK_K) @@ -60,8 +60,7 @@ def _bgmv_shrink_kernel( b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) tiled_b = tl.load( - b_ptr + offset_n[:, None] * lora_k_stride + - current_k[None, :] * lora_n_stride, + b_ptr + offset_n[:, None] * lora_k_stride + current_k[None, :] * lora_n_stride, mask=b_ptr_mask, other=0.0, ) # [BLOCK_N,BLOCK_K] @@ -96,7 +95,7 @@ def bgmv_shrink( applied. batches (int): batch size scaling (float): Scaling factor. - override_config (Optional[Dict[str, int]], optional): Defaults to None. + override_config (Optional[Dict[str, int]], optional): Defaults to None. Triton grid config """ assert inputs.dtype == lora_a_weights.dtype @@ -125,7 +124,7 @@ def bgmv_shrink( # First try to load optimal config from the file config = get_lora_op_configs("bgmv_shrink", batches, K) - grid = lambda META: ( + grid = lambda META: ( # noqa: E731 META["SPLIT_K"], batches, ) @@ -147,4 +146,4 @@ def bgmv_shrink( BLOCK_N=BLOCK_N, **config, ) - return \ No newline at end of file + return diff --git a/server/lorax_server/utils/ops/libentry.py b/server/lorax_server/utils/ops/libentry.py index 867d71662..4572688b1 100644 --- a/server/lorax_server/utils/ops/libentry.py +++ b/server/lorax_server/utils/ops/libentry.py @@ -6,7 +6,6 @@ class LibEntry(triton.KernelInterface): - def __init__( self, fn, @@ -20,23 +19,27 @@ def __init__( fn = fn.fn self.jit_function: triton.runtime.JITFunction = fn self.specialize_indices = [ - p.num for p in self.jit_function.params - if not p.is_constexpr and not p.do_not_specialize + p.num for p in self.jit_function.params if not p.is_constexpr and not p.do_not_specialize ] self.do_not_specialize_indices = [ - p.num for p in self.jit_function.params - if not p.is_constexpr and p.do_not_specialize + p.num for p in self.jit_function.params if not p.is_constexpr and p.do_not_specialize ] def key(self, spec_args, dns_args, const_args): - spec_key = [(arg.dtype, arg.data_ptr() % - self.divisibility == 0) if hasattr(arg, "data_ptr") else - (type(arg), arg) for arg in spec_args] + spec_key = [ + (arg.dtype, arg.data_ptr() % self.divisibility == 0) if hasattr(arg, "data_ptr") else (type(arg), arg) + for arg in spec_args + ] dns_key = [ - arg.dtype if hasattr( - arg, "data_ptr") else type(arg) if not isinstance(arg, int) - else "i32" if -(2**31) <= arg and arg <= 2**31 - - 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + arg.dtype + if hasattr(arg, "data_ptr") + else type(arg) + if not isinstance(arg, int) + else "i32" + if -(2**31) <= arg and arg <= 2**31 - 1 + else "u64" + if 2**63 <= arg and arg <= 2**64 - 1 + else "i64" for arg in dns_args ] # const args passed by position @@ -58,7 +61,7 @@ def run(self, *args, **kwargs): dns_args.append(arg) else: const_args.append(arg) - for p in self.jit_function.params[len(args):]: + for p in self.jit_function.params[len(args) :]: if p.name in kwargs: val = kwargs[p.name] elif p.default is inspect._empty: @@ -92,11 +95,13 @@ def run(self, *args, **kwargs): constexprs = {**constexprs, **config.kwargs} elif isinstance(fn, triton.runtime.Heuristics): for v, heur in fn.values.items(): - constexprs[v] = heur({ - **dict(zip(fn.arg_names, args)), - **kwargs, - **constexprs, - }) + constexprs[v] = heur( + { + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + } + ) else: raise RuntimeError("Invalid Runtime Function") fn = fn.fn @@ -106,7 +111,7 @@ def run(self, *args, **kwargs): # (tl.constexpr) are assigned values through the following loop. for p in self.jit_function.params: if p.is_constexpr and p.name not in constexprs: - constexprs[p.name] = p.default #default=inspect._empty + constexprs[p.name] = p.default # default=inspect._empty self.kernel_cache[entry_key] = (kernel, constexprs) else: # load kernel from cache directly @@ -120,10 +125,7 @@ def run(self, *args, **kwargs): # Autotunner & Heuristics when kwargs & captured args conflict, # captured args have higher priority # 4. We must filter out captured args with default value firstly - constexprs = { - k: v - for k, v in constexprs.items() if v is not inspect._empty - } + constexprs = {k: v for k, v in constexprs.items() if v is not inspect._empty} meta = { **dict(zip(self.arg_names, args)), **kwargs, @@ -143,24 +145,24 @@ def libentry(): """ Decorator for triton library entries. Motivation: - The runtime overhead of Triton kernels is the reason for the lower - performance of small kernels, particularly evident with smaller models. + The runtime overhead of Triton kernels is the reason for the lower + performance of small kernels, particularly evident with smaller models. Using this decorator can reduce Triton runtime overhead. How: The `run` function of JITFunction needs to accomplish: - Parameter binding using inspect - KernelArg type wrapping - Cache key calculation - When dealing with small size, these steps can become bottlenecks in - Triton runtime. Libentry simplifies these steps to reduce runtime + When dealing with small size, these steps can become bottlenecks in + Triton runtime. Libentry simplifies these steps to reduce runtime overhead, thereby improving the runtime expenses of small kernels. NOTE: When Triton is upgraded to version 3.0.0, libentry can be removed, see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 - + """ def decorator(fn): return LibEntry(fn) - return decorator \ No newline at end of file + return decorator diff --git a/server/lorax_server/utils/ops/sgmv_expand.py b/server/lorax_server/utils/ops/sgmv_expand.py index 181b92434..083c03493 100644 --- a/server/lorax_server/utils/ops/sgmv_expand.py +++ b/server/lorax_server/utils/ops/sgmv_expand.py @@ -1,7 +1,7 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -58,22 +58,16 @@ def _sgmv_expand_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride, ) - b_ptr = (lora_ptr + l0_stride * lora_index + - offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride,) + b_ptr = lora_ptr + l0_stride * lora_index + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): if EVEN_K: tiled_a = tl.load(a_ptr) tiled_b = tl.load(b_ptr) else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) + tiled_a = tl.load(a_ptr, mask=offset_k[None, :] < K - k * BLOCK_K, other=0) + tiled_b = tl.load(b_ptr, mask=offset_k[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) accumulator += tl.dot( @@ -85,11 +79,9 @@ def _sgmv_expand_kernel( tiled_c = accumulator.to(lora_ptr.dtype.element_ty) offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) + c_ptr = out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < - (cur_seq_start + M)) & (offset_cn[None, :] < N) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out @@ -125,7 +117,7 @@ def sgmv_expand( batches (int): batch size max_seq_length (int): The max sequence lengths of the sequences in the batch - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False. adds the final lora results to the output. """ # print("!!! inputs", inputs.shape) @@ -166,8 +158,8 @@ def sgmv_expand( ADD_INPUTS = add_inputs CAST_TYPE = False if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, + torch.float16, + torch.bfloat16, ]: CAST_TYPE = True grid = ( @@ -197,4 +189,4 @@ def sgmv_expand( ADD_INPUTS, CAST_TYPE, ) - return \ No newline at end of file + return diff --git a/server/lorax_server/utils/ops/sgmv_expand_slice.py b/server/lorax_server/utils/ops/sgmv_expand_slice.py index 1fa1d96de..2da04b947 100644 --- a/server/lorax_server/utils/ops/sgmv_expand_slice.py +++ b/server/lorax_server/utils/ops/sgmv_expand_slice.py @@ -1,7 +1,7 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -39,10 +39,10 @@ def _sgmv_expand_slice_kernel( CAST_TYPE: tl.constexpr, ): """ - Similar to the 'sgmv_expand' operator, but with an added parameter - 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator - might be that in the future, we could implement a fusion operator to - achieve the current functionality instead of having to call it multiple + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple times. """ pid = tl.program_id(axis=0) @@ -63,22 +63,16 @@ def _sgmv_expand_slice_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride, ) - b_ptr = (lora_ptr + l0_stride * lora_index + - offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride,) + b_ptr = lora_ptr + l0_stride * lora_index + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): if EVEN_K: tiled_a = tl.load(a_ptr) tiled_b = tl.load(b_ptr) else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) + tiled_a = tl.load(a_ptr, mask=offset_k[None, :] < K - k * BLOCK_K, other=0) + tiled_b = tl.load(b_ptr, mask=offset_k[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) accumulator += tl.dot( @@ -90,11 +84,9 @@ def _sgmv_expand_slice_kernel( tiled_c = accumulator.to(lora_ptr.dtype.element_ty) offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) + c_ptr = out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < - (slice_offset + N)) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < (slice_offset + N)) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out @@ -134,7 +126,7 @@ def sgmv_expand_slice( in the batch slice_offst (int): output_tensor's offst slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False. adds the final lora results to the output.. """ # print("!!! inputs", inputs.shape) @@ -178,8 +170,8 @@ def sgmv_expand_slice( ADD_INPUTS = add_inputs CAST_TYPE = False if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, + torch.float16, + torch.bfloat16, ]: CAST_TYPE = True grid = ( @@ -210,4 +202,4 @@ def sgmv_expand_slice( ADD_INPUTS, CAST_TYPE, ) - return \ No newline at end of file + return diff --git a/server/lorax_server/utils/ops/sgmv_shrink.py b/server/lorax_server/utils/ops/sgmv_shrink.py index fb3a5d6ad..80ea5921c 100644 --- a/server/lorax_server/utils/ops/sgmv_shrink.py +++ b/server/lorax_server/utils/ops/sgmv_shrink.py @@ -1,7 +1,7 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -63,10 +63,8 @@ def _sgmv_shrink_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride) - b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + - offset_k[:, None] * lora_n_stride) + a_ptr = input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride + b_ptr = lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + offset_k[:, None] * lora_n_stride accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): @@ -75,12 +73,8 @@ def _sgmv_shrink_kernel( tiled_b = tl.load(b_ptr) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < k_remaining, - other=0.0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < k_remaining, - other=0.0) + tiled_a = tl.load(a_ptr, mask=offset_k[None, :] < k_remaining, other=0.0) + tiled_b = tl.load(b_ptr, mask=offset_k[:, None] < k_remaining, other=0.0) accumulator += tl.dot(tiled_a, tiled_b) a_ptr += BLOCK_K * SPLIT_K * xk_stride @@ -88,10 +82,8 @@ def _sgmv_shrink_kernel( offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) - c_mask = (offset_cm[:, None] < - (cur_seq_start + M)) & (offset_cn[None, :] < N) + c_ptr = out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) accumulator *= scaling # handles write-back with reduction-splitting if SPLIT_K == 1: @@ -195,4 +187,4 @@ def sgmv_shrink( EVEN_K, SPLIT_K, ) - return \ No newline at end of file + return diff --git a/server/lorax_server/utils/ops/utils.py b/server/lorax_server/utils/ops/utils.py index c4615d40f..4460188b8 100644 --- a/server/lorax_server/utils/ops/utils.py +++ b/server/lorax_server/utils/ops/utils.py @@ -22,25 +22,20 @@ def _check_divisibility(hidden_size: int): def _get_default_config(op_type: str, batch: int, hidden_size: int): if op_type == "expand": - return { - "BLOCK_N": 256, - "SPLIT_N": _check_divisibility(hidden_size), - "num_warps": 8 - } + return {"BLOCK_N": 256, "SPLIT_N": _check_divisibility(hidden_size), "num_warps": 8} else: return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} -def get_lora_op_configs(op_type: str, batch: int, - hidden_size: int) -> Dict[str, int]: +def get_lora_op_configs(op_type: str, batch: int, hidden_size: int) -> Dict[str, int]: """Inspired by `fused_moe_kernel` - The return value will be a dictionary mapping an irregular grid of batch - sizes and hidden_size to configurations of the bgmv-related kernel. - NOTE: It currently only supports the default configuration. We plan to - generate optimal configurations for different hardware in the future using + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using scripts similar to `benchmark_moe.py`. """ config = _get_op_configs(op_type, batch, hidden_size) if not config: config = _get_default_config(op_type, batch, hidden_size) - return config \ No newline at end of file + return config diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 9e41d4595..7d7b4c82c 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -48,7 +48,7 @@ def reshape_and_cache( elif SYSTEM == "xpu": ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slots) else: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, 'auto', 1.0, 1.0) + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0) def attention( @@ -138,7 +138,7 @@ def attention( block_size, max_s, None, - 'auto', + "auto", 1.0, 1.0, ) @@ -172,7 +172,7 @@ def attention( block_size, max_s, None, - 'auto', + "auto", 1.0, 1.0, ) diff --git a/server/lorax_server/utils/punica.py b/server/lorax_server/utils/punica.py index f41e65b6d..fe5869e00 100644 --- a/server/lorax_server/utils/punica.py +++ b/server/lorax_server/utils/punica.py @@ -3,9 +3,9 @@ from functools import lru_cache from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union -from loguru import logger import torch import torch.nn.functional as F +from loguru import logger from lorax_server.utils.ops.bgmv_expand import bgmv_expand from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice @@ -255,9 +255,7 @@ def segmented_matmul( y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) -def compute_meta( - token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: +def compute_meta(token_lora_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -267,8 +265,7 @@ def compute_meta( needed based on the input, but only once. """ - lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(token_lora_tensor, return_counts=True) cum_result = torch.cumsum(seq_length_tensor, dim=0) b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) @@ -281,8 +278,7 @@ def compute_meta( # does not need to launch the triton kernel, which can improve performance if batch_size == 1 and lora_indices_tensor == -1: no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, no_lora) # TODO see if this can be vectorized @@ -291,9 +287,8 @@ def convert_mapping( max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: + long_lora_context=None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int]]: """Converts LoRAMapping to index tensors. Args: mapping: LoRAMapping mapping rows in a batch to LoRA ids. @@ -329,9 +324,7 @@ def convert_mapping( lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device="cuda", - dtype=torch.long) + long_lora_offsets = torch.zeros(len(index_mapping_indices), device="cuda", dtype=torch.long) prompt_mapping = meta.adapter_list.copy() lora_idx = None for i in range(len(index_mapping_indices)): @@ -340,8 +333,7 @@ def convert_mapping( lora_indices[i] = lora_idx if long_lora_context: assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) + lora_offset: int = long_lora_context.offsets_by_lora_id.get(index_mapping_indices[i], 0) long_lora_offsets[i] = lora_offset indices_list: List[Union[List[int], torch.Tensor]] = [ @@ -353,21 +345,21 @@ def convert_mapping( assert long_lora_offsets is not None indices_list.append(long_lora_offsets) indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") - prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) + prompt_mapping_tensor = torch.tensor(prompt_mapping, device="cuda", dtype=torch.long) + embeddings_indices = torch.stack( + [ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ] + ) embeddings_indices[embeddings_indices == -1] = max_loras - 1 base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) + sampler_indices_padded = torch.arange(0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded) + ) long_lora_indices = None long_lora_indices_len: Optional[int] = None if long_lora_context: @@ -399,51 +391,33 @@ def convert_mapping( # Source: https://github.com/vllm-project/vllm/blob/main/vllm/lora/punica.py class PunicaWrapper: """ - PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica kernel. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: str, enabled: bool): - self._token_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices_padded = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._embeddings_indices = torch.empty(2, - max_num_batched_tokens, - dtype=torch.long, - device=device) - self._long_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) + def __init__(self, max_num_batched_tokens: int, max_batches: int, device: str, enabled: bool): + self._token_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + self._embeddings_indices = torch.empty(2, max_num_batched_tokens, dtype=torch.long, device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) # 5 is the number of indicies tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices,long_lora_indices self.indices_len: List[Optional[int]] = [None] * 5 # these attributes are the information required for sgmv kernel - self._seq_start_locs = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._seq_lengths = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._lora_indices_per_batch = torch.empty(max_batches, - dtype=torch.long, - device=device) + self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device) + self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device) + self._lora_indices_per_batch = torch.empty(max_batches, dtype=torch.long, device=device) self.max_batch_size = max_batches self.max_length: int = 0 self.batch_size: int = -1 self.is_prefill = False self.no_lora = False self.enabled = enabled - + def update_metadata( self, meta: "AdapterBatchMetadata", @@ -452,10 +426,10 @@ def update_metadata( # token_lora_indices is adapter_indices - 1 to account for base model offset base_indices = meta.adapter_indices - 1 - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) # self._token_lora_indices = base_indices self.indices_len[0] = base_indices.shape[-1] - + if prefill: # Update metadata required for prefill-related operators. self._update_prefill_metada(self._token_lora_indices, base_indices.shape[-1]) @@ -469,7 +443,7 @@ def _update_base_metadata( max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context = None, + long_lora_context=None, ): ( base_indices, @@ -485,57 +459,53 @@ def _update_base_metadata( extra_vocab_size, long_lora_context, ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) + self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_(sampler_indices_padded) + self._embeddings_indices[: embeddings_indices.shape[0], : embeddings_indices.shape[1]].copy_(embeddings_indices) if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) + self._long_lora_indices[: long_lora_offsets_tensor.shape[0]].copy_(long_lora_offsets_tensor) else: self._long_lora_indices.zero_() self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor, indices_len: int) -> None: + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, no_lora) = compute_meta( + token_lora_tensor[:indices_len] + ) - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor[:indices_len]) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) + self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor) + self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_(lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length self.no_lora = no_lora @property - def prefill_metadata( - self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + def prefill_metadata(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: """ - This property provides a convenient way to access the necessary + This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. 1. seq_start_locs: Tensor of sequence start positions 2. seq_lengths: Tensor of sequence lengths - 3. lora_indices_per_batch: Tensor of lora indices, and an index of + 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. 4. batch_size: batch size after clustering identical lora indices 5. max_length: The maximum sequence length in the batch """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length) + return ( + self._seq_start_locs[: self.batch_size], + self._seq_lengths[: self.batch_size], + self._lora_indices_per_batch[: self.batch_size], + self.batch_size, + self.max_length, + ) @property def token_lora_indices(self) -> torch.Tensor: """ - This property provides the lora indices corresponding to each token + This property provides the lora indices corresponding to each token in the batch. An index of -1 means no lora should be applied. """ token_lora_len = self.indices_len[0] @@ -543,8 +513,8 @@ def token_lora_indices(self) -> torch.Tensor: @property def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for + """ + This property is used to access the lora indices specifically for LogitsProcessorWithLoRA """ sampler_indices_len = self.indices_len[1] @@ -561,7 +531,7 @@ def sampler_indices_padded(self) -> torch.Tensor: @property def embeddings_indices(self) -> torch.Tensor: """ - This property provides access to the indices used for lora embeddings, + This property provides access to the indices used for lora embeddings, specifically for VocabParallelEmbeddingWithLoRA """ embeddings_indices_len = self.indices_len[3] @@ -569,8 +539,8 @@ def embeddings_indices(self) -> torch.Tensor: @property def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context + """ + This property provides access to the indices used for long context lora, specifically for LinearScalingRotaryEmbeddingWithLora """ long_lora_len = self.indices_len[4] @@ -583,7 +553,7 @@ def shrink_prefill( w_t_all: torch.Tensor, scale: float, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_shrink( @@ -610,7 +580,7 @@ def expand_prefill( w_t_all: torch.Tensor, add_input: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_expand( @@ -639,7 +609,7 @@ def expand_slice_prefill( y_slice_size: Optional[int], add_input: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_expand_slice( @@ -661,8 +631,7 @@ def expand_slice_decode( y_slice_size: Optional[int], add_input: bool, ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_input) + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) def add_shrink( self, @@ -679,8 +648,7 @@ def add_shrink( Otherwise, it is the decode stage, and the shrink_decode function should be called. """ - shrink_fun: Callable = (self.shrink_prefill - if self.is_prefill else self.shrink_decode) + shrink_fun: Callable = self.shrink_prefill if self.is_prefill else self.shrink_decode shrink_fun(y, x, w_t_all, scale) def add_expand( @@ -699,38 +667,37 @@ def add_expand( should be called. """ - expand_fun: Callable = (self.expand_prefill - if self.is_prefill else self.expand_decode) + expand_fun: Callable = self.expand_prefill if self.is_prefill else self.expand_decode expand_fun(y, x, w_t_all, add_input) - def add_expand_slice(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): + def add_expand_slice( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + ): """ Similar to `add_expand` """ - expand_slice_fun: Callable = (self.expand_slice_prefill - if self.is_prefill else - self.expand_slice_decode) + expand_slice_fun: Callable = self.expand_slice_prefill if self.is_prefill else self.expand_slice_decode expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) def add_lora( - self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale: float, - y_offset: Optional[int] = None, - y_slice_size: Optional[int] = None, - *, - buffer: Optional[torch.Tensor] = None, - callback: Optional[Callable] = None, + self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None, + callback: Optional[Callable] = None, ): """ Semantics: @@ -758,38 +725,31 @@ def add_lora( if buffer is None: # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) self.add_shrink(buffer, x, wa_t_all, scale) - + if callback is not None: # callback used to aggregate intermediate results (i.e., allreduce, allgather) buffer = callback(buffer) - + if y_offset is None and y_slice_size is None: self.add_expand(y, buffer, wb_t_all, add_input=True) else: - self.add_expand_slice(y, - buffer, - wb_t_all, - y_offset, - y_slice_size, - add_input=True) + self.add_expand_slice(y, buffer, wb_t_all, y_offset, y_slice_size, add_input=True) y = y.view_as(y_org) - def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - scale: float, - output_slices: Tuple[int, ...]) -> None: + def add_lora_packed_nslice( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + scale: float, + output_slices: Tuple[int, ...], + ) -> None: """ - Applies lora to each input. Similar to add_lora, This method is + Applies lora to each input. Similar to add_lora, This method is used for layers that are composed of multiple sublayers (slices) packed together. """ @@ -799,21 +759,23 @@ def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, offset_left = 0 # TODO fuse these kernels for slice_idx in range(len(output_slices)): - self.add_lora(y, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], scale, offset_left, - output_slices[slice_idx]) + self.add_lora( + y, x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], scale, offset_left, output_slices[slice_idx] + ) offset_left += output_slices[slice_idx] y = y.view_as(y_org) - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None) -> None: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + ) -> None: """ LogitsProcessorWithLoRA always using bgmv """ @@ -824,9 +786,7 @@ def add_lora_logits(self, if buffer is None: # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 2bf613a81..c3231caea 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -3,7 +3,6 @@ from contextlib import nullcontext from typing import List, Optional, Set, Tuple, Union -from lorax_server.utils.state import use_ngram import torch from transformers import ( PreTrainedTokenizerBase, @@ -23,6 +22,7 @@ OutlinesLogitsProcessor, static_warper, ) +from lorax_server.utils.state import use_ngram from lorax_server.utils.watermark import WatermarkLogitsProcessor diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index 682402cc1..4e3567720 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -15,14 +15,16 @@ def is_quantized(quantize): def is_fp8_supported(): - return torch.cuda.is_available() and \ - (torch.cuda.get_device_capability()[0] >= 9) or \ - (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + return ( + torch.cuda.is_available() + and (torch.cuda.get_device_capability()[0] >= 9) + or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + ) def is_fp8_kv(quantize): - return quantize and quantize == 'fp8-kv' + return quantize and quantize == "fp8-kv" def is_fp8(quantize): - return quantize and quantize.startswith('fp8') + return quantize and quantize.startswith("fp8") From 295829f1c90d939de534f7ad526eb8764cc47b75 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 23:17:07 -0700 Subject: [PATCH 54/76] Fixed tests --- server/tests/adapters/test_medusa.py | 1 + server/tests/utils/test_lora.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/server/tests/adapters/test_medusa.py b/server/tests/adapters/test_medusa.py index bc808d1a9..fe9274cf2 100644 --- a/server/tests/adapters/test_medusa.py +++ b/server/tests/adapters/test_medusa.py @@ -30,6 +30,7 @@ def test_batched_medusa_weights(default_causal_lm: CausalLM): meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), + adapter_list=[0, 1, 0, 1], adapter_set={0, 1}, adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), segment_indices=[0, 1, 0, 1], diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 24ecd81ea..70f3e5985 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -74,6 +74,7 @@ def test_batched_lora_weights(lora_ranks: List[int]): meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), + adapter_list=[0, 1, 0, 1], adapter_set={0, 1}, adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), segment_indices=[0, 1, 0, 1], @@ -149,6 +150,7 @@ def test_batched_lora_weights_decode( meta = AdapterBatchMetadata( adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64), + adapter_list=adapter_indices, adapter_set=set(adapter_indices), adapter_segments=torch.tensor(segments, dtype=torch.int64), segment_indices=segment_indices, @@ -193,7 +195,8 @@ def test_batched_lora_weights_no_segments(): meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 0, 0], dtype=torch.int64), - adapter_set={0, 1}, + adapter_list=[0], + adapter_set={0}, adapter_segments=torch.tensor([0, 4], dtype=torch.int64), segment_indices=[0], ) From ef8607147131f92f7557328adb1e29fb24befc82 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 31 Oct 2024 23:42:34 -0700 Subject: [PATCH 55/76] Fixed CausalLM and embedding model --- server/lorax_server/models/causal_lm.py | 21 ++++++++++++--------- server/lorax_server/models/types.py | 7 ++++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 79aa029cf..598c6b380 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -90,7 +90,7 @@ def from_pb( padding_right_offset = 0 max_decode_tokens = 0 adapter_indices_list = [] - adapter_set = set() + adapter_list = [] for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i req_inputs = tokenizers.get_inputs(r, tokenizer) @@ -102,7 +102,7 @@ def from_pb( max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) adapter_indices_list.append(r.adapter_index) - adapter_set.add(r.adapter_index) + adapter_list.append(r.adapter_index) adapter_indices = torch.tensor(adapter_indices_list, dtype=torch.int64, device=device) @@ -156,7 +156,8 @@ def from_pb( max_tokens=max_tokens, adapter_meta=AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), @@ -180,7 +181,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: all_input_ids = [] max_input_length = 0 - adapter_set = set() + adapter_list = [] next_token_choosers = [] stopping_criterias = [] @@ -209,7 +210,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: total_remaining_decode_tokens += remaining_decode_tokens new_padding_right_offset = max(new_padding_right_offset, remaining_decode_tokens) - adapter_set.add(self.requests[idx].adapter_index) + adapter_list.append(self.requests[idx].adapter_index) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] @@ -262,7 +263,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: self.max_tokens = max_tokens self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) @@ -301,7 +303,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) - adapter_set = set() + adapter_list = [] adapter_segment_builder = SegmentConcatBuilder() cumulative_adapter_indices_size = 0 @@ -344,7 +346,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) + adapter_list.extend(batch.adapter_meta.adapter_list) # Update adapter segments adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) @@ -476,7 +478,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": max_tokens=max_tokens, adapter_meta=AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 62ad44b02..b5e7654dd 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -181,7 +181,7 @@ def from_pb( max_s = 0 cumulative_length = 0 - adapter_set = set() + adapter_list = [] for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): tokenized_input = tokenized_input[-r.truncate :] @@ -199,7 +199,7 @@ def from_pb( position_ids.append(request_position_ids) adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) + adapter_list.append(r.adapter_index) cumulative_length += input_length @@ -232,7 +232,8 @@ def from_pb( size=len(batch_tokenized_inputs), adapter_meta=AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), From 0d78a0afef6c2bed43ff451dde6d7943db255977 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 09:57:17 -0700 Subject: [PATCH 56/76] Replace flume --- Cargo.lock | 30 +---- router/Cargo.toml | 4 +- router/src/batch.rs | 5 +- router/src/infer.rs | 123 ++++++++---------- router/src/loader.rs | 13 +- router/src/scheduler.rs | 12 +- router/src/validation.rs | 37 ++++-- server/lorax_server/models/flash_causal_lm.py | 1 - 8 files changed, 105 insertions(+), 120 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ec8bd04c9..9512fa06b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -833,7 +833,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", - "flume 0.11.0", + "flume", "half", "lebe", "miniz_oxide", @@ -884,22 +884,9 @@ checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" [[package]] name = "flume" -version = "0.10.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "pin-project", - "spin 0.9.8", -] - -[[package]] -name = "flume" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ "spin 0.9.8", ] @@ -1581,7 +1568,6 @@ dependencies = [ "axum-tracing-opentelemetry", "base64 0.22.1", "clap", - "flume 0.10.14", "futures", "h2", "hf-hub", @@ -1611,6 +1597,7 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", "tower-http 0.4.1", "tracing", "tracing-opentelemetry 0.19.0", @@ -1874,15 +1861,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.11" diff --git a/router/Cargo.toml b/router/Cargo.toml index c38330683..a8ed2c2e5 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -19,7 +19,6 @@ async-stream = "0.3.3" axum = { version = "0.6.4", features = ["json"] } axum-tracing-opentelemetry = "0.10.0" clap = { version = "4.1.4", features = ["derive", "env"] } -flume = "0.10.14" futures = "0.3.26" hf-hub = { version = "0.3.0", features = ["tokio"] } h2 = "0.3.26" @@ -48,7 +47,8 @@ tokio = { version = "1.32.0", features = [ "signal", "sync", ] } -tower-http = { version = "0.4.0", features = ["cors"] } +tokio-stream = "0.1.14" +RecvStreamtower-http = { version = "0.4.0", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } diff --git a/router/src/batch.rs b/router/src/batch.rs index 18d0b40ad..2363dbdf6 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -12,7 +12,7 @@ use lorax_client::{ StoppingCriteriaParameters, TokenizedInputs, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use tokio::time::Instant; +use tokio::{sync::mpsc, time::Instant}; use tracing::{Instrument, Span}; use crate::{ @@ -167,7 +167,7 @@ pub(crate) struct Entry { /// Request pub request: Arc, /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: flume::Sender>, + pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry pub span: Span, /// Temporary span used as a guard when logging inference, wait times... @@ -222,6 +222,7 @@ impl BatchEntriesState { // TODO(travis): clone is not ideal, find a way to do this cleanly in place for r in self.batch_requests.clone().into_iter().rev() { let id = r.id; + tracing::info!("!!! drain::remove entry id={id:?}"); let entry = self.batch_entries.remove(&id).unwrap(); let adapter_index = r.adapter_index; let adapter = self.index_to_adapter.get_mut(&adapter_index).unwrap(); diff --git a/router/src/infer.rs b/router/src/infer.rs index 5626302c3..b3b608dc3 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -9,8 +9,6 @@ use crate::{ MessageChunk, TextMessage, Token, TokenizerConfigToken, Tool, }; use crate::{GenerateRequest, PrefillToken}; -use flume::r#async::RecvStream; -use flume::SendTimeoutError; use futures::future::try_join_all; use futures::stream::StreamExt; /// Batching and inference logic @@ -32,8 +30,10 @@ use std::sync::{ use std::time::Duration; use thiserror::Error; use tokenizers::Tokenizer; -use tokio::sync::{Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Span}; #[derive(Clone, Serialize, Deserialize, Default)] @@ -276,7 +276,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, - RecvStream>, + UnboundedReceiverStream>, ), InferError, > { @@ -330,7 +330,7 @@ impl Infer { })?; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Process the request by sending it to the queue associated with `adapter` self.adapter_scheduler.process( @@ -348,7 +348,7 @@ impl Infer { ); // Return stream - Ok((permit, response_rx.into_stream())) + Ok((permit, UnboundedReceiverStream::new(response_rx))) } /// Tokenizer the input @@ -542,7 +542,7 @@ impl Infer { }; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Process the request by sending it to the queue associated with `adapter` self.adapter_scheduler.process( @@ -562,7 +562,7 @@ impl Infer { // Return values let mut return_embeddings = None; - let mut stream = response_rx.into_stream(); + let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { // Add prefill tokens @@ -643,7 +643,7 @@ impl Infer { }; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Process the request by sending it to the queue associated with `adapter` self.adapter_scheduler.process( @@ -665,7 +665,7 @@ impl Infer { let mut result_start = None; let mut result_queued = None; - let mut stream = response_rx.into_stream(); + let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { // Add prefill tokens @@ -743,7 +743,7 @@ impl Infer { ); // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); let request_id_map: HashMap = request .inputs @@ -793,7 +793,7 @@ impl Infer { // Return values let mut all_entities = HashMap::new(); - let mut stream = response_rx.into_stream(); + let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { // Add prefill tokens @@ -1198,10 +1198,7 @@ pub(crate) async fn embed( // request and we need to stop generating hence why we unwrap_or(true) let stopped = send_embeddings(embedding, entry) .map_err(|err| { - if let SendTimeoutError::Timeout(_) = *err { - tracing::error!("Entry response channel timed out.") - } - + tracing::error!("Entry response channel error."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); err }) @@ -1258,10 +1255,7 @@ pub(crate) async fn classify( // request and we need to stop generating hence why we unwrap_or(true) let stopped = send_classifications(predictions, entry) .map_err(|err| { - if let SendTimeoutError::Timeout(_) = *err { - tracing::error!("Entry response channel timed out.") - } - + tracing::error!("Entry response channel error."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); err }) @@ -1329,6 +1323,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap, entries: &mut IntMap "dropped"); - err }).unwrap_or(true); if stopped { + tracing::info!("!!! filter_send_generations::remove entry id={id:?}"); entries.remove(&id).expect("ID not found in entries. This is a bug."); } }); @@ -1356,9 +1348,10 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { + tracing::info!("!!! send_responses::disconnected"); return Ok(true); } @@ -1366,13 +1359,10 @@ fn send_responses( if generation.prefill_tokens_length > 0 { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Prefill { - tokens: generation.prefill_tokens, - tokens_length: generation.prefill_tokens_length, - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::Prefill { + tokens: generation.prefill_tokens, + tokens_length: generation.prefill_tokens_length, + }))?; } // Create last Token @@ -1420,25 +1410,24 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { + tracing::info!( + "!!! send_responses::generation_ended id={id:?} generated_text={generated_text:?}" + ); // Generation has ended stopped = true; // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::End { - token, - generated_text: generated_text.clone(), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + generated_text: generated_text.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; } _ => { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Token(token)), - Duration::from_millis(10), - )?; + entry + .response_tx + .send(Ok(InferStreamResponse::Token(token)))?; } } } @@ -1450,20 +1439,17 @@ fn send_responses( fn send_embeddings( embedding: Embedding, entry: &Entry, -) -> Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { return Ok(true); } - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Embed { - embedding: embedding.clone(), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::Embed { + embedding: embedding.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; // TODO(travis): redundant as we always return true, just make it return nothing Ok(true) @@ -1473,21 +1459,18 @@ fn send_embeddings( fn send_classifications( predictions: ClassifyPredictionList, entry: &Entry, -) -> Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { return Ok(true); } - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Classify { - predictions: predictions.clone(), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - id: entry.id, - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::Classify { + predictions: predictions.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + id: entry.id, + }))?; // TODO(travis): redundant as we always return true, just make it return nothing Ok(true) @@ -1506,7 +1489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx - .send_timeout(Err(err), Duration::from_millis(10)) + .send(Err(err)) .unwrap_or(()); }); } diff --git a/router/src/loader.rs b/router/src/loader.rs index 1233c664c..274e636de 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -3,20 +3,20 @@ use crate::infer::InferError; use crate::queue::{AdapterQueuesState, AdapterStatus}; use lorax_client::ShardedClient; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tracing::Span; /// Request AdapterLoader #[derive(Debug, Clone)] pub(crate) struct AdapterLoader { /// Channel to communicate with the background task - sender: flume::Sender, + sender: mpsc::UnboundedSender, } impl AdapterLoader { pub(crate) fn new(client: ShardedClient) -> Self { // Create channel - let (sender, receiver) = flume::unbounded(); + let (sender, receiver) = mpsc::unbounded_channel(); // Launch background queue task tokio::spawn(loader_task(client, receiver)); @@ -115,10 +115,13 @@ impl AdapterLoader { } // Background task responsible of the loader state -async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver) { +async fn loader_task( + mut client: ShardedClient, + mut receiver: mpsc::UnboundedReceiver, +) { let mut err_msgs: HashMap = HashMap::new(); - while let Ok(cmd) = receiver.recv_async().await { + while let Some(cmd) = receiver.recv().await { match cmd { AdapterLoaderCommand::DownloadAdapter { adapter, diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 22bcb2418..9709cc2a8 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -7,7 +7,7 @@ use crate::{ }; use lorax_client::{Batch, ShardedClient}; use std::{cmp::max, collections::HashSet, sync::Arc}; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tracing::{info_span, instrument, Instrument, Span}; enum AdapterSchedulerCommand { @@ -25,7 +25,7 @@ enum AdapterSchedulerCommand { #[derive(Clone)] pub(crate) struct AdapterScheduler { - sender: flume::Sender, + sender: mpsc::UnboundedSender, } impl AdapterScheduler { @@ -43,7 +43,7 @@ impl AdapterScheduler { chunked_prefill: bool, is_causal_lm: bool, ) -> Self { - let (sender, receiver) = flume::unbounded(); + let (sender, receiver) = mpsc::unbounded_channel(); // receives requests from the infer struct and sends them to the appropriate adapter queue tokio::spawn(adapter_scheduler_task( @@ -118,7 +118,7 @@ async fn adapter_scheduler_task( requires_padding: bool, block_size: u32, window_size: Option, - receiver: flume::Receiver, + mut receiver: mpsc::UnboundedReceiver, max_active_adapters: usize, adapter_cycle_time_s: u64, speculate: u32, @@ -141,7 +141,7 @@ async fn adapter_scheduler_task( is_causal_lm, ); - while let Ok(cmd) = receiver.recv_async().await { + while let Some(cmd) = receiver.recv().await { match cmd { AdapterSchedulerCommand::Append(adapter, entry) => { state.append(adapter, adapter_event.clone(), entry).await; @@ -330,7 +330,7 @@ impl AdapterSchedulerState { 'entry_loop: while let Some((id, mut entry, adapter)) = self.next_entry().await { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); continue; } diff --git a/router/src/validation.rs b/router/src/validation.rs index 76ec6d7ca..4a7a91edc 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -12,7 +12,7 @@ use std::io::Cursor; use std::iter; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tracing::{instrument, Span}; use {once_cell::sync::Lazy, regex::Regex}; @@ -25,7 +25,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, /// Channel to communicate with the background tokenization task - sender: Option>, + sender: Option>, } impl Validation { @@ -41,15 +41,17 @@ impl Validation { ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { - // Create channel - let (validation_sender, validation_receiver) = flume::unbounded(); + // Create round robin channel + let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); + let mut senders = Vec::with_capacity(workers); // Create workers for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); let config_clone = config.clone(); let preprocessor_config_clone = preprocessor_config.clone(); - let receiver_clone = validation_receiver.clone(); + let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); + senders.push(tokenizer_sender); // Spawn worker tokio::task::spawn_blocking(move || { @@ -57,10 +59,14 @@ impl Validation { tokenizer_clone, config_clone, preprocessor_config_clone, - receiver_clone, + tokenizer_receiver, ) }); } + + // Create tokenization round robin task + tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); + Some(validation_sender) } else { None @@ -390,15 +396,30 @@ impl Validation { } } +/// Round robin tokenization task +async fn round_robin_task( + mut receiver: mpsc::UnboundedReceiver, + senders: Vec>, +) { + loop { + for sender in &senders { + match receiver.recv().await { + None => return, + Some(request) => sender.send(request).unwrap(), + }; + } + } +} + /// Start tokenization workers fn tokenizer_worker( tokenizer: Tokenizer, config: Option, preprocessor_config: Option, - receiver: flume::Receiver, + mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests - while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { + while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx .send(prepare_input( diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 152e1cb50..9390bba09 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -2078,7 +2078,6 @@ def generate_token( batch.next_token_chooser.next_state(i, next_token_id) # Update values - # logger.info(f"!!! UPDATE VALUES {i} n_accepted_ids={n_accepted_ids} new_input_length={new_input_length} input_length={input_length} cache_length={cache_length}") index += n_accepted_ids batch.cache_lengths[i] = new_cache_length batch.max_input_length = max(batch.max_input_length, new_input_length) From 8cb79b2232b8c8f0fecdc49c959b084951883531 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 09:58:18 -0700 Subject: [PATCH 57/76] Remove unused dep --- Cargo.lock | 21 +-------------------- router/Cargo.toml | 1 - 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9512fa06b..5a889d52c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -295,7 +295,7 @@ dependencies = [ "http", "opentelemetry 0.18.0", "tower", - "tower-http 0.3.5", + "tower-http", "tracing", "tracing-opentelemetry 0.18.0", ] @@ -1598,7 +1598,6 @@ dependencies = [ "tokenizers", "tokio", "tokio-stream", - "tower-http 0.4.1", "tracing", "tracing-opentelemetry 0.19.0", "tracing-subscriber 0.3.17", @@ -3719,24 +3718,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tower-http" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c" -dependencies = [ - "bitflags 2.3.3", - "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", - "pin-project-lite", - "tower-layer", - "tower-service", -] - [[package]] name = "tower-layer" version = "0.3.2" diff --git a/router/Cargo.toml b/router/Cargo.toml index a8ed2c2e5..8fd447cbb 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -48,7 +48,6 @@ tokio = { version = "1.32.0", features = [ "sync", ] } tokio-stream = "0.1.14" -RecvStreamtower-http = { version = "0.4.0", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } From 045a45a40b0a3c82dc456b9320929b385a9b6a09 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 10:21:06 -0700 Subject: [PATCH 58/76] Update axum --- Cargo.lock | 529 ++++++++++++++++++++++++++++++++----------- router/Cargo.toml | 12 +- router/src/main.rs | 2 +- router/src/server.rs | 91 ++++---- 4 files changed, 442 insertions(+), 192 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a889d52c..6e0630435 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,13 +242,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.3.4", "bitflags 1.3.2", "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "itoa", "matchit", "memchr", @@ -260,13 +260,47 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "tokio", "tower", "tower-layer", "tower-service", ] +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core 0.4.5", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.3.4" @@ -276,28 +310,51 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 0.2.9", + "http-body 0.4.5", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", "mime", + "pin-project-lite", "rustversion", + "sync_wrapper 1.0.1", "tower-layer", "tower-service", + "tracing", ] [[package]] name = "axum-tracing-opentelemetry" -version = "0.10.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "164b95427e83b79583c7699a72b4a6b485a12bbdef5b5c054ee5ff2296d82f52" +checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" dependencies = [ - "axum", - "futures", - "http", - "opentelemetry 0.18.0", + "axum 0.7.5", + "futures-core", + "futures-util", + "http 1.1.0", + "opentelemetry 0.21.0", + "pin-project-lite", "tower", - "tower-http", "tracing", - "tracing-opentelemetry 0.18.0", + "tracing-opentelemetry 0.22.0", + "tracing-opentelemetry-instrumentation-sdk", ] [[package]] @@ -728,33 +785,13 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys 0.3.7", -] - [[package]] name = "dirs" version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys 0.4.1", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", + "dirs-sys", ] [[package]] @@ -1058,6 +1095,12 @@ version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "grpc-metadata" version = "0.1.0" @@ -1079,7 +1122,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.0.0", "slab", "tokio", @@ -1142,7 +1185,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" dependencies = [ - "dirs 5.0.1", + "dirs", "futures", "indicatif", "log", @@ -1179,6 +1222,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1186,15 +1240,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] [[package]] -name = "http-range-header" -version = "0.3.0" +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "pin-project-lite", +] [[package]] name = "httparse" @@ -1219,8 +1290,8 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", + "http 0.2.9", + "http-body 0.4.5", "httparse", "httpdate", "itoa", @@ -1232,13 +1303,32 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + [[package]] name = "hyper-timeout" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.27", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1251,12 +1341,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.27", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-util" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "hyper 1.5.0", + "pin-project-lite", + "tokio", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -1369,6 +1474,19 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "init-tracing-opentelemetry" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" +dependencies = [ + "opentelemetry 0.20.0", + "opentelemetry-otlp 0.13.0", + "thiserror", + "tracing", + "tracing-opentelemetry 0.21.0", +] + [[package]] name = "instant" version = "0.1.12" @@ -1564,7 +1682,7 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", "clap", @@ -1572,6 +1690,7 @@ dependencies = [ "h2", "hf-hub", "image", + "init-tracing-opentelemetry", "itertools 0.12.1", "lorax-client", "metrics", @@ -1583,7 +1702,7 @@ dependencies = [ "once_cell", "openssl", "opentelemetry 0.19.0", - "opentelemetry-otlp", + "opentelemetry-otlp 0.12.0", "rand", "regex", "reqwest", @@ -1598,6 +1717,7 @@ dependencies = [ "tokenizers", "tokio", "tokio-stream", + "tower-http", "tracing", "tracing-opentelemetry 0.19.0", "tracing-subscriber 0.3.17", @@ -1714,7 +1834,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5" dependencies = [ "base64 0.21.2", - "hyper", + "hyper 0.14.27", "indexmap 1.9.3", "ipnet", "metrics", @@ -1894,12 +2014,12 @@ dependencies = [ "async-rustls", "async-trait", "awaitdrop", - "axum", + "axum 0.6.18", "base64 0.13.1", "bytes", "futures", "hostname", - "hyper", + "hyper 0.14.27", "muxado", "once_cell", "parking_lot 0.12.1", @@ -2128,22 +2248,38 @@ dependencies = [ [[package]] name = "opentelemetry" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" +checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f" dependencies = [ - "opentelemetry_api 0.18.0", - "opentelemetry_sdk 0.18.0", + "opentelemetry_api 0.19.0", + "opentelemetry_sdk 0.19.0", ] [[package]] name = "opentelemetry" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f" +checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54" dependencies = [ - "opentelemetry_api 0.19.0", - "opentelemetry_sdk 0.19.0", + "opentelemetry_api 0.20.0", + "opentelemetry_sdk 0.20.0", +] + +[[package]] +name = "opentelemetry" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" +dependencies = [ + "futures-core", + "futures-sink", + "indexmap 2.0.0", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", ] [[package]] @@ -2155,15 +2291,34 @@ dependencies = [ "async-trait", "futures", "futures-util", - "http", + "http 0.2.9", "opentelemetry 0.19.0", - "opentelemetry-proto", + "opentelemetry-proto 0.2.0", "prost", "thiserror", "tokio", "tonic 0.8.3", ] +[[package]] +name = "opentelemetry-otlp" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275" +dependencies = [ + "async-trait", + "futures-core", + "http 0.2.9", + "opentelemetry-proto 0.3.0", + "opentelemetry-semantic-conventions", + "opentelemetry_api 0.20.0", + "opentelemetry_sdk 0.20.0", + "prost", + "thiserror", + "tokio", + "tonic 0.9.2", +] + [[package]] name = "opentelemetry-proto" version = "0.2.0" @@ -2177,32 +2332,53 @@ dependencies = [ "tonic 0.8.3", ] +[[package]] +name = "opentelemetry-proto" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" +dependencies = [ + "opentelemetry_api 0.20.0", + "opentelemetry_sdk 0.20.0", + "prost", + "tonic 0.9.2", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269" +dependencies = [ + "opentelemetry 0.20.0", +] + [[package]] name = "opentelemetry_api" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22" +checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2" dependencies = [ "fnv", "futures-channel", "futures-util", "indexmap 1.9.3", - "js-sys", "once_cell", "pin-project-lite", "thiserror", + "urlencoding", ] [[package]] name = "opentelemetry_api" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2" +checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b" dependencies = [ - "fnv", "futures-channel", "futures-util", "indexmap 1.9.3", + "js-sys", "once_cell", "pin-project-lite", "thiserror", @@ -2211,9 +2387,9 @@ dependencies = [ [[package]] name = "opentelemetry_sdk" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113" +checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1" dependencies = [ "async-trait", "crossbeam-channel", @@ -2223,7 +2399,7 @@ dependencies = [ "futures-executor", "futures-util", "once_cell", - "opentelemetry_api 0.18.0", + "opentelemetry_api 0.19.0", "percent-encoding", "rand", "thiserror", @@ -2233,32 +2409,71 @@ dependencies = [ [[package]] name = "opentelemetry_sdk" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1" +checksum = "fa8e705a0612d48139799fcbaba0d4a90f06277153e43dd2bdc16c6f0edd8026" dependencies = [ "async-trait", "crossbeam-channel", - "dashmap", - "fnv", "futures-channel", "futures-executor", "futures-util", "once_cell", - "opentelemetry_api 0.19.0", + "opentelemetry_api 0.20.0", + "ordered-float 3.9.2", "percent-encoding", "rand", + "regex", + "serde_json", "thiserror", "tokio", "tokio-stream", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f16aec8a98a457a52664d69e0091bac3a0abd18ead9b641cb00202ba4e0efe4" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.21.0", + "ordered-float 4.4.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "option-ext" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "3.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc" +dependencies = [ + "num-traits", +] + +[[package]] +name = "ordered-float" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e7ccb95e240b7c9506a3d544f10d935e142cc90b0a1d56954fb44d89ad6b97" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" @@ -2758,9 +2973,9 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-tls", "ipnet", "js-sys", @@ -2792,7 +3007,7 @@ checksum = "88a3e86aa6053e59030e7ce2d2a3b258dd08fc2d337d52f73f6cb480f5858690" dependencies = [ "anyhow", "async-trait", - "http", + "http 0.2.9", "reqwest", "serde", "task-local-extensions", @@ -2810,8 +3025,8 @@ dependencies = [ "chrono", "futures", "getrandom", - "http", - "hyper", + "http 0.2.9", + "hyper 0.14.27", "parking_lot 0.11.2", "reqwest", "reqwest-middleware", @@ -2874,9 +3089,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "6.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a36224c3276f8c4ebc8c20f158eca7ca4359c8db89991c4925132aaaf6702661" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -2885,23 +3100,22 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "6.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b94b81e5b2c284684141a2fb9e2a31be90638caf040bf9afbc5a0416afe1ac" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "shellexpand", "syn 2.0.60", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "7.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d38ff6bf570dc3bb7100fce9f7b60c33fa71d80e88da3f2580df4ff2bdded74" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" dependencies = [ "sha2", "walkdir", @@ -3149,15 +3363,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shellexpand" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" -dependencies = [ - "dirs 4.0.0", -] - [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -3208,9 +3413,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" @@ -3311,6 +3516,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" + [[package]] name = "sysinfo" version = "0.30.13" @@ -3614,15 +3825,15 @@ checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.6.18", "base64 0.13.1", "bytes", "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-timeout", "percent-encoding", "pin-project", @@ -3645,15 +3856,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" dependencies = [ "async-trait", - "axum", + "axum 0.6.18", "base64 0.21.2", "bytes", "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-timeout", "percent-encoding", "pin-project", @@ -3701,28 +3912,23 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.3.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" +checksum = "8437150ab6bbc8c5f0f519e3d5ed4aa883a83dd4cdd3d1b21f9482936046cb97" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.3.3", "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", + "http 1.1.0", "pin-project-lite", "tower-layer", "tower-service", - "tracing", ] [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" @@ -3785,17 +3991,14 @@ dependencies = [ ] [[package]] -name = "tracing-opentelemetry" -version = "0.18.0" +name = "tracing-log" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ + "log", "once_cell", - "opentelemetry 0.18.0", - "tracing", "tracing-core", - "tracing-log", - "tracing-subscriber 0.3.17", ] [[package]] @@ -3808,10 +4011,56 @@ dependencies = [ "opentelemetry 0.19.0", "tracing", "tracing-core", - "tracing-log", + "tracing-log 0.1.3", "tracing-subscriber 0.3.17", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8" +dependencies = [ + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry_sdk 0.20.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.1.3", + "tracing-subscriber 0.3.17", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c67ac25c5407e7b961fafc6f7e9aa5958fd297aada2d20fa2ae1737357e55596" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.21.0", + "opentelemetry_sdk 0.21.2", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber 0.3.17", + "web-time", +] + +[[package]] +name = "tracing-opentelemetry-instrumentation-sdk" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c" +dependencies = [ + "http 1.1.0", + "opentelemetry 0.21.0", + "tracing", + "tracing-opentelemetry 0.22.0", +] + [[package]] name = "tracing-serde" version = "0.1.3" @@ -3840,7 +4089,7 @@ dependencies = [ "thread_local", "tracing", "tracing-core", - "tracing-log", + "tracing-log 0.1.3", "tracing-serde", ] @@ -3861,7 +4110,7 @@ dependencies = [ "thread_local", "tracing", "tracing-core", - "tracing-log", + "tracing-log 0.1.3", "tracing-serde", ] @@ -4014,9 +4263,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "utoipa" -version = "3.4.0" +version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "520434cac5c98120177d5cc15be032703f6dca7d5ef82e725c798113b375000a" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ "indexmap 2.0.0", "serde", @@ -4026,9 +4275,9 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "3.4.1" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e22e88a487b6e0374533871b79b1f5ded05671bd0936bd547eb42f82fb9060d" +checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392" dependencies = [ "proc-macro-error", "proc-macro2", @@ -4039,11 +4288,11 @@ dependencies = [ [[package]] name = "utoipa-swagger-ui" -version = "3.1.4" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4602d7100d3cfd8a086f30494e68532402ab662fa366c9d201d677e33cee138d" +checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" dependencies = [ - "axum", + "axum 0.7.5", "mime_guess", "regex", "rust-embed", @@ -4221,6 +4470,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa30049b1c872b72c89866d458eae9f20380ab280ffd1b1e18df2d3e2d98cfe0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" diff --git a/router/Cargo.toml b/router/Cargo.toml index 8fd447cbb..c325ba9dc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,8 +16,8 @@ path = "src/main.rs" [dependencies] async-stream = "0.3.3" -axum = { version = "0.6.4", features = ["json"] } -axum-tracing-opentelemetry = "0.10.0" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" clap = { version = "4.1.4", features = ["derive", "env"] } futures = "0.3.26" hf-hub = { version = "0.3.0", features = ["tokio"] } @@ -48,12 +48,16 @@ tokio = { version = "1.32.0", features = [ "sync", ] } tokio-stream = "0.1.14" +tower-http = { version = "0.6.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } -utoipa = { version = "3.0.1", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.12.3", features = ["axum"], optional = true } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } once_cell = "1.19.0" itertools = "0.12.1" async-trait = "0.1.80" diff --git a/router/src/main.rs b/router/src/main.rs index ec8b61aae..249031258 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -533,7 +533,7 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { if let Ok(tracer) = tracer { layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); - axum_tracing_opentelemetry::init_propagator().unwrap(); + init_tracing_opentelemetry::init_propagator().unwrap(); }; } diff --git a/router/src/server.rs b/router/src/server.rs index 3370114e9..afdf2dbe2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -25,6 +25,7 @@ use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; +use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use futures::stream::StreamExt; use futures::Stream; @@ -39,6 +40,7 @@ use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::sync::Mutex; +use thiserror::Error; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::mpsc; @@ -480,6 +482,12 @@ async fn chat_completions_v1( } } +#[derive(Debug, Error)] +pub enum WebServerError { + #[error("Axum error: {0}")] + Axum(#[from] axum::BoxError), +} + type PreparedInput = (String, Option, bool); pub(crate) fn prepare_chat_input( @@ -1523,12 +1531,15 @@ pub async fn run( tracing::info!("REQUEST_LOGGER_URL not set, request logging is disabled"); } + #[allow(unused_mut)] // mut is needed for conditional compilation + let mut doc = ApiDoc::openapi(); + + // Configure Swagger UI + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); + // Create router - let app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) + let base_routes = Router::new() // Base routes - .route("/", post(compat_generate)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/embed", post(embed)) .route("/classify", post(classify)) @@ -1537,16 +1548,28 @@ pub async fn run( .route("/v1/completions", post(completions_v1)) .route("/v1/chat/completions", post(chat_completions_v1)) // AWS Sagemaker route - .route("/invocations", post(compat_generate)) - // Base Health route - .route("/health", get(health)) + .route("/invocations", post(compat_generate)); + + let info_routes = Router::new() // Inference API health route .route("/", get(health)) + // Base Health route + .route("/health", get(health)) + .route("/info", get(get_model_info)) // AWS Sagemaker health route .route("/ping", get(health)) // Prometheus metrics route .route("/metrics", get(metrics)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + // Combine routes and layers + let mut app = Router::new() + .merge(swagger_ui) + .merge(base_routes) + .merge(info_routes); + + // add layers after routes + app = app .layer(Extension(info)) .layer(Extension(client.clone())) .layer(Extension(request_logger_sender.clone())) @@ -1554,53 +1577,16 @@ pub async fn run( .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(prom_handle.clone())) - .layer(opentelemetry_tracing_layer()) + .layer(OtelAxumLayer::default()) .layer(cors_layer) .layer(Extension(cloned_tokenizer)); if ngrok { #[cfg(feature = "ngrok")] { - use ngrok::config::TunnelBuilder; - - let _ = addr; - - let authtoken = - ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); - - let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling"); - - let tunnel = ngrok::Session::builder() - .authtoken(authtoken) - .connect() - .await - .unwrap() - .labeled_tunnel() - .label("edge", edge); - - let listener = tunnel.listen().await.unwrap(); - - // Run prom metrics and health locally too - tokio::spawn( - axum::Server::bind(&addr) - .serve( - Router::new() - .route("/health", get(health)) - .route("/metrics", get(metrics)) - .layer(Extension(health_ext)) - .layer(Extension(prom_handle)) - .into_make_service(), - ) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()), - ); + panic!("ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable."); // Run server - axum::Server::builder(listener) - .serve(app.into_make_service()) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) - .await?; } #[cfg(not(feature = "ngrok"))] { @@ -1609,15 +1595,16 @@ pub async fn run( let _ngrok_username = ngrok_username; let _ngrok_password = ngrok_password; - panic!("`lorax-router` was compiled without the `ngrok` feature"); + panic!("`text-generation-router` was compiled without the `ngrok` feature"); } } else { // Run server - axum::Server::bind(&addr) - .serve(app.into_make_service()) - // Wait until all requests are finished to shut down + + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) - .await?; + .await + .map_err(|err| WebServerError::Axum(Box::new(err)))?; } Ok(()) } From 20cf752b8854ee2ef9ca9fe88e491189afce8b87 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 10:54:43 -0700 Subject: [PATCH 59/76] Client debug mode, fixed / --- clients/python/lorax/client.py | 18 +++++++++++++++++- router/src/server.rs | 4 ++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 905df9417..a8028834c 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -1,4 +1,5 @@ import json +import logging import requests from requests.adapters import HTTPAdapter, Retry @@ -20,7 +21,22 @@ from lorax.errors import parse_error import os -LORAX_DEBUG_MODE = os.getenv("LORAD_DEBUG_MODE", None) is not None +LORAX_DEBUG_MODE = os.getenv("LORAX_DEBUG_MODE", None) is not None +if LORAX_DEBUG_MODE: + # https://stackoverflow.com/a/16630836/1869739 + # These two lines enable debugging at httplib level (requests->urllib3->http.client) + # You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. + # The only thing missing will be the response.body which is not logged. + import http.client as http_client + http_client.HTTPConnection.debuglevel = 1 + + # You must initialize logging, otherwise you'll not see debug output. + logging.basicConfig() + logging.getLogger().setLevel(logging.DEBUG) + requests_log = logging.getLogger("requests.packages.urllib3") + requests_log.setLevel(logging.DEBUG) + requests_log.propagate = True + class Client: """Client to make calls to a LoRAX instance diff --git a/router/src/server.rs b/router/src/server.rs index afdf2dbe2..a7541e392 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1540,6 +1540,7 @@ pub async fn run( // Create router let base_routes = Router::new() // Base routes + .route("/", post(compat_generate)) .route("/generate", post(generate)) .route("/embed", post(embed)) .route("/classify", post(classify)) @@ -1551,7 +1552,6 @@ pub async fn run( .route("/invocations", post(compat_generate)); let info_routes = Router::new() - // Inference API health route .route("/", get(health)) // Base Health route .route("/health", get(health)) @@ -1595,7 +1595,7 @@ pub async fn run( let _ngrok_username = ngrok_username; let _ngrok_password = ngrok_password; - panic!("`text-generation-router` was compiled without the `ngrok` feature"); + panic!("`lorax-router` was compiled without the `ngrok` feature"); } } else { // Run server From 2868accc407f92158a680cc1767a79526bf138c9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 11:33:07 -0700 Subject: [PATCH 60/76] Docker test --- .github/workflows/build.yaml | 6 ++---- router/src/batch.rs | 1 - router/src/infer.rs | 6 ------ 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 13b9e96ca..3e7ac15a0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,7 @@ on: push: branches: - 'main' + - 'optimizations' tags: - 'v*' @@ -69,10 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix=,suffix=,format=short - type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=optimizations-5,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: diff --git a/router/src/batch.rs b/router/src/batch.rs index 2363dbdf6..39f8d32c7 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -222,7 +222,6 @@ impl BatchEntriesState { // TODO(travis): clone is not ideal, find a way to do this cleanly in place for r in self.batch_requests.clone().into_iter().rev() { let id = r.id; - tracing::info!("!!! drain::remove entry id={id:?}"); let entry = self.batch_entries.remove(&id).unwrap(); let adapter_index = r.adapter_index; let adapter = self.index_to_adapter.get_mut(&adapter_index).unwrap(); diff --git a/router/src/infer.rs b/router/src/infer.rs index b3b608dc3..0a4c4e5cc 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1323,7 +1323,6 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap, entries: &mut IntMap "dropped"); }).unwrap_or(true); if stopped { - tracing::info!("!!! filter_send_generations::remove entry id={id:?}"); entries.remove(&id).expect("ID not found in entries. This is a bug."); } }); @@ -1351,7 +1349,6 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - tracing::info!("!!! send_responses::disconnected"); return Ok(true); } @@ -1410,9 +1407,6 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { - tracing::info!( - "!!! send_responses::generation_ended id={id:?} generated_text={generated_text:?}" - ); // Generation has ended stopped = true; // Send message From 2131dc137b3f75d6be17efe62d7b597d503dda7f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 11:38:52 -0700 Subject: [PATCH 61/76] Fixed unused imports --- router/src/infer.rs | 1 - router/src/server.rs | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 0a4c4e5cc..4bdc4dcf4 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -27,7 +27,6 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use std::time::Duration; use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::mpsc::error::SendError; diff --git a/router/src/server.rs b/router/src/server.rs index a7541e392..011b509e1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -26,7 +26,6 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; -use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use futures::stream::StreamExt; use futures::Stream; use lorax_client::{ShardInfo, ShardedClient}; @@ -1292,8 +1291,8 @@ pub async fn run( cors_expose_headers: Option, tokenizer_config: HubTokenizerConfig, ngrok: bool, - ngrok_authtoken: Option, - ngrok_edge: Option, + _ngrok_authtoken: Option, + _ngrok_edge: Option, adapter_source: String, eager_prefill: bool, prefix_caching: bool, From b727a94619e418be4d9671fa53849e1ea7ab5a29 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 11:59:32 -0700 Subject: [PATCH 62/76] Revert docker --- .github/workflows/build.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3e7ac15a0..13b9e96ca 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,7 +5,6 @@ on: push: branches: - 'main' - - 'optimizations' tags: - 'v*' @@ -70,7 +69,10 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-5,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix=,suffix=,format=short + type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} - name: Create a hash from tags env: From cc17d4771fed3e6807822da336767574b9c2e8f1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 12:09:58 -0700 Subject: [PATCH 63/76] Add back tracing --- router/src/infer.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/router/src/infer.rs b/router/src/infer.rs index 4bdc4dcf4..9ff3dd906 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1322,6 +1322,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap, entries: &mut IntMap "dropped"); }).unwrap_or(true); if stopped { + tracing::info!("!!! filter_send_generations::remove entry id={id:?}"); entries.remove(&id).expect("ID not found in entries. This is a bug."); } }); @@ -1348,6 +1350,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { + tracing::info!("!!! send_responses::disconnected"); return Ok(true); } @@ -1406,6 +1409,9 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { + tracing::info!( + "!!! send_responses::generation_ended id={id:?} generated_text={generated_text:?}" + ); // Generation has ended stopped = true; // Send message From 68991ba76149479342af8fa21006e3338955eb73 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 12:44:46 -0700 Subject: [PATCH 64/76] Debug --- router/src/infer.rs | 5 +++-- router/src/scheduler.rs | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 9ff3dd906..4aec1d2da 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1348,9 +1348,10 @@ fn send_responses( generation: Generation, entry: &Entry, ) -> Result>>> { - // Return directly if the channel is disconnected + // Return directly if the channel is closed if entry.response_tx.is_closed() { - tracing::info!("!!! send_responses::disconnected"); + tracing::error!("Entry response channel closed."); + metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); return Ok(true); } diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 9709cc2a8..d557faa89 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -331,6 +331,7 @@ impl AdapterSchedulerState { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { + tracing::error!("Entry response channel closed."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); continue; } From 5380426c6499028dc63f0f614bee6e6065bac16c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 12:44:59 -0700 Subject: [PATCH 65/76] Docker test --- .github/workflows/build.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 13b9e96ca..bef72f695 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,7 @@ on: push: branches: - 'main' + - 'optimizations' tags: - 'v*' @@ -69,10 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix=,suffix=,format=short - type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=optimizations-6,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: From 89abd515efb20f1b2f364b10a0506c0840529ba7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 15:46:41 -0700 Subject: [PATCH 66/76] Debug registration --- router/src/infer.rs | 29 +++++++++++++++++++++++++++-- router/src/queue.rs | 10 +++++++++- router/src/server.rs | 5 +++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 4aec1d2da..fad8ddcfb 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -408,9 +408,17 @@ impl Infer { let mut result_start = None; let mut result_queued = None; + tracing::info!("Waiting for response"); + + let mut id = None; + // Iterate on stream while let Some(response) = stream.next().await { match response? { + InferStreamResponse::Register { id_val } => { + id = Some(id_val); + tracing::info!("Register response id={id:?}"); + } // Add prefill tokens InferStreamResponse::Prefill { tokens, @@ -428,9 +436,13 @@ impl Infer { .collect(); } result_prefill_length = tokens_length; + tracing::info!("Prefill response id={id:?}"); } // Push last token - InferStreamResponse::Token(token) => result_tokens.push(token), + InferStreamResponse::Token(token) => { + tracing::info!("Token response id={id:?}"); + result_tokens.push(token) + } // Final message // Set return values InferStreamResponse::End { @@ -439,6 +451,7 @@ impl Infer { start, queued, } => { + tracing::info!("End response id={id:?}"); result_tokens.push(token); result_generated_text = Some(generated_text); result_start = Some(start); @@ -455,6 +468,8 @@ impl Infer { } } + tracing::info!("Finished response id={id:?}"); + // Check that we received a `InferStreamResponse::End` message if let (Some(generated_text), Some(queued), Some(start)) = (result_generated_text, result_queued, result_start) @@ -564,6 +579,9 @@ impl Infer { let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { + InferStreamResponse::Register { .. } => { + tracing::error!("Received a Register message in embed. This is a bug."); + } // Add prefill tokens InferStreamResponse::Prefill { .. } => { tracing::error!("Received a Prefill message in embed. This is a bug."); @@ -667,6 +685,9 @@ impl Infer { let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { + InferStreamResponse::Register { .. } => { + tracing::error!("Received a Register message in classify. This is a bug."); + } // Add prefill tokens InferStreamResponse::Prefill { .. } => { tracing::error!("Received a Prefill message in classify. This is a bug."); @@ -1350,7 +1371,8 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is closed if entry.response_tx.is_closed() { - tracing::error!("Entry response channel closed."); + let id = generation.request_id; + tracing::error!("Entry id={id:?} response channel closed."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); return Ok(true); } @@ -1497,6 +1519,9 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message + Register { + id_val: u64, + }, Prefill { tokens: Option, tokens_length: u32, diff --git a/router/src/queue.rs b/router/src/queue.rs index 78eaaff06..04e13cd9f 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -9,7 +9,7 @@ use std::{ use tokio::{sync::Notify, time::Instant}; use tracing::info_span; -use crate::{adapter::Adapter, batch::Entry}; +use crate::{adapter::Adapter, batch::Entry, infer::InferStreamResponse}; #[derive(Debug, PartialEq)] pub(crate) enum AdapterStatus { @@ -74,6 +74,11 @@ impl QueueState { let queue_span = info_span!(parent: &entry.span, "queued"); entry.temp_span = Some(queue_span); + entry + .response_tx + .send(Ok(InferStreamResponse::Register { id_val: entry_id })) + .unwrap(); + // Push entry in the queue self.entries.push_back((entry_id, entry)); } @@ -214,9 +219,12 @@ impl AdapterQueuesState { // ensure that append completes before sending batcher message let queue = self.queue_map.get_mut(&adapter).unwrap(); + let id = self.next_id; queue.append(self.next_id, entry); self.next_id += 1; + tracing::info!("append entry id={:?} adapter={:?}", id, adapter.index()); + return download; } diff --git a/router/src/server.rs b/router/src/server.rs index 011b509e1..048ddb860 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1062,6 +1062,11 @@ async fn generate_stream_with_callback( match response { Ok(response) => { match response { + InferStreamResponse::Register { + .. + } => { + // Register is ignored + } // Prefill is ignored InferStreamResponse::Prefill { tokens_length, From 3c7b69b3bad5d262c8a9edf8af35be1b0b0ed780 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 1 Nov 2024 15:47:36 -0700 Subject: [PATCH 67/76] Update tag --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index bef72f695..50836f6ce 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,7 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-6,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=raw,value=optimizations-7,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: From d52f530129b17341abfa224c37a0319e423e25ed Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 09:26:30 -0800 Subject: [PATCH 68/76] Don't skip filter --- router/src/infer.rs | 39 +++++++------------ server/lorax_server/models/flash_causal_lm.py | 9 ++--- 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index fad8ddcfb..af693ecf0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -408,8 +408,6 @@ impl Infer { let mut result_start = None; let mut result_queued = None; - tracing::info!("Waiting for response"); - let mut id = None; // Iterate on stream @@ -417,7 +415,6 @@ impl Infer { match response? { InferStreamResponse::Register { id_val } => { id = Some(id_val); - tracing::info!("Register response id={id:?}"); } // Add prefill tokens InferStreamResponse::Prefill { @@ -436,13 +433,9 @@ impl Infer { .collect(); } result_prefill_length = tokens_length; - tracing::info!("Prefill response id={id:?}"); } // Push last token - InferStreamResponse::Token(token) => { - tracing::info!("Token response id={id:?}"); - result_tokens.push(token) - } + InferStreamResponse::Token(token) => result_tokens.push(token), // Final message // Set return values InferStreamResponse::End { @@ -451,7 +444,6 @@ impl Infer { start, queued, } => { - tracing::info!("End response id={id:?}"); result_tokens.push(token); result_generated_text = Some(generated_text); result_start = Some(start); @@ -468,8 +460,6 @@ impl Infer { } } - tracing::info!("Finished response id={id:?}"); - // Check that we received a `InferStreamResponse::End` message if let (Some(generated_text), Some(queued), Some(start)) = (result_generated_text, result_queued, result_start) @@ -1124,10 +1114,10 @@ pub(crate) async fn prefill( // Update health generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + let removed = filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, removed).await; // TODO(travis) // if let Some(concat_duration) = timings.concat { @@ -1167,10 +1157,10 @@ pub(crate) async fn decode( // Update health generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + let removed = filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, removed).await; metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("lorax_batch_inference_success", "method" => "decode"); @@ -1308,11 +1298,12 @@ async fn filter_batch( client: &mut ShardedClient, next_batch: Option, entries: &IntMap, + removed: bool, ) -> Option { let mut batch = next_batch?; - // No need to filter - if batch.size as usize == entries.len() { + // No need to filter is we haven't removed any entries + if !removed { return Some(batch); } @@ -1338,12 +1329,12 @@ async fn filter_batch( /// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// and filter entries #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { +fn filter_send_generations(generations: Vec, entries: &mut IntMap) -> bool { + let mut removed = false; generations.into_iter().for_each(|generation| { let id = generation.request_id; // Get entry // We can `expect` here as the request id should always be in the entries - tracing::info!("!!! filter_send_generations id={id:?}"); let entry = entries .get(&id) .expect("ID not found in entries. This is a bug."); @@ -1358,10 +1349,11 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); }).unwrap_or(true); if stopped { - tracing::info!("!!! filter_send_generations::remove entry id={id:?}"); entries.remove(&id).expect("ID not found in entries. This is a bug."); + removed = true; } }); + removed } /// Send responses through the `entry` response channel @@ -1370,9 +1362,9 @@ fn send_responses( entry: &Entry, ) -> Result>>> { // Return directly if the channel is closed + let request_id = generation.request_id; if entry.response_tx.is_closed() { - let id = generation.request_id; - tracing::error!("Entry id={id:?} response channel closed."); + tracing::error!("Entry id={request_id:?} response channel closed."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); return Ok(true); } @@ -1432,9 +1424,6 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { - tracing::info!( - "!!! send_responses::generation_ended id={id:?} generated_text={generated_text:?}" - ); // Generation has ended stopped = true; // Send message diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 9390bba09..627512330 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -388,9 +388,6 @@ def from_pb( def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == 0: raise ValueError("Batch must have at least one request") - # We assume that if len(requests) == len(self) then the requests are the same - if len(request_ids) == len(self): - return self device = self.block_tables_tensor.device @@ -1487,15 +1484,15 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> arange_int = arange.to(dtype=torch.int32) new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, # then update the slots with the additional indices to ensure we're grabbing the ones that have been # allocated slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = batch.slots[slot_indices] - input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) - block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length From 45c6c535b179e25cd8a245a209b2eddd2353c686 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 09:26:44 -0800 Subject: [PATCH 69/76] Docker test --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 50836f6ce..2120004db 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,7 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-7,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=raw,value=optimizations-8,enable=${{ github.ref == 'refs/heads/optimizations' }} - name: Create a hash from tags env: From 3ad4d66e6d493da2367f47316f7bf4b51f9002a7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:10:32 -0800 Subject: [PATCH 70/76] Remove register --- router/src/infer.rs | 14 -------------- router/src/queue.rs | 7 ------- router/src/server.rs | 5 ----- 3 files changed, 26 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index af693ecf0..0dbec8557 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -408,14 +408,9 @@ impl Infer { let mut result_start = None; let mut result_queued = None; - let mut id = None; - // Iterate on stream while let Some(response) = stream.next().await { match response? { - InferStreamResponse::Register { id_val } => { - id = Some(id_val); - } // Add prefill tokens InferStreamResponse::Prefill { tokens, @@ -569,9 +564,6 @@ impl Infer { let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { - InferStreamResponse::Register { .. } => { - tracing::error!("Received a Register message in embed. This is a bug."); - } // Add prefill tokens InferStreamResponse::Prefill { .. } => { tracing::error!("Received a Prefill message in embed. This is a bug."); @@ -675,9 +667,6 @@ impl Infer { let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { - InferStreamResponse::Register { .. } => { - tracing::error!("Received a Register message in classify. This is a bug."); - } // Add prefill tokens InferStreamResponse::Prefill { .. } => { tracing::error!("Received a Prefill message in classify. This is a bug."); @@ -1508,9 +1497,6 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message - Register { - id_val: u64, - }, Prefill { tokens: Option, tokens_length: u32, diff --git a/router/src/queue.rs b/router/src/queue.rs index 04e13cd9f..2a51d2779 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -74,11 +74,6 @@ impl QueueState { let queue_span = info_span!(parent: &entry.span, "queued"); entry.temp_span = Some(queue_span); - entry - .response_tx - .send(Ok(InferStreamResponse::Register { id_val: entry_id })) - .unwrap(); - // Push entry in the queue self.entries.push_back((entry_id, entry)); } @@ -223,8 +218,6 @@ impl AdapterQueuesState { queue.append(self.next_id, entry); self.next_id += 1; - tracing::info!("append entry id={:?} adapter={:?}", id, adapter.index()); - return download; } diff --git a/router/src/server.rs b/router/src/server.rs index 048ddb860..011b509e1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1062,11 +1062,6 @@ async fn generate_stream_with_callback( match response { Ok(response) => { match response { - InferStreamResponse::Register { - .. - } => { - // Register is ignored - } // Prefill is ignored InferStreamResponse::Prefill { tokens_length, From b45c219d524388b08b76da95340dda63cdfbb274 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:10:50 -0800 Subject: [PATCH 71/76] Revert docker --- .github/workflows/build.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2120004db..13b9e96ca 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,7 +5,6 @@ on: push: branches: - 'main' - - 'optimizations' tags: - 'v*' @@ -70,7 +69,10 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=optimizations-8,enable=${{ github.ref == 'refs/heads/optimizations' }} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix=,suffix=,format=short + type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} - name: Create a hash from tags env: From a4a2d5f47c5c6d71d4e24201c56e33356270b145 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:14:38 -0800 Subject: [PATCH 72/76] Fixed tests --- server/lorax_server/models/causal_lm.py | 8 +++++--- server/lorax_server/models/flash_qwen2.py | 7 ++++++- server/lorax_server/models/flash_roberta.py | 9 ++++++++- server/tests/utils/test_tokens.py | 7 ++++++- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 598c6b380..c642ecf03 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -596,9 +596,11 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option # TODO(travis): don't update this if indices haven't changed # Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous adapter_data = AdapterBatchData.from_meta( - batch.adapter_meta, - self.layer_to_adapter_weights, - prefill=True, + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=True, prefill_head_indices=None, ) diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index b52b2ef8d..10e2d3693 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -122,7 +122,12 @@ def embed(self, batch) -> torch.Tensor: adapter_meta = batch.adapter_meta prefill = False adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + meta=adapter_meta, + weights=self.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=prefill, + prefill_head_indices=batch.prefill_head_indices, ) embedding, _ = self.forward(batch, adapter_data=adapter_data) return embedding.cpu().tolist() diff --git a/server/lorax_server/models/flash_roberta.py b/server/lorax_server/models/flash_roberta.py index 8e6d41d7e..617c24191 100644 --- a/server/lorax_server/models/flash_roberta.py +++ b/server/lorax_server/models/flash_roberta.py @@ -209,7 +209,14 @@ def forward(self, batch: FlashEmbeddingClassificationBatch): @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: - adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.layer_to_adapter_weights, False, None) + adapter_data = AdapterBatchData.from_meta( + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=False, + prefill_head_indices=None, + ) with self._forward_context(cu_seqlens=batch.cu_seqlens): embedding: torch.Tensor = self.model.forward( diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 2da49c049..1e4380c23 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -80,7 +80,12 @@ def test_deterministic_tokens_temperature_zero(default_causal_lm, default_causal attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] adapter_data = AdapterBatchData.from_meta( - batch.adapter_meta, default_causal_lm.layer_to_adapter_weights, prefill=True, prefill_head_indices=None + meta=batch.adapter_meta, + weights=default_causal_lm.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=True, + prefill_head_indices=None, ) logits, _ = default_causal_lm.forward( From 4a264bcdf19023957a376fa0b3ea5aacc5130d3b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:14:47 -0800 Subject: [PATCH 73/76] ruff --- server/lorax_server/models/causal_lm.py | 6 +++--- server/lorax_server/models/flash_qwen2.py | 6 +++--- server/lorax_server/models/flash_roberta.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index c642ecf03..7eeaa9f36 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -596,11 +596,11 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option # TODO(travis): don't update this if indices haven't changed # Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous adapter_data = AdapterBatchData.from_meta( - meta=batch.adapter_meta, - weights=self.layer_to_adapter_weights, + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, layer_to_lora_weights={}, punica_wrapper=None, - prefill=True, + prefill=True, prefill_head_indices=None, ) diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 10e2d3693..f2c70687d 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -122,11 +122,11 @@ def embed(self, batch) -> torch.Tensor: adapter_meta = batch.adapter_meta prefill = False adapter_data = AdapterBatchData.from_meta( - meta=adapter_meta, - weights=self.layer_to_adapter_weights, + meta=adapter_meta, + weights=self.layer_to_adapter_weights, layer_to_lora_weights={}, punica_wrapper=None, - prefill=prefill, + prefill=prefill, prefill_head_indices=batch.prefill_head_indices, ) embedding, _ = self.forward(batch, adapter_data=adapter_data) diff --git a/server/lorax_server/models/flash_roberta.py b/server/lorax_server/models/flash_roberta.py index 617c24191..74768336e 100644 --- a/server/lorax_server/models/flash_roberta.py +++ b/server/lorax_server/models/flash_roberta.py @@ -210,11 +210,11 @@ def forward(self, batch: FlashEmbeddingClassificationBatch): @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: adapter_data = AdapterBatchData.from_meta( - meta=batch.adapter_meta, - weights=self.layer_to_adapter_weights, + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, layer_to_lora_weights={}, punica_wrapper=None, - prefill=False, + prefill=False, prefill_head_indices=None, ) From e1067a0685d84d43e641d6b4b958a406b93dc71e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:28:30 -0800 Subject: [PATCH 74/76] Fix tests --- router/src/queue.rs | 3 +-- server/lorax_server/utils/layers.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/router/src/queue.rs b/router/src/queue.rs index 2a51d2779..78eaaff06 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -9,7 +9,7 @@ use std::{ use tokio::{sync::Notify, time::Instant}; use tracing::info_span; -use crate::{adapter::Adapter, batch::Entry, infer::InferStreamResponse}; +use crate::{adapter::Adapter, batch::Entry}; #[derive(Debug, PartialEq)] pub(crate) enum AdapterStatus { @@ -214,7 +214,6 @@ impl AdapterQueuesState { // ensure that append completes before sending batcher message let queue = self.queue_map.get_mut(&adapter).unwrap(); - let id = self.next_id; queue.append(self.next_id, entry); self.next_id += 1; diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 8b43d89b7..d25be2128 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -78,7 +78,7 @@ def forward_layer_type( # Triton Punica kernels if ( - adapter_data.punica_wrapper.enabled + adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size and can_vectorize ): From 848b4c70eef03cd3888681cfeecdc2526bdfe607 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:35:20 -0800 Subject: [PATCH 75/76] Clear cache --- server/lorax_server/server.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index bb378b561..d8a2595cf 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -58,13 +58,10 @@ async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - try: - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() - except Exception: - exit(1) + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() return generate_pb2.ClearCacheResponse() From 107be9a9aa0da505c5854c6dbae18e0571a10488 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 5 Nov 2024 10:26:35 -0800 Subject: [PATCH 76/76] Check for key in lora weights --- server/lorax_server/utils/layers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index d25be2128..0feaae609 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -77,8 +77,10 @@ def forward_layer_type( can_vectorize = data is not None and data.can_vectorize(self.process_group) # Triton Punica kernels + key = (layer_type, self.layer_id) if ( adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled + and key in adapter_data.layer_to_lora_weights and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size and can_vectorize ): @@ -89,7 +91,7 @@ def forward_layer_type( y_offset = None y_slice_size = None - lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)] + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[key] adapter_data.punica_wrapper.add_lora( result, input,