Skip to content

Commit

Permalink
Merge branch 'main' into support-cross-attn-mllama
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar committed Nov 27, 2024
2 parents dab0740 + da95224 commit b3d16c7
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 54 deletions.
1 change: 1 addition & 0 deletions .github/workflows/server_tests.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Server Tests

on:
workflow_dispatch:
pull_request:
paths:
- ".github/workflows/server_tests.yaml"
Expand Down
3 changes: 3 additions & 0 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def generate(

if resp.status_code != 200:
raise parse_error(resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None)

if LORAX_DEBUG_MODE:
print(resp.headers)

return Response(**payload[0])

Expand Down
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class Token(BaseModel):
# Token text
text: str
# Logprob
logprob: float
logprob: Optional[float]
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/contributing/development_env.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ We'll be working out of three different terminals during development, each servi
Install development dependencies:

```shell
DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \
apt update && DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
Expand Down
12 changes: 12 additions & 0 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ pub struct Infer {
limit_concurrent_requests: Arc<Semaphore>,
/// tokenizer for NER processing
tokenizer: Option<Arc<Tokenizer>>,
/// Whether prefix caching is enabled
pub prefix_caching: bool,
}

impl Infer {
Expand Down Expand Up @@ -268,6 +270,7 @@ impl Infer {
chat_template,
limit_concurrent_requests: semaphore,
tokenizer: tokenizer,
prefix_caching,
}
}

Expand Down Expand Up @@ -861,10 +864,19 @@ impl Infer {
&self,
request: GenerateRequest,
best_of: usize,
prefix_caching: bool,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;

// If prefix caching is enabled, first generate a single token to cache the prefix, then generate
// subsequent responses.
if prefix_caching {
let mut prefix_request = request.clone();
prefix_request.parameters.max_new_tokens = Some(1);
self.generate(prefix_request).await?;
}

// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
Expand Down
4 changes: 3 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,9 @@ async fn generate(
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
let (response, best_of_responses) = infer
.generate_best_of(req.0, best_of, infer.prefix_caching)
.await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
Expand Down
64 changes: 32 additions & 32 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from lorax_server.models.model import Model
from lorax_server.models.types import (
AlternativeTokens,
Batch,
GeneratedText,
Generation,
Expand Down Expand Up @@ -1747,8 +1748,7 @@ def generate_token(
# Only save tokens if we are done prefilling for this request
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
Expand Down Expand Up @@ -1884,36 +1884,6 @@ def generate_token(
) in enumerate(iterator):
all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None

# TODO(travis): return_k_alternatives
# if request.parameters.return_k_alternatives > 0:
# # Limit the number of alternatives to the vocabulary size
# num_alternatives = min(
# request.parameters.return_k_alternatives,
# len(alternative_token_ids[token_idx]),
# )

# # Select top-k logprobs
# request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives]
# request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives]

# # Decode tokens
# request_alternative_token_texts = []
# for alternative_token_id in request_alternative_token_ids:
# all_input_ids.append(alternative_token_id)
# alternative_token_text, _, _ = self.decode_token(
# all_input_ids,
# prefix_offset,
# read_offset,
# )
# request_alternative_token_texts.append(alternative_token_text)
# all_input_ids.pop()
# alternative_tokens = AlternativeTokens(
# request_alternative_token_ids,
# request_alternative_token_logprobs,
# request_alternative_token_texts,
# )
# all_alternative_tokens.append(alternative_tokens)

# Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
Expand Down Expand Up @@ -1992,6 +1962,36 @@ def generate_token(
)
next_token_texts.append(next_token_text)

if request.parameters.return_k_alternatives > 0:
# Limit the number of alternatives to the vocabulary size
num_alternatives = min(
request.parameters.return_k_alternatives,
len(alternative_token_ids[j]),
)

# Select top-k logprobs
request_alternative_token_ids = alternative_token_ids[j][:num_alternatives]
request_alternative_token_logprobs = alternative_token_logprobs[j][:num_alternatives]

# Decode tokens
request_alternative_token_texts = []
for alternative_token_id in request_alternative_token_ids:
all_input_ids.append(alternative_token_id)
alternative_token_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
request_alternative_token_texts.append(alternative_token_text)
all_input_ids.pop()
alternative_tokens = AlternativeTokens(
request_alternative_token_ids,
request_alternative_token_logprobs,
request_alternative_token_texts,
)
all_alternative_tokens.append(alternative_tokens)


stop, reason = stopping_criteria(
next_token_id,
next_token_text,
Expand Down
10 changes: 6 additions & 4 deletions server/lorax_server/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,11 @@ def supports_adapter_loading(self) -> bool:

@property
def adapter_layers(self) -> List[str]:
return TEXT_ADAPTER_LAYERS \
+ [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \
+ [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS]
return (
TEXT_ADAPTER_LAYERS
+ [f"VISION_GLOBAL_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS]
+ [f"VISION_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS]
)

@property
def default_traced_adapter_layers(self) -> List[str]:
Expand All @@ -197,7 +199,7 @@ def default_traced_adapter_layers(self) -> List[str]:
def get_num_layers_for_type(self, layer_type: str) -> int:
if "LM_HEAD" in layer_type:
return 1
if 'VISION_GLOBAL_TRANSFORMER_' in layer_type:
if "VISION_GLOBAL_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.global_transformer.layers)
if "VISION_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.transformer.layers)
Expand Down
28 changes: 26 additions & 2 deletions server/lorax_server/utils/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import requests
from loguru import logger

from .hub import (
HubModelSource,
Expand All @@ -24,6 +25,18 @@
PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT = "/v2/repos/{}/version/{}"
PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.app.predibase.com")

# Predibase status codes
PENDING = "pending"
QUEUED = "queued"
TRAINING = "training"
STOPPING = "stopping"
STOPPED = "stopped"
CANCELED = "canceled"
COMPLETED = "completed"
ERRORED = "errored"
STATUSES = {PENDING, QUEUED, TRAINING, STOPPING, STOPPED, CANCELED, COMPLETED, ERRORED}
FINAL_STATUSES = {COMPLETED, ERRORED, CANCELED, STOPPED}


@lru_cache(maxsize=256)
def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str:
Expand Down Expand Up @@ -60,9 +73,20 @@ def fetch_legacy_url():
# Not found in new path, fall back to legacy endpoint.
return fetch_legacy_url()

path = resp.json().get("adapterPath", None)
if path is None:
resp_json = resp.json()

status = resp_json.get("status")
if status not in STATUSES:
# Status is unknown to us, so skip status validation
logger.warning(f"Unknown status {status} for adapter {model_id}")
elif status not in FINAL_STATUSES:
# Status is known to us, but not a final status, so raise a user error
raise RuntimeError(f"Adapter {model_id} has not completed training (status: {status})")

path = resp_json.get("adapterPath")
if not path:
raise RuntimeError(f"Adapter {model_id} is not yet available")

return path
else:
# Use legacy path only since new endpoint requires both name and version number.
Expand Down
26 changes: 13 additions & 13 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b3d16c7

Please sign in to comment.