From c0e5798318a6b826572c612ddd4cf44621aa4add Mon Sep 17 00:00:00 2001 From: Jeffrey Tang <810895+jeffreyftang@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:16:23 -0600 Subject: [PATCH 1/5] fix: Make logprob field optional for response Pydantic validation (#692) --- clients/python/lorax/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 09624f3a5..2fc98b7b7 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -270,7 +270,7 @@ class Token(BaseModel): # Token text text: str # Logprob - logprob: float + logprob: Optional[float] # Is the token a special token # Can be used to ignore tokens when concatenating special: bool From 02ff9f34675f3cc58f6b97e9e898caf396f27704 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 16:50:19 -0600 Subject: [PATCH 2/5] Bump tornado from 6.4.1 to 6.4.2 in /server (#695) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jeffrey Tang --- .github/workflows/server_tests.yaml | 1 + server/poetry.lock | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/.github/workflows/server_tests.yaml b/.github/workflows/server_tests.yaml index 744dbc3c3..25dd2883b 100644 --- a/.github/workflows/server_tests.yaml +++ b/.github/workflows/server_tests.yaml @@ -1,6 +1,7 @@ name: Server Tests on: + workflow_dispatch: pull_request: paths: - ".github/workflows/server_tests.yaml" diff --git a/server/poetry.lock b/server/poetry.lock index 1b68aff6e..60b0c0640 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -3672,22 +3672,22 @@ optree = ["optree (>=0.11.0)"] [[package]] name = "tornado" -version = "6.4.1" +version = "6.4.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -optional = true +optional = false python-versions = ">=3.8" files = [ - {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"}, - {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"}, - {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"}, - {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"}, - {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"}, - {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"}, - {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"}, - {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"}, - {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"}, - {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"}, - {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, + {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1"}, + {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803"}, + {file = "tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec"}, + {file = "tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946"}, + {file = "tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf"}, + {file = "tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634"}, + {file = "tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73"}, + {file = "tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c"}, + {file = "tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482"}, + {file = "tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38"}, + {file = "tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b"}, ] [[package]] From 69e53062d40f161cda9d7a3c21205d4b9ffd7a4e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 26 Nov 2024 11:41:30 -0800 Subject: [PATCH 3/5] Handle case where adapter exists but is still being trained (#691) --- server/lorax_server/models/flash_causal_lm.py | 3 +- server/lorax_server/models/mllama.py | 12 ++++---- server/lorax_server/utils/sources/__init__.py | 28 +++++++++++++++++-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e7e295fd7..8b45f76fc 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1741,8 +1741,7 @@ def generate_token( # Only save tokens if we are done prefilling for this request batch.all_input_ids_tensor[ i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index a80d5ce22..2d2fd77ff 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -186,9 +186,11 @@ def supports_adapter_loading(self) -> bool: @property def adapter_layers(self) -> List[str]: - return TEXT_ADAPTER_LAYERS \ - + [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \ - + [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] + return ( + TEXT_ADAPTER_LAYERS + + [f"VISION_GLOBAL_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS] + + [f"VISION_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS] + ) @property def default_traced_adapter_layers(self) -> List[str]: @@ -197,14 +199,14 @@ def default_traced_adapter_layers(self) -> List[str]: def get_num_layers_for_type(self, layer_type: str) -> int: if "LM_HEAD" in layer_type: return 1 - if 'VISION_GLOBAL_TRANSFORMER_' in layer_type: + if "VISION_GLOBAL_TRANSFORMER_" in layer_type: return len(self.model.vision_model.global_transformer.layers) if "VISION_TRANSFORMER_" in layer_type: return len(self.model.vision_model.transformer.layers) return [ layer_id for layer_id, layer in enumerate(self.model.text_model.model.layers) - if not isinstance(layer, FlashLlamaCrossLayer) + if not isinstance(layer, FlashLlamaCrossLayer) ] def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index 47cc566a5..f9c209df5 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -3,6 +3,7 @@ from typing import Optional import requests +from loguru import logger from .hub import ( HubModelSource, @@ -24,6 +25,18 @@ PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT = "/v2/repos/{}/version/{}" PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.app.predibase.com") +# Predibase status codes +PENDING = "pending" +QUEUED = "queued" +TRAINING = "training" +STOPPING = "stopping" +STOPPED = "stopped" +CANCELED = "canceled" +COMPLETED = "completed" +ERRORED = "errored" +STATUSES = {PENDING, QUEUED, TRAINING, STOPPING, STOPPED, CANCELED, COMPLETED, ERRORED} +FINAL_STATUSES = {COMPLETED, ERRORED, CANCELED, STOPPED} + @lru_cache(maxsize=256) def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: @@ -60,9 +73,20 @@ def fetch_legacy_url(): # Not found in new path, fall back to legacy endpoint. return fetch_legacy_url() - path = resp.json().get("adapterPath", None) - if path is None: + resp_json = resp.json() + + status = resp_json.get("status") + if status not in STATUSES: + # Status is unknown to us, so skip status validation + logger.warning(f"Unknown status {status} for adapter {model_id}") + elif status not in FINAL_STATUSES: + # Status is known to us, but not a final status, so raise a user error + raise RuntimeError(f"Adapter {model_id} has not completed training (status: {status})") + + path = resp_json.get("adapterPath") + if not path: raise RuntimeError(f"Adapter {model_id} is not yet available") + return path else: # Use legacy path only since new endpoint requires both name and version number. From 8f51d0ea0e8b32f228ece7b31e0e83909b99413d Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Tue, 26 Nov 2024 22:17:57 -0800 Subject: [PATCH 4/5] Add support for returning alternative_tokens with return_k_alternatives (#697) --- server/lorax_server/models/flash_causal_lm.py | 61 ++++++++++--------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 8b45f76fc..2e3990328 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -24,6 +24,7 @@ ) from lorax_server.models.model import Model from lorax_server.models.types import ( + AlternativeTokens, Batch, GeneratedText, Generation, @@ -1877,36 +1878,6 @@ def generate_token( ) in enumerate(iterator): all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None - # TODO(travis): return_k_alternatives - # if request.parameters.return_k_alternatives > 0: - # # Limit the number of alternatives to the vocabulary size - # num_alternatives = min( - # request.parameters.return_k_alternatives, - # len(alternative_token_ids[token_idx]), - # ) - - # # Select top-k logprobs - # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] - # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] - - # # Decode tokens - # request_alternative_token_texts = [] - # for alternative_token_id in request_alternative_token_ids: - # all_input_ids.append(alternative_token_id) - # alternative_token_text, _, _ = self.decode_token( - # all_input_ids, - # prefix_offset, - # read_offset, - # ) - # request_alternative_token_texts.append(alternative_token_text) - # all_input_ids.pop() - # alternative_tokens = AlternativeTokens( - # request_alternative_token_ids, - # request_alternative_token_logprobs, - # request_alternative_token_texts, - # ) - # all_alternative_tokens.append(alternative_tokens) - # Compute logprobs first as, even though we might skip the token, # it can still be required to compute the logprobs # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need @@ -1985,6 +1956,36 @@ def generate_token( ) next_token_texts.append(next_token_text) + if request.parameters.return_k_alternatives > 0: + # Limit the number of alternatives to the vocabulary size + num_alternatives = min( + request.parameters.return_k_alternatives, + len(alternative_token_ids[j]), + ) + + # Select top-k logprobs + request_alternative_token_ids = alternative_token_ids[j][:num_alternatives] + request_alternative_token_logprobs = alternative_token_logprobs[j][:num_alternatives] + + # Decode tokens + request_alternative_token_texts = [] + for alternative_token_id in request_alternative_token_ids: + all_input_ids.append(alternative_token_id) + alternative_token_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + request_alternative_token_texts.append(alternative_token_text) + all_input_ids.pop() + alternative_tokens = AlternativeTokens( + request_alternative_token_ids, + request_alternative_token_logprobs, + request_alternative_token_texts, + ) + all_alternative_tokens.append(alternative_tokens) + + stop, reason = stopping_criteria( next_token_id, next_token_text, From da95224890789210aec105ac0beaba766d1f18cc Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 27 Nov 2024 10:38:46 -0800 Subject: [PATCH 5/5] Speed up best_of when using prefix caching (#698) --- clients/python/lorax/client.py | 3 +++ docs/guides/contributing/development_env.md | 2 +- router/src/infer.rs | 12 ++++++++++++ router/src/server.rs | 4 +++- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index a8028834c..aa58e800d 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -293,6 +293,9 @@ def generate( if resp.status_code != 200: raise parse_error(resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None) + + if LORAX_DEBUG_MODE: + print(resp.headers) return Response(**payload[0]) diff --git a/docs/guides/contributing/development_env.md b/docs/guides/contributing/development_env.md index 5f35eebb1..84c55ce91 100644 --- a/docs/guides/contributing/development_env.md +++ b/docs/guides/contributing/development_env.md @@ -47,7 +47,7 @@ We'll be working out of three different terminals during development, each servi Install development dependencies: ```shell -DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \ +apt update && DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \ PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ diff --git a/router/src/infer.rs b/router/src/infer.rs index 52f7dc2a0..101360ba1 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -160,6 +160,8 @@ pub struct Infer { limit_concurrent_requests: Arc, /// tokenizer for NER processing tokenizer: Option>, + /// Whether prefix caching is enabled + pub prefix_caching: bool, } impl Infer { @@ -268,6 +270,7 @@ impl Infer { chat_template, limit_concurrent_requests: semaphore, tokenizer: tokenizer, + prefix_caching, } } @@ -861,10 +864,19 @@ impl Infer { &self, request: GenerateRequest, best_of: usize, + prefix_caching: bool, ) -> Result<(InferResponse, Vec), InferError> { // validate best_of parameter separately let best_of = self.validation.validate_best_of(best_of)?; + // If prefix caching is enabled, first generate a single token to cache the prefix, then generate + // subsequent responses. + if prefix_caching { + let mut prefix_request = request.clone(); + prefix_request.parameters.max_new_tokens = Some(1); + self.generate(prefix_request).await?; + } + // create multiple generate requests let mut infer_responses: Vec = try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; diff --git a/router/src/server.rs b/router/src/server.rs index 7717d5335..d33412aaf 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -639,7 +639,9 @@ async fn generate( // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { - let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; + let (response, best_of_responses) = infer + .generate_best_of(req.0, best_of, infer.prefix_caching) + .await?; (response, Some(best_of_responses)) } _ => (infer.generate(req.0).await?, None),